优化批量图片识别
This commit is contained in:
parent
dcba129694
commit
3f72706e24
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
.env
|
||||
@ -1,3 +1,4 @@
|
||||
import PIL
|
||||
import gradio as gr
|
||||
import subprocess
|
||||
import os
|
||||
@ -199,18 +200,27 @@ def run_ocr(image):
|
||||
return content
|
||||
|
||||
|
||||
def ocr_image(image):
|
||||
# 将图片转换为 base64 编码
|
||||
def ocr_image(image, user_query=""):
|
||||
# 压缩图片并降低分辨率
|
||||
max_size = 1600
|
||||
aspect_ratio = image.width / image.height
|
||||
if aspect_ratio > 1: # 宽度大于高度
|
||||
new_size = (max_size, int(max_size / aspect_ratio))
|
||||
else: # 高度大于或等于宽度
|
||||
new_size = (int(max_size * aspect_ratio), max_size)
|
||||
compressed_image = image.resize(new_size, PIL.Image.Resampling.LANCZOS)
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
|
||||
compressed_image.save(buffered, format="JPEG", quality=75) # 设置JPEG质量为85
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
print(f"Base64 字符串长度: {len(image_base64) / 1024:.2f} k")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "识别图片中的文字并以纯文本(txt)格式输出"
|
||||
"text": "识别图片中的文字并以纯文本(txt)格式输出,如果图片分为左右两栏,则先输出左边栏再输出右边栏的内容。" + user_query
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@ -234,12 +244,12 @@ def ocr_image(image):
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
|
||||
def batch_ocr(files):
|
||||
def batch_ocr(files, user_query):
|
||||
results = []
|
||||
for file in files:
|
||||
try:
|
||||
image = Image.open(file)
|
||||
ocr_result = ocr_image(image)
|
||||
ocr_result = ocr_image(image, user_query)
|
||||
results.append(f"\n\n{ocr_result}\n\n")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@ -334,9 +344,11 @@ with gr.Blocks() as iface:
|
||||
output_file = gr.File(label="Download OCR Results")
|
||||
with gr.Row():
|
||||
process_button = gr.Button("Process Images")
|
||||
with gr.Row():
|
||||
user_query_text = gr.Textbox(placeholder="输入额外的要求")
|
||||
|
||||
# 绑定按钮点击事件
|
||||
process_button.click(batch_ocr, inputs=input_files, outputs=output_file)
|
||||
process_button.click(batch_ocr, inputs=[input_files, user_query_text], outputs=output_file)
|
||||
|
||||
iface.load(fn=update_log_output, outputs=[log_output], every=1)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user