优化批量图片识别
This commit is contained in:
parent
dcba129694
commit
3f72706e24
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
.env
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
import PIL
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
@ -199,18 +200,27 @@ def run_ocr(image):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def ocr_image(image):
|
def ocr_image(image, user_query=""):
|
||||||
# 将图片转换为 base64 编码
|
# 压缩图片并降低分辨率
|
||||||
|
max_size = 1600
|
||||||
|
aspect_ratio = image.width / image.height
|
||||||
|
if aspect_ratio > 1: # 宽度大于高度
|
||||||
|
new_size = (max_size, int(max_size / aspect_ratio))
|
||||||
|
else: # 高度大于或等于宽度
|
||||||
|
new_size = (int(max_size * aspect_ratio), max_size)
|
||||||
|
compressed_image = image.resize(new_size, PIL.Image.Resampling.LANCZOS)
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="JPEG")
|
|
||||||
|
compressed_image.save(buffered, format="JPEG", quality=75) # 设置JPEG质量为85
|
||||||
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
print(f"Base64 字符串长度: {len(image_base64) / 1024:.2f} k")
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "识别图片中的文字并以纯文本(txt)格式输出"
|
"text": "识别图片中的文字并以纯文本(txt)格式输出,如果图片分为左右两栏,则先输出左边栏再输出右边栏的内容。" + user_query
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@ -234,12 +244,12 @@ def ocr_image(image):
|
|||||||
return chat_completion.choices[0].message.content
|
return chat_completion.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
def batch_ocr(files):
|
def batch_ocr(files, user_query):
|
||||||
results = []
|
results = []
|
||||||
for file in files:
|
for file in files:
|
||||||
try:
|
try:
|
||||||
image = Image.open(file)
|
image = Image.open(file)
|
||||||
ocr_result = ocr_image(image)
|
ocr_result = ocr_image(image, user_query)
|
||||||
results.append(f"\n\n{ocr_result}\n\n")
|
results.append(f"\n\n{ocr_result}\n\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -334,9 +344,11 @@ with gr.Blocks() as iface:
|
|||||||
output_file = gr.File(label="Download OCR Results")
|
output_file = gr.File(label="Download OCR Results")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_button = gr.Button("Process Images")
|
process_button = gr.Button("Process Images")
|
||||||
|
with gr.Row():
|
||||||
|
user_query_text = gr.Textbox(placeholder="输入额外的要求")
|
||||||
|
|
||||||
# 绑定按钮点击事件
|
# 绑定按钮点击事件
|
||||||
process_button.click(batch_ocr, inputs=input_files, outputs=output_file)
|
process_button.click(batch_ocr, inputs=[input_files, user_query_text], outputs=output_file)
|
||||||
|
|
||||||
iface.load(fn=update_log_output, outputs=[log_output], every=1)
|
iface.load(fn=update_log_output, outputs=[log_output], every=1)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user