|
from cgi import test
|
|
import gradio as gr
|
|
import requests
|
|
import os
|
|
import subprocess
|
|
from typing import Union
|
|
from pth2onnx import convert_pth_to_onnx
|
|
import re
|
|
import time
|
|
from urllib.parse import urlparse, unquote
|
|
import gdown
|
|
import sys
|
|
from typing import Optional
|
|
import datetime
|
|
from mnnsr import modelTest_for_gradio
|
|
from PIL import Image
|
|
import cv2
|
|
|
|
|
|
|
|
|
|
log_to_terminal = True
|
|
task_counter = 0
|
|
download_cache = {}
|
|
|
|
|
|
def print_log(task_id, stage, status):
|
|
if log_to_terminal:
|
|
print(f"任务{task_id}: [{status}] {stage}")
|
|
|
|
|
|
def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
|
|
param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
|
|
if fp16:
|
|
param.append('--fp16')
|
|
subprocess.run(param, check=True)
|
|
|
|
def download_file(url: str, save_path: str):
|
|
response = requests.get(url)
|
|
with open(save_path, 'wb') as f:
|
|
f.write(response.content)
|
|
|
|
def download_gdrive_file(
|
|
url: str,
|
|
folder: str,
|
|
filesize_max: int,
|
|
filesize_min: int = 0
|
|
) -> Optional[str]:
|
|
"""
|
|
从 Google Drive 链接获取文件信息,检查大小后下载文件。
|
|
|
|
Args:
|
|
url (str): Google Drive 的分享链接。
|
|
folder (str): 用于保存文件的目标文件夹路径。
|
|
filesize_max (int): 允许下载的最大文件大小(单位:字节)。
|
|
filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。
|
|
|
|
Returns:
|
|
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
|
|
"""
|
|
print(f"--- 开始处理链接: {url} ---")
|
|
|
|
try:
|
|
os.makedirs(folder, exist_ok=True)
|
|
except OSError as e:
|
|
print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
|
|
try:
|
|
print(f"开始下载google drive文件到: {folder}")
|
|
|
|
original_dir = os.getcwd()
|
|
|
|
os.chdir(folder)
|
|
|
|
downloaded_filename = gdown.download(url, quiet=True, fuzzy=True)
|
|
|
|
os.chdir(original_dir)
|
|
|
|
|
|
if downloaded_filename:
|
|
downloaded_path = os.path.join(folder, downloaded_filename)
|
|
if os.path.exists(downloaded_path):
|
|
|
|
file_size = os.path.getsize(downloaded_path)
|
|
if file_size< filesize_max and file_size> filesize_min:
|
|
print(f"下载成功!文件已保存至: {downloaded_path}")
|
|
return downloaded_path
|
|
else:
|
|
print(f"文件大小超出范围: {file_size} bytes")
|
|
return None
|
|
else:
|
|
print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr)
|
|
return None
|
|
else:
|
|
print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr)
|
|
return None
|
|
except Exception as e:
|
|
print(f"下载过程中发生错误: {e}", file=sys.stderr)
|
|
return None
|
|
finally:
|
|
|
|
if 'original_dir' in locals():
|
|
os.chdir(original_dir)
|
|
|
|
|
|
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
|
|
global download_cache
|
|
|
|
|
|
if url in download_cache:
|
|
cached_path = download_cache[url]
|
|
if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min:
|
|
print(f" 使用缓存文件: {cached_path}")
|
|
return cached_path
|
|
|
|
save_path = None
|
|
|
|
try:
|
|
|
|
with requests.get(url, stream=True, timeout=10) as response:
|
|
response.raise_for_status()
|
|
|
|
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
if total_size > filesize_max:
|
|
return None
|
|
|
|
|
|
filename = None
|
|
content_disposition = response.headers.get('Content-Disposition')
|
|
if content_disposition:
|
|
|
|
match = re.search(r'filename="?([^"]+)"?', content_disposition)
|
|
if match:
|
|
filename = match.group(1)
|
|
|
|
filename = unquote(filename)
|
|
|
|
|
|
if not filename:
|
|
parsed_url = urlparse(url)
|
|
filename = os.path.basename(parsed_url.path)
|
|
|
|
filename = unquote(filename)
|
|
|
|
|
|
if not filename:
|
|
filename = f"download_{int(time.time())}.bin"
|
|
|
|
|
|
os.makedirs(folder, exist_ok=True)
|
|
save_path = os.path.join(folder, filename)
|
|
|
|
downloaded_size = 0
|
|
with open(save_path, 'wb') as file:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
downloaded_size += len(chunk)
|
|
|
|
if downloaded_size > filesize_max:
|
|
file.close()
|
|
os.remove(save_path)
|
|
return None
|
|
file.write(chunk)
|
|
|
|
|
|
if os.path.getsize(save_path) < filesize_min:
|
|
os.remove(save_path)
|
|
return None
|
|
|
|
|
|
if save_path:
|
|
download_cache[url] = save_path
|
|
return save_path
|
|
|
|
except Exception as e:
|
|
|
|
if save_path and os.path.exists(save_path):
|
|
os.remove(save_path)
|
|
return None
|
|
|
|
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
|
|
log = ('初始化日志记录...\n')
|
|
print_log(task_id, '初始化日志记录', '开始')
|
|
yield [],log
|
|
|
|
if isinstance(model_input, str):
|
|
input_path = model_input
|
|
log += f'使用文件: {input_path}\n'
|
|
else:
|
|
input_path = model_input.name
|
|
log += f'已上传文件: {input_path}\n'
|
|
print_log(task_id, log.split('\n')[-1], '开始')
|
|
yield [], log
|
|
|
|
if not input_path:
|
|
log += ( f'未获得正确的模型文件\n')
|
|
print_log(task_id, f'未获得正确的模型文件', '错误')
|
|
yield [],log
|
|
return
|
|
|
|
|
|
if input_path.endswith('.onnx'):
|
|
onnx_path = input_path
|
|
log += ( '输入已经是 ONNX 文件\n')
|
|
print_log(task_id, '输入已经是 ONNX 文件', '跳过')
|
|
yield [],log
|
|
else:
|
|
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
|
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
|
|
if onnx_path:
|
|
log += ( f'成功生成ONNX模型: {onnx_path}\n')
|
|
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
|
else:
|
|
log += ( '生成ONNX模型失败\n')
|
|
print_log(task_id, '生成ONNX模型', '错误')
|
|
yield [],log
|
|
return
|
|
|
|
|
|
|
|
output_name= os.path.splitext(os.path.basename(onnx_path))[0]
|
|
mnn_path = os.path.join(output_dir, f'{output_name}.mnn')
|
|
try:
|
|
log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
|
|
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
|
|
convertmnn(onnx_path, mnn_path,fp16)
|
|
yield onnx_path,log
|
|
except Exception as e:
|
|
log += ( f'转换 MNN 模型时出错: {str(e)}\n')
|
|
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
|
|
yield onnx_path,log
|
|
|
|
print_log(task_id, '模型转换任务完成', '完成')
|
|
|
|
|
|
if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024:
|
|
log += ( f'MNN 模型已保存到: {mnn_path}\n')
|
|
else:
|
|
log += ( 'MNN 模型生成失败或文件大小不足1KB\n')
|
|
mnn_path = None
|
|
|
|
yield onnx_path, mnn_path, log
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("# MNN模型转换工具")
|
|
model_type_opt = ['从链接下载', '直接上传文件']
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_type = gr.Radio(model_type_opt, label='模型文件来源')
|
|
url_input = gr.Textbox(label='模型链接')
|
|
file_input = gr.File(label='模型文件', visible=False)
|
|
|
|
def show_input(input_type):
|
|
if input_type == model_type_opt[0]:
|
|
return gr.update(visible=True), gr.update(visible=False)
|
|
else:
|
|
return gr.update(visible=False), gr.update(visible=True)
|
|
|
|
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
|
|
|
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
|
|
|
|
fp16 = gr.Checkbox(label="FP16", value=False)
|
|
try_run = gr.Checkbox(label="pymnnsr测试", value=False)
|
|
convert_btn = gr.Button("开始转换")
|
|
with gr.Column():
|
|
|
|
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
|
|
onnx_output = gr.File(label="ONNX 模型输出",file_types=["filepath"])
|
|
mnn_output = gr.File(label="MNN 模型输出",file_types=["filepath"])
|
|
img_output = gr.Image(type="pil", label="测试输出(边缘正常说明模型转换成功,色彩有bug)" ,visible=False)
|
|
def show_try_run(try_run):
|
|
if try_run:
|
|
return gr.update(visible=True)
|
|
else:
|
|
return gr.update(visible=False)
|
|
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
|
|
|
|
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
|
|
global task_counter
|
|
task_counter += 1
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
output_dir = os.path.join(os.getcwd(), f"output_{task_counter}_{timestamp}")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
log=""
|
|
|
|
if input_type == model_type_opt[0] and url_input:
|
|
log = f'正在下载模型文件: {url_input}\n'
|
|
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
|
yield None, None, log, None
|
|
|
|
if url_input.startswith("https://drive.google.com/"):
|
|
model_input = download_gdrive_file(
|
|
url=url_input,
|
|
folder=output_dir,
|
|
filesize_max=200*1024*1024,
|
|
filesize_min=1024
|
|
)
|
|
else:
|
|
model_input = download_file2folder(
|
|
url=url_input,
|
|
folder=output_dir,
|
|
filesize_max=200*1024*1024,
|
|
filesize_min=1024
|
|
)
|
|
|
|
if not model_input:
|
|
log += f'\n模型文件下载失败\n'
|
|
print_log(task_counter, f'模型文件载', '失败')
|
|
yield None, None, log, None
|
|
return
|
|
|
|
log += f'\n模型文件已下载到: {model_input}\n'
|
|
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
|
|
yield None, None, log, None
|
|
elif input_type == model_type_opt[1] and file_input:
|
|
model_input = file_input
|
|
else:
|
|
|
|
log = '\n请选择输入类型并提供有效的输入!'
|
|
yield None, None, log,None
|
|
return
|
|
|
|
onnx_path = None
|
|
mnn_path = None
|
|
|
|
async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
|
|
if isinstance(result, tuple) and len(result) == 3:
|
|
onnx_path, mnn_path, process_log = result
|
|
yield onnx_path, mnn_path, log+process_log,None
|
|
elif isinstance(result, tuple) and len(result) == 2:
|
|
|
|
_, process_log = result
|
|
yield None, None, log+process_log,None
|
|
|
|
|
|
if mnn_path and try_run:
|
|
processed_image_np = modelTest_for_gradio(mnn_path, "./sample.jpg")
|
|
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
|
|
|
yield onnx_path, mnn_path, log+process_log,processed_image_pil
|
|
|
|
convert_btn.click(
|
|
process_model,
|
|
inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
|
|
outputs=[onnx_output, mnn_output, log_box, img_output],
|
|
api_name="convert_model"
|
|
)
|
|
|
|
|
|
examples_column = gr.Column(visible=True)
|
|
with examples_column:
|
|
examples = [
|
|
[model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
|
|
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
|
|
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
|
|
[model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
|
|
[model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
|
|
[model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
|
|
[model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
|
|
]
|
|
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
|
|
|
|
demo.launch() |