from cgi import test import gradio as gr import requests import os import subprocess from typing import Union from sympy import E from pth2onnx import convert_pth_to_onnx import re import time from urllib.parse import urlparse, unquote import gdown import sys from typing import Optional import datetime from mnnsr import modelTest_for_gradio from PIL import Image import cv2 log_to_terminal = True task_counter = 0 download_cache = {} # 格式: {url: 文件路径} # 日志函数 def print_log(task_id, stage, status): if log_to_terminal: print(f"任务{task_id}: [{status}] {stage}") # 使用 MNN 库自带的转换工具 def convertmnn(onnx_path: str, mnn_path: str, fp16=False): param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp'] if fp16: param.append('--fp16') try: subprocess.run(param, check=True) except Exception as e: print(f"转换 MNN 模型时出错: {e}") def download_file(url: str, save_path: str): response = requests.get(url) with open(save_path, 'wb') as f: f.write(response.content) def download_gdrive_file( url: str, folder: str, filesize_max: int, filesize_min: int = 0 ) -> Optional[str]: """ 从 Google Drive 链接获取文件信息,检查大小后下载文件。 Args: url (str): Google Drive 的分享链接。 folder (str): 用于保存文件的目标文件夹路径。 filesize_max (int): 允许下载的最大文件大小(单位:字节)。 filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。 Returns: Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。 """ # 检查缓存是否存在且文件有效 global download_cache if url in download_cache: cached_path = download_cache[url] if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min: print(f" 使用缓存文件: {cached_path}") return cached_path print(f"--- 开始处理链接: {url} ---") # 准备下载路径并创建文件夹 try: os.makedirs(folder, exist_ok=True) except OSError as e: print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr) return None # 4. 下载文件 try: print(f"开始下载google drive文件到: {folder}") # 保存当前工作目录 original_dir = os.getcwd() # 切换到目标文件夹 os.chdir(folder) # 下载文件,不指定output_path参数 downloaded_filename = gdown.download(url, quiet=True, fuzzy=True) # 恢复原始工作目录 os.chdir(original_dir) # 检查下载结果和文件信息 if downloaded_filename: downloaded_path = os.path.join(folder, downloaded_filename) if os.path.exists(downloaded_path): # 获取文件大小 file_size = os.path.getsize(downloaded_path) if file_size< filesize_max and file_size> filesize_min: print(f"下载成功!文件已保存至: {downloaded_path}") download_cache[url] = downloaded_path return downloaded_path else: print(f"文件大小超出范围: {file_size} bytes") return None else: print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr) return None else: print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr) return None except Exception as e: print(f"下载过程中发生错误: {e}", file=sys.stderr) return None finally: # 确保工作目录恢复,即使发生异常 if 'original_dir' in locals(): os.chdir(original_dir) def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]: global download_cache # 检查缓存是否存在且文件有效 if url in download_cache: cached_path = download_cache[url] if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min: print(f" 使用缓存文件: {cached_path}") return cached_path save_path = None # 初始化save_path为None,避免except块中引用错误 try: # 发送HTTP请求,流式下载 with requests.get(url, stream=True, timeout=10) as response: response.raise_for_status() # 检查HTTP错误状态 # 获取文件总大小(如果服务器提供) total_size = int(response.headers.get('content-length', 0)) if total_size > filesize_max: return None # 文件大小超过最大值,不下载 # 提取文件名 filename = None content_disposition = response.headers.get('Content-Disposition') if content_disposition: # 使用正则表达式提取filename match = re.search(r'filename="?([^"]+)"?', content_disposition) if match: filename = match.group(1) # 处理可能的URL编码文件名 filename = unquote(filename) # 如果响应头没有,从URL解析 if not filename: parsed_url = urlparse(url) filename = os.path.basename(parsed_url.path) # 处理URL编码的文件名 filename = unquote(filename) # 如果仍然没有文件名,生成默认文件名 if not filename: filename = f"download_{int(time.time())}.bin" # 确保目标文件夹存在 os.makedirs(folder, exist_ok=True) save_path = os.path.join(folder, filename) downloaded_size = 0 with open(save_path, 'wb') as file: for chunk in response.iter_content(chunk_size=8192): if chunk: # 过滤空块 downloaded_size += len(chunk) # 检查是否超过最大允许大小 if downloaded_size > filesize_max: file.close() os.remove(save_path) return None file.write(chunk) # 下载完成后检查最小文件大小 if os.path.getsize(save_path) < filesize_min: os.remove(save_path) return None # 下载成功后更新缓存 if save_path: download_cache[url] = save_path return save_path except Exception as e: # 发生异常时清理文件 if save_path and os.path.exists(save_path): os.remove(save_path) return None async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int,dynamic_axes:bool,fp16mnn:bool): log = ('初始化日志记录...\n') print_log(task_id, '初始化日志记录', '开始') yield [],log if isinstance(model_input, str): input_path = model_input log += f'使用文件: {input_path}\n' else: input_path = model_input.name log += f'已上传文件: {input_path}\n' print_log(task_id, log.split('\n')[-1], '开始') yield [], log if not input_path: log += ( f'未获得正确的模型文件\n') print_log(task_id, f'未获得正确的模型文件', '错误') yield [],log return if input_path.endswith('.onnx'): onnx_path = input_path log += ( '输入已经是 ONNX 文件\n') print_log(task_id, '输入已经是 ONNX 文件', '跳过') yield [],log else: print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始') onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset, dynamic_axes=dynamic_axes) if onnx_path: log += ( f'成功生成ONNX模型: {onnx_path}\n') print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成') else: log += ( '生成ONNX模型失败\n') print_log(task_id, '生成ONNX模型', '错误') yield [],log return # 转换为 MNN 模型 output_name= os.path.splitext(os.path.basename(onnx_path))[0] mnn_path = os.path.join(output_dir, f'{output_name}.mnn') try: log += ( '正在将 ONNX 模型转换为 MNN 格式...\n') print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始') convertmnn(onnx_path, mnn_path, fp16mnn) yield onnx_path,log except Exception as e: log += ( f'转换 MNN 模型时出错: {str(e)}\n') print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误') yield onnx_path,log print_log(task_id, '模型转换任务完成', '完成') # 转换为 MNN 模型后对文件检查 if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024: # 1KB = 1024 bytes log += ( f'MNN 模型已保存到: {mnn_path}\n') else: log += ( 'MNN 模型生成失败或文件大小不足1KB\n') mnn_path = None yield onnx_path, mnn_path, log with gr.Blocks() as demo: gr.Markdown("# MNN Model Converter | MNN模型转换工具") model_type_opt = ['Download Link', 'Upload File'] with gr.Row(): with gr.Column(): input_type = gr.Radio(model_type_opt, value =model_type_opt[0] ,label='Input Model') url_input = gr.Textbox(label="") file_input = gr.File(label='', visible=False) def show_input(input_type): if input_type == model_type_opt[0]: return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input]) tilesize = gr.Number(label="Dummy input width/height, default 64", value=64, precision=0) opset = gr.Number(label="ONNX export opset version, suggest 9/11/13/16/17/18", value=13, precision=0) fp16 = gr.Checkbox(label="Use fp16 in ONNX export", value=False) onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False) dynamic_axes = gr.Checkbox(label="ONNX input apply dynamic axes", value=True) fp16mnn = gr.Checkbox(label="Use fp16 in MNN convert, only reduce filesize", value=True) try_run = gr.Checkbox(label="MNNSR test", value=False) convert_btn = gr.Button("Run") with gr.Column(): # with gr.Row(): log_box = gr.Textbox(label="Log", lines=10, interactive=False) onnx_output = gr.File(label="ONNX",file_types=["filepath"]) mnn_output = gr.File(label="MNN",file_types=["filepath"]) img_output = gr.Image(type="pil", label="MNNSR Image Output" ,visible=False) def show_try_run(try_run): if try_run: return gr.update(visible=True) else: return gr.update(visible=False) try_run.change(show_try_run, inputs=try_run, outputs=img_output) async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, fp16mnn,try_run): global task_counter task_counter += 1 timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(os.getcwd(),"output", f"{task_counter}_{timestamp}") os.makedirs(output_dir, exist_ok=True) log="" if input_type == model_type_opt[0] and url_input: log = f'正在下载模型文件: {url_input}\n' print_log(task_counter, f'正在下载模型文件: {url_input}', '开始') yield None, None, log, None if url_input.startswith("https://drive.google.com/"): model_input = download_gdrive_file( url=url_input, folder=output_dir, filesize_max=200*1024*1024, # 200MB filesize_min=1024 # 1KB ) else: model_input = download_file2folder( url=url_input, folder=output_dir, filesize_max=200*1024*1024, # 200MB filesize_min=1024 # 1KB ) if not model_input: log += '模型文件下载失败\n' print_log(task_counter, f'模型文件载', '失败') yield None, None, log, None return log += f'模型文件已下载到: {model_input}\n' print_log(task_counter, f'模型文件已下载到: {model_input}', '完成') yield None, None, log, None elif input_type == model_type_opt[1] and file_input: model_input = file_input else: log = '请选择输入类型并提供有效的输入!\n' yield None, None, log,None return onnx_path = None mnn_path = None # 调用重命名后的函数 async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset, dynamic_axes, fp16mnn): if isinstance(result, tuple) and len(result) == 3: onnx_path, mnn_path, process_log = result yield onnx_path, mnn_path, log+process_log, None elif isinstance(result, tuple) and len(result) == 2: # 处理纯日志yield _, process_log = result yield None, None, log+process_log, None # yield onnx_path, mnn_path, log+process_log if mnn_path: if try_run: print_log(task_counter, f'测试模型: {mnn_path}', '开始') processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg", tilesize if tilesize>0 and dynamic_axes else 0, 0) processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB)) # processed_image_pil = Image.fromarray(processed_image_np) yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil else: yield onnx_path, mnn_path, log+process_log, None return convert_btn.click( process_model, inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, fp16mnn, try_run], outputs=[onnx_output, mnn_output, log_box, img_output], api_name="convert_nmm_model" ) examples_column = gr.Column(visible=True) with examples_column: examples = [ [model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"], [model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"], [model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"], [model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"], [model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"], [model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"], [model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"], [model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"], [model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/1x-MangaJPEGHQ.pth"], ] example_input = gr.Examples(examples=examples, inputs=[input_type, url_input]) demo.launch(ssr_mode=False, server_name="0.0.0.0")