611 lines
21 KiB
Python
611 lines
21 KiB
Python
import PIL
|
||
import gradio as gr
|
||
import subprocess
|
||
import os
|
||
import re
|
||
import io
|
||
import base64
|
||
import datetime
|
||
from dotenv import load_dotenv
|
||
from openai import OpenAI
|
||
from PIL import Image
|
||
from docx import Document
|
||
from docx.shared import Pt, Inches
|
||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
|
||
import markdown
|
||
from html.parser import HTMLParser
|
||
|
||
load_dotenv()
|
||
|
||
|
||
def convert_to_wav(audio_file):
|
||
# 使用 ffmpeg 将音频文件转换为 wav 格式
|
||
output_wav_file = "output.wav"
|
||
print(f"开始转换音频文件 {audio_file} 为 WAV 格式...")
|
||
subprocess.run([
|
||
"ffmpeg", "-y", "-i", audio_file,
|
||
"-ac", "1",
|
||
"-ar", "16000",
|
||
output_wav_file
|
||
])
|
||
print(f"音频文件 {audio_file} 已转换为 WAV 格式,输出文件为 {output_wav_file}")
|
||
return output_wav_file
|
||
|
||
|
||
def transcribe_audio(wav_file, original_filename, offset_time, duration_time):
|
||
output_dir = "/home/tmfc/apps/chyoso-toolkit/whisper_output/"
|
||
# 调用 whisper 命令行程序进行转写
|
||
whisper_cmd = [
|
||
"/home/tmfc/apps/whisper/.venv/bin/whisper",
|
||
"--language", "zh",
|
||
"--output_dir", output_dir,
|
||
"--output_format", "txt",
|
||
"--model", "turbo",
|
||
wav_file
|
||
]
|
||
print(whisper_cmd)
|
||
print(f"开始转写音频文件 {wav_file}...")
|
||
result = subprocess.run(whisper_cmd, capture_output=True, text=True)
|
||
print(result)
|
||
# 将转写结果保存为 txt 文件
|
||
# txt_file = "transcription.txt"
|
||
# with open(txt_file, "w") as f:
|
||
# f.write(result.stdout)
|
||
|
||
txt_file = output_dir + os.path.splitext(original_filename)[0] + ".txt"
|
||
print(f"音频文件 {wav_file} 转写完成,结果已保存为 {txt_file}")
|
||
|
||
return txt_file
|
||
|
||
|
||
def process_audio(audio_file, offset_time, duration_time):
|
||
print("开始处理音频文件...")
|
||
# 获取上传的文件名
|
||
original_filename = os.path.basename(audio_file)
|
||
|
||
# 转换音频文件为 wav 格式
|
||
# wav_file = convert_to_wav(audio_file)
|
||
|
||
# 转写音频文件
|
||
txt_file = transcribe_audio(audio_file, original_filename, offset_time, duration_time)
|
||
|
||
print("音频文件处理完成")
|
||
return txt_file
|
||
|
||
|
||
def direct_transcribe(audio_file, offset_time, duration_time):
|
||
print("开始直接转写音频文件...")
|
||
|
||
# 转写音频文件
|
||
txt_file = transcribe_audio("output.wav", "output.wav", offset_time, duration_time)
|
||
|
||
print("音频文件直接转写完成")
|
||
return txt_file
|
||
|
||
|
||
batch_directory = '/mnt/d/share/audio/'
|
||
|
||
|
||
def list_files():
|
||
# 获取目录下的所有文件
|
||
files = os.listdir(batch_directory)
|
||
# 过滤掉目录,只保留文件
|
||
files = [f for f in files if
|
||
os.path.isfile(os.path.join(batch_directory, f)) and f.lower().endswith(('.mp3', '.m4a'))]
|
||
return files
|
||
|
||
|
||
log_content = ""
|
||
|
||
|
||
def batch_transcribe():
|
||
global log_content
|
||
files = list_files()
|
||
result_file = []
|
||
for file in files:
|
||
# 转换音频文件为 wav 格式
|
||
log_entry = "转换" + file + "为 wav\n"
|
||
log_content += log_entry
|
||
wav_file = convert_to_wav(batch_directory + file)
|
||
log_entry = "转换wav成功,开始转写\n"
|
||
log_content += log_entry
|
||
# 转写音频文件
|
||
txt_file = transcribe_audio(wav_file, file, 10, 0)
|
||
log_entry = "转写 " + file + "完成\n"
|
||
log_content += log_entry
|
||
result_file.append(txt_file)
|
||
return result_file
|
||
|
||
|
||
def display_files():
|
||
files = list_files()
|
||
return "\n".join(files)
|
||
|
||
|
||
def get_log():
|
||
global log_content
|
||
return log_content
|
||
|
||
|
||
def update_log_output():
|
||
return gr.update(value=get_log())
|
||
|
||
|
||
def convert_to_docx(text):
|
||
if text.strip() == "":
|
||
return "输入框不能为空!"
|
||
|
||
# 将输入内容写入 file.md 文件
|
||
with open("file.md", "w") as file:
|
||
file.write(text)
|
||
|
||
# 使用 pandoc 将 file.md 转换为 file.docx
|
||
try:
|
||
subprocess.run(["pandoc", "file.md", "-o", "file.docx"], check=True)
|
||
except subprocess.CalledProcessError as e:
|
||
return f"转换失败: {e}"
|
||
|
||
# 返回 file.docx 文件供用户下载
|
||
return "file.docx"
|
||
|
||
|
||
class MarkdownToDocxParser(HTMLParser):
|
||
"""解析 HTML 并转换为 Word 文档"""
|
||
def __init__(self, document):
|
||
super().__init__()
|
||
self.doc = document
|
||
self.current_paragraph = None
|
||
self.current_run = None
|
||
self.in_bold = False
|
||
self.in_italic = False
|
||
self.in_code = False
|
||
self.in_heading = False
|
||
self.heading_level = 0
|
||
self.list_items = []
|
||
# 表格相关
|
||
self.in_table = False
|
||
self.table_depth = 0 # 追踪表格嵌套深度
|
||
self.current_table = None
|
||
self.current_row = None
|
||
self.current_cell = None
|
||
self.table_rows = []
|
||
self.current_row_cells = []
|
||
self.current_cell_content = [] # 存储单元格内容(包括格式)
|
||
self.is_header_row = False
|
||
|
||
def handle_starttag(self, tag, attrs):
|
||
if tag in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||
self.in_heading = True
|
||
self.heading_level = int(tag[1])
|
||
self.current_paragraph = self.doc.add_heading(level=self.heading_level)
|
||
self.current_paragraph.text = ''
|
||
elif tag == 'p':
|
||
if not self.in_table: # 不在表格中才创建段落
|
||
self.current_paragraph = self.doc.add_paragraph()
|
||
elif self.current_cell is not None:
|
||
# 在表格单元格中,记录换行
|
||
self.current_cell_content.append({'type': 'break'})
|
||
elif tag == 'strong' or tag == 'b':
|
||
self.in_bold = True
|
||
elif tag == 'em' or tag == 'i':
|
||
self.in_italic = True
|
||
elif tag == 'code':
|
||
self.in_code = True
|
||
elif tag == 'li':
|
||
if not self.in_table:
|
||
self.current_paragraph = self.doc.add_paragraph(style='List Bullet')
|
||
elif tag == 'br':
|
||
if self.in_table and self.current_cell is not None:
|
||
self.current_cell_content.append({'type': 'break'})
|
||
elif self.current_paragraph:
|
||
self.current_paragraph.add_run().add_break()
|
||
# 表格处理
|
||
elif tag == 'table':
|
||
self.table_depth += 1
|
||
if self.table_depth == 1: # 只处理最外层表格
|
||
self.in_table = True
|
||
self.table_rows = []
|
||
elif tag == 'thead':
|
||
if self.table_depth == 1:
|
||
self.is_header_row = True
|
||
elif tag == 'tbody':
|
||
if self.table_depth == 1:
|
||
self.is_header_row = False
|
||
elif tag == 'tr':
|
||
if self.table_depth == 1:
|
||
self.current_row_cells = []
|
||
elif tag == 'th' or tag == 'td':
|
||
if self.table_depth == 1:
|
||
self.current_cell = []
|
||
self.current_cell_content = []
|
||
if tag == 'th':
|
||
self.in_bold = True # 表头加粗
|
||
|
||
def handle_endtag(self, tag):
|
||
if tag in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||
self.in_heading = False
|
||
self.heading_level = 0
|
||
elif tag == 'p':
|
||
if not self.in_table:
|
||
self.current_paragraph = None
|
||
elif tag == 'strong' or tag == 'b':
|
||
self.in_bold = False
|
||
elif tag == 'em' or tag == 'i':
|
||
self.in_italic = False
|
||
elif tag == 'code':
|
||
self.in_code = False
|
||
elif tag == 'li':
|
||
if not self.in_table:
|
||
self.current_paragraph = None
|
||
# 表格处理
|
||
elif tag == 'table':
|
||
self.table_depth -= 1
|
||
if self.table_depth == 0:
|
||
self.in_table = False
|
||
self._create_table()
|
||
elif tag == 'thead':
|
||
if self.table_depth == 1:
|
||
self.is_header_row = False
|
||
elif tag == 'tr':
|
||
if self.table_depth == 1 and self.current_row_cells:
|
||
self.table_rows.append(self.current_row_cells)
|
||
self.current_row_cells = []
|
||
elif tag == 'th' or tag == 'td':
|
||
if self.table_depth == 1 and self.current_cell is not None:
|
||
# 保存单元格内容(包含格式信息)
|
||
self.current_row_cells.append(self.current_cell_content.copy())
|
||
self.current_cell = None
|
||
self.current_cell_content = []
|
||
if tag == 'th':
|
||
self.in_bold = False
|
||
|
||
def _create_table(self):
|
||
"""创建 Word 表格"""
|
||
if not self.table_rows:
|
||
return
|
||
|
||
# 计算列数
|
||
max_cols = max(len(row) for row in self.table_rows) if self.table_rows else 0
|
||
if max_cols == 0:
|
||
return
|
||
|
||
# 创建表格
|
||
table = self.doc.add_table(rows=len(self.table_rows), cols=max_cols)
|
||
table.style = 'Light Grid Accent 1'
|
||
|
||
# 填充数据
|
||
for i, row_data in enumerate(self.table_rows):
|
||
row = table.rows[i]
|
||
for j, cell_content_list in enumerate(row_data):
|
||
if j >= len(row.cells):
|
||
continue
|
||
cell = row.cells[j]
|
||
# 清空默认段落
|
||
cell.text = ''
|
||
para = cell.paragraphs[0] if cell.paragraphs else cell.add_paragraph()
|
||
|
||
# 处理单元格内容(支持格式和换行)
|
||
for content_item in cell_content_list:
|
||
if isinstance(content_item, dict):
|
||
if content_item.get('type') == 'break':
|
||
para.add_run().add_break()
|
||
else:
|
||
# 文本内容
|
||
text, is_bold, is_italic = content_item
|
||
run = para.add_run(text)
|
||
if is_bold:
|
||
run.bold = True
|
||
if is_italic:
|
||
run.italic = True
|
||
|
||
# 第一行加粗(表头)
|
||
if i == 0:
|
||
for paragraph in cell.paragraphs:
|
||
for run in paragraph.runs:
|
||
run.bold = True
|
||
|
||
self.table_rows = []
|
||
|
||
def handle_data(self, data):
|
||
if not data.strip() and not self.in_table:
|
||
return
|
||
|
||
# 如果在表格单元格中(只处理最外层表格)
|
||
if self.current_cell is not None and self.table_depth == 1:
|
||
# 保存文本及其格式
|
||
self.current_cell_content.append((data, self.in_bold, self.in_italic))
|
||
return
|
||
|
||
if not self.current_paragraph and not self.in_table:
|
||
self.current_paragraph = self.doc.add_paragraph()
|
||
|
||
if self.current_paragraph:
|
||
run = self.current_paragraph.add_run(data)
|
||
|
||
if self.in_bold:
|
||
run.bold = True
|
||
if self.in_italic:
|
||
run.italic = True
|
||
if self.in_code:
|
||
run.font.name = 'Courier New'
|
||
run.font.size = Pt(10)
|
||
|
||
|
||
def convert_markdown_to_word(markdown_text):
|
||
"""将 Markdown 文本转换为 Word 文档并保存到 word_output 目录(使用纯 Python 实现)"""
|
||
if not markdown_text or markdown_text.strip() == "":
|
||
return "请输入 Markdown 内容!", None
|
||
|
||
# 确保输出目录存在
|
||
output_dir = "word_output"
|
||
if not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 生成文件名(使用时间戳)
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
docx_filename = f"document_{timestamp}.docx"
|
||
docx_path = os.path.join(output_dir, docx_filename)
|
||
|
||
try:
|
||
# 将 Markdown 转换为 HTML(启用表格扩展)
|
||
html_content = markdown.markdown(
|
||
markdown_text,
|
||
extensions=['extra', 'nl2br', 'tables']
|
||
)
|
||
|
||
# 创建 Word 文档
|
||
doc = Document()
|
||
|
||
# 解析 HTML 并添加到 Word 文档
|
||
parser = MarkdownToDocxParser(doc)
|
||
parser.feed(html_content)
|
||
|
||
# 保存文档
|
||
doc.save(docx_path)
|
||
message = f"转换成功!文件已保存为: {docx_filename}"
|
||
# 只返回本次生成的文件
|
||
return message, docx_path
|
||
except Exception as e:
|
||
message = f"转换失败: {e}"
|
||
return message, None
|
||
|
||
|
||
def get_word_files():
|
||
"""获取 word_output 目录下的所有 Word 文档"""
|
||
output_dir = "word_output"
|
||
if not os.path.exists(output_dir):
|
||
return []
|
||
|
||
files = []
|
||
for filename in os.listdir(output_dir):
|
||
if filename.endswith('.docx'):
|
||
file_path = os.path.join(output_dir, filename)
|
||
files.append(file_path)
|
||
|
||
# 按修改时间倒序排列
|
||
files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
||
return files
|
||
|
||
|
||
def remove_headers(curl_command):
|
||
# 使用正则表达式去除 if-none-match 和 range 标头
|
||
curl_command = re.sub(r'(-H\s*\'if-none-match:[^\\]*\'\s*)', '', curl_command)
|
||
curl_command = re.sub(r'(-H\s*\'range:[^\\]*\'\s*)', '', curl_command)
|
||
return curl_command
|
||
|
||
|
||
def do_download_pdf_file(curl_command, pdf_filename):
|
||
# 去除不需要的标头
|
||
curl_command = remove_headers(curl_command)
|
||
|
||
# 使用 subprocess 调用 curl 命令
|
||
try:
|
||
result = subprocess.run(curl_command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||
# 保存文件到临时文件
|
||
if pdf_filename != "":
|
||
temp_file_path = f"./{pdf_filename}.pdf"
|
||
else:
|
||
temp_file_path = "./download.pdf"
|
||
with open(temp_file_path, 'wb') as f:
|
||
f.write(result.stdout)
|
||
return "File downloaded successfully", temp_file_path
|
||
except subprocess.CalledProcessError as e:
|
||
return f"Failed to download file. Error: {e.stderr.decode()}", None
|
||
|
||
|
||
def download_pdf(curl_command, pdf_filename):
|
||
message, file_path = do_download_pdf_file(curl_command, pdf_filename)
|
||
return file_path
|
||
|
||
|
||
def run_ocr(image):
|
||
# 保存上传的图片到指定路径
|
||
image_path = os.path.expanduser("/home/tmfc/apps/got-ocr/img.png")
|
||
image.save(image_path)
|
||
|
||
# 调用 OCR 命令
|
||
command = [
|
||
"sudo", "-u", "tmfc",
|
||
"/home/tmfc/miniconda3/envs/got/bin/python3", "/home/tmfc/apps/got-ocr/GOT/demo/run_ocr_2.0_crop.py",
|
||
"--model-name", "/home/tmfc/apps/got-ocr/models/",
|
||
"--image-file", image_path
|
||
]
|
||
|
||
out_file = "/home/tmfc/apps/got-ocr/img.txt"
|
||
try:
|
||
with open(out_file, 'w') as f:
|
||
result = subprocess.run(command, stdout=f, stderr=subprocess.PIPE, text=True)
|
||
|
||
except subprocess.CalledProcessError as e:
|
||
return f"识别失败: {e}"
|
||
with open(out_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
return content
|
||
|
||
|
||
def ocr_image(image, user_query=""):
|
||
# 压缩图片并降低分辨率
|
||
max_size = 1600
|
||
aspect_ratio = image.width / image.height
|
||
if aspect_ratio > 1: # 宽度大于高度
|
||
new_size = (max_size, int(max_size / aspect_ratio))
|
||
else: # 高度大于或等于宽度
|
||
new_size = (int(max_size * aspect_ratio), max_size)
|
||
compressed_image = image.resize(new_size, PIL.Image.Resampling.LANCZOS)
|
||
buffered = io.BytesIO()
|
||
|
||
compressed_image.save(buffered, format="JPEG", quality=75) # 设置JPEG质量为85
|
||
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||
print(f"Base64 字符串长度: {len(image_base64) / 1024:.2f} k")
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "识别图片中的文字并以纯文本(txt)格式输出,如果图片分为左右两栏,则先输出左边栏再输出右边栏的内容。" + user_query
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
]
|
||
client = OpenAI(
|
||
# This is the default and can be omitted
|
||
api_key=os.getenv("GLM_API_KEY"),
|
||
base_url="https://open.bigmodel.cn/api/paas/v4/"
|
||
)
|
||
|
||
chat_completion = client.chat.completions.create(
|
||
messages=messages,
|
||
model="glm-4v-plus",
|
||
)
|
||
return chat_completion.choices[0].message.content
|
||
|
||
|
||
def batch_ocr(files, user_query):
|
||
results = []
|
||
for file in files:
|
||
try:
|
||
image = Image.open(file)
|
||
ocr_result = ocr_image(image, user_query)
|
||
results.append(f"\n\n{ocr_result}\n\n")
|
||
except Exception as e:
|
||
print(e)
|
||
|
||
# 将结果写入 Markdown 文件
|
||
output_file = "ocr_results.txt"
|
||
with open(output_file, "w") as f:
|
||
f.write("\n".join(results))
|
||
|
||
return output_file
|
||
|
||
|
||
with gr.Blocks() as iface:
|
||
gr.Markdown("# 大模型工具集")
|
||
with gr.Tabs():
|
||
with gr.TabItem("音频转写"):
|
||
with gr.Row():
|
||
audio_input = gr.Audio(type="filepath", label="上传音频文件")
|
||
with gr.Column():
|
||
offset_input = gr.Number(label="偏移时间 (秒)")
|
||
duration_input = gr.Number(label="转写时长 (秒)")
|
||
|
||
with gr.Row():
|
||
process_button = gr.Button("处理并转写")
|
||
direct_transcribe_button = gr.Button("直接转写")
|
||
|
||
output_file = gr.File(label="转写结果")
|
||
|
||
process_button.click(process_audio, inputs=[audio_input, offset_input, duration_input], outputs=output_file)
|
||
direct_transcribe_button.click(direct_transcribe, inputs=[audio_input, offset_input, duration_input],
|
||
outputs=output_file)
|
||
with gr.TabItem("音频批量转写"):
|
||
with gr.Row():
|
||
link_upload = gr.HTML(value='<a href="https://webd.willking.tech" target="_blank">点击上传文件</a>')
|
||
with gr.Row():
|
||
file_list = gr.Textbox(label="文件列表")
|
||
with gr.Column():
|
||
list_file_button = gr.Button("刷新文件")
|
||
batch_process_button = gr.Button("批量处理")
|
||
with gr.Column():
|
||
batch_output_file = gr.Files(label="批量转写结果")
|
||
|
||
with gr.Row():
|
||
log_output = gr.Textbox(label="日志信息", lines=10)
|
||
|
||
list_file_button.click(fn=display_files, outputs=file_list)
|
||
|
||
batch_process_button.click(batch_transcribe, outputs=batch_output_file)
|
||
|
||
with gr.Tab("Markdown 转 Word"):
|
||
gr.Markdown("## Markdown 转 Word 转换器")
|
||
with gr.Row():
|
||
markdown_input = gr.Textbox(
|
||
lines=15,
|
||
placeholder="请在此输入 Markdown 内容...",
|
||
label="Markdown 内容"
|
||
)
|
||
with gr.Row():
|
||
convert_md_button = gr.Button("转换为 Word", variant="primary")
|
||
with gr.Row():
|
||
conversion_status = gr.Textbox(label="转换状态", interactive=False)
|
||
with gr.Row():
|
||
word_file_output = gr.File(label="生成的 Word 文档")
|
||
|
||
convert_md_button.click(
|
||
convert_markdown_to_word,
|
||
inputs=markdown_input,
|
||
outputs=[conversion_status, word_file_output]
|
||
)
|
||
|
||
with gr.Tab("下载pdf"):
|
||
gr.Markdown("## pdf 下载指令修复")
|
||
with gr.Row():
|
||
curl_text_input = gr.Textbox(lines=10, placeholder="请在此输入cURL脚本...")
|
||
with gr.Row():
|
||
with gr.Column():
|
||
pdf_filename = gr.Textbox(placeholder="输入文件名")
|
||
with gr.Column():
|
||
pdf_download_button = gr.Button("下载")
|
||
with gr.Row():
|
||
pdf_download_result = gr.File(label="下载pdf文件")
|
||
|
||
pdf_download_button.click(download_pdf, inputs=[curl_text_input, pdf_filename], outputs=pdf_download_result)
|
||
|
||
with gr.Tab("图片识别"):
|
||
gr.Markdown("## OCR 图片识别")
|
||
with gr.Row():
|
||
with gr.Column():
|
||
image_input = gr.Image(type="pil", label="上传图片")
|
||
with gr.Column():
|
||
btn_recognize = gr.Button("识别")
|
||
text_output = gr.Textbox(label="OCR 识别结果")
|
||
|
||
btn_recognize.click(fn=run_ocr, inputs=image_input, outputs=text_output)
|
||
|
||
with gr.Tab("批量图片识别"):
|
||
gr.Markdown("## OCR 图片批量识别")
|
||
with gr.Row():
|
||
with gr.Column():
|
||
input_files = gr.File(file_count="multiple", label="上传图片")
|
||
with gr.Column():
|
||
output_file = gr.File(label="下载识别结果")
|
||
with gr.Row():
|
||
process_button = gr.Button("开始识别")
|
||
with gr.Row():
|
||
user_query_text = gr.Textbox(label="额外识别要求", placeholder="输入额外的要求")
|
||
|
||
# 绑定按钮点击事件
|
||
process_button.click(batch_ocr, inputs=[input_files, user_query_text], outputs=output_file)
|
||
|
||
# 使用 Gradio 5 的定时器 API 周期性更新日志输出
|
||
timer = gr.Timer(1.0)
|
||
timer.tick(fn=update_log_output, outputs=[log_output])
|
||
|
||
iface.launch(server_name="0.0.0.0", server_port=7861)
|