加入批量ocr 功能
This commit is contained in:
parent
7c8b8f47e0
commit
8a7e8de066
1
.env
Normal file
1
.env
Normal file
@ -0,0 +1 @@
|
||||
GLM_API_KEY="d58beac412cc13d5a4ea96613f59d55a.NCYKWCm3vyeqepgL"
|
||||
1
.env.example
Normal file
1
.env.example
Normal file
@ -0,0 +1 @@
|
||||
GLM_API_KEY=""
|
||||
@ -2,6 +2,13 @@ import gradio as gr
|
||||
import subprocess
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import base64
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def convert_to_wav(audio_file):
|
||||
@ -192,6 +199,56 @@ def run_ocr(image):
|
||||
return content
|
||||
|
||||
|
||||
def ocr_image(image):
|
||||
# 将图片转换为 base64 编码
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "识别图片中的文字并以纯文本(txt)格式输出"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
client = OpenAI(
|
||||
# This is the default and can be omitted
|
||||
api_key=os.getenv("GLM_API_KEY"),
|
||||
base_url="https://open.bigmodel.cn/api/paas/v4/"
|
||||
)
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model="glm-4v-plus",
|
||||
)
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
|
||||
def batch_ocr(files):
|
||||
results = []
|
||||
for file in files:
|
||||
image = Image.open(file)
|
||||
ocr_result = ocr_image(image)
|
||||
results.append(f"\n\n{ocr_result}\n\n")
|
||||
|
||||
# 将结果写入 Markdown 文件
|
||||
output_file = "ocr_results.txt"
|
||||
with open(output_file, "w") as f:
|
||||
f.write("\n".join(results))
|
||||
|
||||
return output_file
|
||||
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
gr.Markdown("# 大模型工具集")
|
||||
with gr.Tabs():
|
||||
@ -265,6 +322,19 @@ with gr.Blocks() as iface:
|
||||
|
||||
btn_recognize.click(fn=run_ocr, inputs=image_input, outputs=text_output)
|
||||
|
||||
with gr.Tab("批量图片识别"):
|
||||
gr.Markdown("## OCR 图片批量识别")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_files = gr.File(file_count="multiple", label="Upload Images")
|
||||
with gr.Column():
|
||||
output_file = gr.File(label="Download OCR Results")
|
||||
with gr.Row():
|
||||
process_button = gr.Button("Process Images")
|
||||
|
||||
# 绑定按钮点击事件
|
||||
process_button.click(batch_ocr, inputs=input_files, outputs=output_file)
|
||||
|
||||
iface.load(fn=update_log_output, outputs=[log_output], every=1)
|
||||
|
||||
iface.launch(server_name="0.0.0.0")
|
||||
|
||||
@ -1 +1,4 @@
|
||||
gradio==4.44.0
|
||||
openai==1.44.1
|
||||
python-dotenv~=1.0.1
|
||||
pillow~=10.4.0
|
||||
Loading…
x
Reference in New Issue
Block a user