model2mnn / app.py
tumuyan2's picture
update
5de10eb
raw
history blame
15.6 kB
from cgi import test
import gradio as gr
import requests
import os
import subprocess
from typing import Union
from pth2onnx import convert_pth_to_onnx
import re
import time
from urllib.parse import urlparse, unquote
import gdown
import sys
from typing import Optional
import datetime
from mnnsr import modelTest_for_gradio
from PIL import Image
import cv2
log_to_terminal = True
task_counter = 0
download_cache = {} # 格式: {url: 文件路径}
# 日志函数
def print_log(task_id, stage, status):
if log_to_terminal:
print(f"任务{task_id}: [{status}] {stage}")
# 使用 MNN 库自带的转换工具
def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
if fp16:
param.append('--fp16')
subprocess.run(param, check=True)
def download_file(url: str, save_path: str):
response = requests.get(url)
with open(save_path, 'wb') as f:
f.write(response.content)
def download_gdrive_file(
url: str,
folder: str,
filesize_max: int,
filesize_min: int = 0
) -> Optional[str]:
"""
从 Google Drive 链接获取文件信息,检查大小后下载文件。
Args:
url (str): Google Drive 的分享链接。
folder (str): 用于保存文件的目标文件夹路径。
filesize_max (int): 允许下载的最大文件大小(单位:字节)。
filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。
Returns:
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
"""
print(f"--- 开始处理链接: {url} ---")
# 准备下载路径并创建文件夹
try:
os.makedirs(folder, exist_ok=True)
except OSError as e:
print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
return None
# 4. 下载文件
try:
print(f"开始下载google drive文件到: {folder}")
# 保存当前工作目录
original_dir = os.getcwd()
# 切换到目标文件夹
os.chdir(folder)
# 下载文件,不指定output_path参数
downloaded_filename = gdown.download(url, quiet=True, fuzzy=True)
# 恢复原始工作目录
os.chdir(original_dir)
# 检查下载结果和文件信息
if downloaded_filename:
downloaded_path = os.path.join(folder, downloaded_filename)
if os.path.exists(downloaded_path):
# 获取文件大小
file_size = os.path.getsize(downloaded_path)
if file_size< filesize_max and file_size> filesize_min:
print(f"下载成功!文件已保存至: {downloaded_path}")
return downloaded_path
else:
print(f"文件大小超出范围: {file_size} bytes")
return None
else:
print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr)
return None
else:
print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr)
return None
except Exception as e:
print(f"下载过程中发生错误: {e}", file=sys.stderr)
return None
finally:
# 确保工作目录恢复,即使发生异常
if 'original_dir' in locals():
os.chdir(original_dir)
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
global download_cache
# 检查缓存是否存在且文件有效
if url in download_cache:
cached_path = download_cache[url]
if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min:
print(f" 使用缓存文件: {cached_path}")
return cached_path
save_path = None # 初始化save_path为None,避免except块中引用错误
try:
# 发送HTTP请求,流式下载
with requests.get(url, stream=True, timeout=10) as response:
response.raise_for_status() # 检查HTTP错误状态
# 获取文件总大小(如果服务器提供)
total_size = int(response.headers.get('content-length', 0))
if total_size > filesize_max:
return None # 文件大小超过最大值,不下载
# 提取文件名
filename = None
content_disposition = response.headers.get('Content-Disposition')
if content_disposition:
# 使用正则表达式提取filename
match = re.search(r'filename="?([^"]+)"?', content_disposition)
if match:
filename = match.group(1)
# 处理可能的URL编码文件名
filename = unquote(filename)
# 如果响应头没有,从URL解析
if not filename:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)
# 处理URL编码的文件名
filename = unquote(filename)
# 如果仍然没有文件名,生成默认文件名
if not filename:
filename = f"download_{int(time.time())}.bin"
# 确保目标文件夹存在
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(folder, filename)
downloaded_size = 0
with open(save_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
if chunk: # 过滤空块
downloaded_size += len(chunk)
# 检查是否超过最大允许大小
if downloaded_size > filesize_max:
file.close()
os.remove(save_path)
return None
file.write(chunk)
# 下载完成后检查最小文件大小
if os.path.getsize(save_path) < filesize_min:
os.remove(save_path)
return None
# 下载成功后更新缓存
if save_path:
download_cache[url] = save_path
return save_path
except Exception as e:
# 发生异常时清理文件
if save_path and os.path.exists(save_path):
os.remove(save_path)
return None
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
log = ('初始化日志记录...\n')
print_log(task_id, '初始化日志记录', '开始')
yield [],log
if isinstance(model_input, str):
input_path = model_input
log += f'使用文件: {input_path}\n'
else:
input_path = model_input.name
log += f'已上传文件: {input_path}\n'
print_log(task_id, log.split('\n')[-1], '开始')
yield [], log
if not input_path:
log += ( f'未获得正确的模型文件\n')
print_log(task_id, f'未获得正确的模型文件', '错误')
yield [],log
return
if input_path.endswith('.onnx'):
onnx_path = input_path
log += ( '输入已经是 ONNX 文件\n')
print_log(task_id, '输入已经是 ONNX 文件', '跳过')
yield [],log
else:
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
if onnx_path:
log += ( f'成功生成ONNX模型: {onnx_path}\n')
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
else:
log += ( '生成ONNX模型失败\n')
print_log(task_id, '生成ONNX模型', '错误')
yield [],log
return
# 转换为 MNN 模型
output_name= os.path.splitext(os.path.basename(onnx_path))[0]
mnn_path = os.path.join(output_dir, f'{output_name}.mnn')
try:
log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
convertmnn(onnx_path, mnn_path,fp16)
yield onnx_path,log
except Exception as e:
log += ( f'转换 MNN 模型时出错: {str(e)}\n')
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
yield onnx_path,log
print_log(task_id, '模型转换任务完成', '完成')
# 转换为 MNN 模型后对文件检查
if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024: # 1KB = 1024 bytes
log += ( f'MNN 模型已保存到: {mnn_path}\n')
else:
log += ( 'MNN 模型生成失败或文件大小不足1KB\n')
mnn_path = None
yield onnx_path, mnn_path, log
with gr.Blocks() as demo:
gr.Markdown("# MNN模型转换工具")
model_type_opt = ['从链接下载', '直接上传文件']
with gr.Row():
with gr.Column():
input_type = gr.Radio(model_type_opt, label='模型文件来源')
url_input = gr.Textbox(label='模型链接')
file_input = gr.File(label='模型文件', visible=False)
def show_input(input_type):
if input_type == model_type_opt[0]:
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
# 添加fp16和try_run复选框
fp16 = gr.Checkbox(label="FP16", value=False)
try_run = gr.Checkbox(label="pymnnsr测试", value=False)
convert_btn = gr.Button("开始转换")
with gr.Column():
# with gr.Row():
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
onnx_output = gr.File(label="ONNX 模型输出",file_types=["filepath"])
mnn_output = gr.File(label="MNN 模型输出",file_types=["filepath"])
img_output = gr.Image(type="pil", label="测试输出(边缘正常说明模型转换成功,色彩有bug)" ,visible=False)
def show_try_run(try_run):
if try_run:
return gr.update(visible=True)
else:
return gr.update(visible=False)
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
global task_counter
task_counter += 1
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(os.getcwd(), f"output_{task_counter}_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
log=""
if input_type == model_type_opt[0] and url_input:
log = f'正在下载模型文件: {url_input}\n'
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
yield None, None, log, None
if url_input.startswith("https://drive.google.com/"):
model_input = download_gdrive_file(
url=url_input,
folder=output_dir,
filesize_max=200*1024*1024, # 200MB
filesize_min=1024 # 1KB
)
else:
model_input = download_file2folder(
url=url_input,
folder=output_dir,
filesize_max=200*1024*1024, # 200MB
filesize_min=1024 # 1KB
)
if not model_input:
log += f'\n模型文件下载失败\n'
print_log(task_counter, f'模型文件载', '失败')
yield None, None, log, None
return
log += f'\n模型文件已下载到: {model_input}\n'
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
yield None, None, log, None
elif input_type == model_type_opt[1] and file_input:
model_input = file_input
else:
# 改为通过yield返回错误日志
log = '\n请选择输入类型并提供有效的输入!'
yield None, None, log,None
return
onnx_path = None
mnn_path = None
# 调用重命名后的函数
async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
if isinstance(result, tuple) and len(result) == 3:
onnx_path, mnn_path, process_log = result
yield onnx_path, mnn_path, log+process_log,None
elif isinstance(result, tuple) and len(result) == 2:
# 处理纯日志yield
_, process_log = result
yield None, None, log+process_log,None
# yield onnx_path, mnn_path, log+process_log
if mnn_path and try_run:
processed_image_np = modelTest_for_gradio(mnn_path, "./sample.jpg")
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
# processed_image_pil = Image.fromarray(processed_image_np)
yield onnx_path, mnn_path, log+process_log,processed_image_pil
convert_btn.click(
process_model,
inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
outputs=[onnx_output, mnn_output, log_box, img_output],
api_name="convert_model"
)
# 将示例移至底部并包裹在列组件中
examples_column = gr.Column(visible=True)
with examples_column:
examples = [
[model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
[model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
[model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
[model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
[model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
]
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
demo.launch()