File size: 15,640 Bytes
5de10eb 2152dfa 62053b6 2152dfa a1b158c 5de10eb 2152dfa 62053b6 2152dfa 62053b6 2152dfa 62053b6 2152dfa 62053b6 2152dfa 62053b6 2152dfa 62053b6 2152dfa 62053b6 5de10eb 62053b6 a1b158c 62053b6 a1b158c 62053b6 a1b158c 62053b6 a1b158c 62053b6 2152dfa 62053b6 2152dfa 62053b6 2152dfa 62053b6 2152dfa 72c8c5e 55f5f0f 2152dfa 5de10eb 2152dfa 72c8c5e 55f5f0f 2152dfa 55f5f0f 72c8c5e 5de10eb 2152dfa 55f5f0f 2152dfa 5de10eb 2152dfa 55f5f0f 5de10eb 2152dfa 72c8c5e 2152dfa 55f5f0f 2152dfa 55f5f0f 2152dfa 5de10eb 2152dfa 55f5f0f 2152dfa 72c8c5e 5de10eb 2152dfa 55f5f0f 2152dfa 5de10eb 2152dfa 55f5f0f 2152dfa 55f5f0f 2152dfa 5de10eb 62053b6 6b67881 62053b6 6b67881 62053b6 2152dfa 6b67881 62053b6 6b67881 5de10eb 6b67881 5de10eb 6b67881 72c8c5e a1b158c 72c8c5e 55f5f0f 72c8c5e 62053b6 55f5f0f 72c8c5e 5de10eb 62053b6 a1b158c 62053b6 5de10eb 62053b6 55f5f0f 72c8c5e 5de10eb 62053b6 2152dfa 5de10eb 2152dfa 72c8c5e 2152dfa 55f5f0f 5de10eb 2152dfa 5de10eb 55f5f0f 2152dfa 5de10eb 2152dfa 6b67881 5de10eb 2152dfa 62053b6 2152dfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
from cgi import test
import gradio as gr
import requests
import os
import subprocess
from typing import Union
from pth2onnx import convert_pth_to_onnx
import re
import time
from urllib.parse import urlparse, unquote
import gdown
import sys
from typing import Optional
import datetime
from mnnsr import modelTest_for_gradio
from PIL import Image
import cv2
log_to_terminal = True
task_counter = 0
download_cache = {} # 格式: {url: 文件路径}
# 日志函数
def print_log(task_id, stage, status):
if log_to_terminal:
print(f"任务{task_id}: [{status}] {stage}")
# 使用 MNN 库自带的转换工具
def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
if fp16:
param.append('--fp16')
subprocess.run(param, check=True)
def download_file(url: str, save_path: str):
response = requests.get(url)
with open(save_path, 'wb') as f:
f.write(response.content)
def download_gdrive_file(
url: str,
folder: str,
filesize_max: int,
filesize_min: int = 0
) -> Optional[str]:
"""
从 Google Drive 链接获取文件信息,检查大小后下载文件。
Args:
url (str): Google Drive 的分享链接。
folder (str): 用于保存文件的目标文件夹路径。
filesize_max (int): 允许下载的最大文件大小(单位:字节)。
filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。
Returns:
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
"""
print(f"--- 开始处理链接: {url} ---")
# 准备下载路径并创建文件夹
try:
os.makedirs(folder, exist_ok=True)
except OSError as e:
print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
return None
# 4. 下载文件
try:
print(f"开始下载google drive文件到: {folder}")
# 保存当前工作目录
original_dir = os.getcwd()
# 切换到目标文件夹
os.chdir(folder)
# 下载文件,不指定output_path参数
downloaded_filename = gdown.download(url, quiet=True, fuzzy=True)
# 恢复原始工作目录
os.chdir(original_dir)
# 检查下载结果和文件信息
if downloaded_filename:
downloaded_path = os.path.join(folder, downloaded_filename)
if os.path.exists(downloaded_path):
# 获取文件大小
file_size = os.path.getsize(downloaded_path)
if file_size< filesize_max and file_size> filesize_min:
print(f"下载成功!文件已保存至: {downloaded_path}")
return downloaded_path
else:
print(f"文件大小超出范围: {file_size} bytes")
return None
else:
print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr)
return None
else:
print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr)
return None
except Exception as e:
print(f"下载过程中发生错误: {e}", file=sys.stderr)
return None
finally:
# 确保工作目录恢复,即使发生异常
if 'original_dir' in locals():
os.chdir(original_dir)
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
global download_cache
# 检查缓存是否存在且文件有效
if url in download_cache:
cached_path = download_cache[url]
if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min:
print(f" 使用缓存文件: {cached_path}")
return cached_path
save_path = None # 初始化save_path为None,避免except块中引用错误
try:
# 发送HTTP请求,流式下载
with requests.get(url, stream=True, timeout=10) as response:
response.raise_for_status() # 检查HTTP错误状态
# 获取文件总大小(如果服务器提供)
total_size = int(response.headers.get('content-length', 0))
if total_size > filesize_max:
return None # 文件大小超过最大值,不下载
# 提取文件名
filename = None
content_disposition = response.headers.get('Content-Disposition')
if content_disposition:
# 使用正则表达式提取filename
match = re.search(r'filename="?([^"]+)"?', content_disposition)
if match:
filename = match.group(1)
# 处理可能的URL编码文件名
filename = unquote(filename)
# 如果响应头没有,从URL解析
if not filename:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)
# 处理URL编码的文件名
filename = unquote(filename)
# 如果仍然没有文件名,生成默认文件名
if not filename:
filename = f"download_{int(time.time())}.bin"
# 确保目标文件夹存在
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(folder, filename)
downloaded_size = 0
with open(save_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
if chunk: # 过滤空块
downloaded_size += len(chunk)
# 检查是否超过最大允许大小
if downloaded_size > filesize_max:
file.close()
os.remove(save_path)
return None
file.write(chunk)
# 下载完成后检查最小文件大小
if os.path.getsize(save_path) < filesize_min:
os.remove(save_path)
return None
# 下载成功后更新缓存
if save_path:
download_cache[url] = save_path
return save_path
except Exception as e:
# 发生异常时清理文件
if save_path and os.path.exists(save_path):
os.remove(save_path)
return None
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
log = ('初始化日志记录...\n')
print_log(task_id, '初始化日志记录', '开始')
yield [],log
if isinstance(model_input, str):
input_path = model_input
log += f'使用文件: {input_path}\n'
else:
input_path = model_input.name
log += f'已上传文件: {input_path}\n'
print_log(task_id, log.split('\n')[-1], '开始')
yield [], log
if not input_path:
log += ( f'未获得正确的模型文件\n')
print_log(task_id, f'未获得正确的模型文件', '错误')
yield [],log
return
if input_path.endswith('.onnx'):
onnx_path = input_path
log += ( '输入已经是 ONNX 文件\n')
print_log(task_id, '输入已经是 ONNX 文件', '跳过')
yield [],log
else:
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
if onnx_path:
log += ( f'成功生成ONNX模型: {onnx_path}\n')
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
else:
log += ( '生成ONNX模型失败\n')
print_log(task_id, '生成ONNX模型', '错误')
yield [],log
return
# 转换为 MNN 模型
output_name= os.path.splitext(os.path.basename(onnx_path))[0]
mnn_path = os.path.join(output_dir, f'{output_name}.mnn')
try:
log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
convertmnn(onnx_path, mnn_path,fp16)
yield onnx_path,log
except Exception as e:
log += ( f'转换 MNN 模型时出错: {str(e)}\n')
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
yield onnx_path,log
print_log(task_id, '模型转换任务完成', '完成')
# 转换为 MNN 模型后对文件检查
if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024: # 1KB = 1024 bytes
log += ( f'MNN 模型已保存到: {mnn_path}\n')
else:
log += ( 'MNN 模型生成失败或文件大小不足1KB\n')
mnn_path = None
yield onnx_path, mnn_path, log
with gr.Blocks() as demo:
gr.Markdown("# MNN模型转换工具")
model_type_opt = ['从链接下载', '直接上传文件']
with gr.Row():
with gr.Column():
input_type = gr.Radio(model_type_opt, label='模型文件来源')
url_input = gr.Textbox(label='模型链接')
file_input = gr.File(label='模型文件', visible=False)
def show_input(input_type):
if input_type == model_type_opt[0]:
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
# 添加fp16和try_run复选框
fp16 = gr.Checkbox(label="FP16", value=False)
try_run = gr.Checkbox(label="pymnnsr测试", value=False)
convert_btn = gr.Button("开始转换")
with gr.Column():
# with gr.Row():
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
onnx_output = gr.File(label="ONNX 模型输出",file_types=["filepath"])
mnn_output = gr.File(label="MNN 模型输出",file_types=["filepath"])
img_output = gr.Image(type="pil", label="测试输出(边缘正常说明模型转换成功,色彩有bug)" ,visible=False)
def show_try_run(try_run):
if try_run:
return gr.update(visible=True)
else:
return gr.update(visible=False)
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
global task_counter
task_counter += 1
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(os.getcwd(), f"output_{task_counter}_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
log=""
if input_type == model_type_opt[0] and url_input:
log = f'正在下载模型文件: {url_input}\n'
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
yield None, None, log, None
if url_input.startswith("https://drive.google.com/"):
model_input = download_gdrive_file(
url=url_input,
folder=output_dir,
filesize_max=200*1024*1024, # 200MB
filesize_min=1024 # 1KB
)
else:
model_input = download_file2folder(
url=url_input,
folder=output_dir,
filesize_max=200*1024*1024, # 200MB
filesize_min=1024 # 1KB
)
if not model_input:
log += f'\n模型文件下载失败\n'
print_log(task_counter, f'模型文件载', '失败')
yield None, None, log, None
return
log += f'\n模型文件已下载到: {model_input}\n'
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
yield None, None, log, None
elif input_type == model_type_opt[1] and file_input:
model_input = file_input
else:
# 改为通过yield返回错误日志
log = '\n请选择输入类型并提供有效的输入!'
yield None, None, log,None
return
onnx_path = None
mnn_path = None
# 调用重命名后的函数
async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
if isinstance(result, tuple) and len(result) == 3:
onnx_path, mnn_path, process_log = result
yield onnx_path, mnn_path, log+process_log,None
elif isinstance(result, tuple) and len(result) == 2:
# 处理纯日志yield
_, process_log = result
yield None, None, log+process_log,None
# yield onnx_path, mnn_path, log+process_log
if mnn_path and try_run:
processed_image_np = modelTest_for_gradio(mnn_path, "./sample.jpg")
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
# processed_image_pil = Image.fromarray(processed_image_np)
yield onnx_path, mnn_path, log+process_log,processed_image_pil
convert_btn.click(
process_model,
inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
outputs=[onnx_output, mnn_output, log_box, img_output],
api_name="convert_model"
)
# 将示例移至底部并包裹在列组件中
examples_column = gr.Column(visible=True)
with examples_column:
examples = [
[model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
[model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
[model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
[model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
[model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
[model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
]
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
demo.launch() |