File size: 15,640 Bytes
5de10eb
2152dfa
 
 
 
 
 
62053b6
 
 
 
 
2152dfa
a1b158c
5de10eb
 
 
 
 
 
2152dfa
 
 
62053b6
2152dfa
62053b6
2152dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62053b6
 
 
 
 
 
2152dfa
62053b6
 
2152dfa
62053b6
 
 
 
 
2152dfa
62053b6
2152dfa
62053b6
5de10eb
62053b6
 
 
 
 
 
 
 
a1b158c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62053b6
a1b158c
62053b6
 
a1b158c
62053b6
a1b158c
 
 
 
 
62053b6
 
 
 
 
 
 
 
 
 
 
 
2152dfa
 
 
 
 
 
 
 
 
 
 
62053b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2152dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62053b6
 
 
 
2152dfa
 
 
62053b6
2152dfa
 
 
72c8c5e
55f5f0f
2152dfa
5de10eb
2152dfa
72c8c5e
 
55f5f0f
2152dfa
 
55f5f0f
72c8c5e
5de10eb
2152dfa
 
55f5f0f
2152dfa
5de10eb
2152dfa
 
 
 
 
55f5f0f
 
5de10eb
2152dfa
 
72c8c5e
2152dfa
55f5f0f
2152dfa
 
55f5f0f
2152dfa
5de10eb
2152dfa
 
 
 
 
 
 
55f5f0f
2152dfa
72c8c5e
5de10eb
2152dfa
55f5f0f
2152dfa
5de10eb
2152dfa
 
 
 
 
55f5f0f
2152dfa
55f5f0f
2152dfa
 
 
 
 
5de10eb
62053b6
6b67881
 
62053b6
6b67881
62053b6
2152dfa
6b67881
62053b6
6b67881
 
 
 
 
 
 
 
 
5de10eb
6b67881
 
 
 
5de10eb
 
 
 
 
 
 
 
 
6b67881
 
72c8c5e
 
a1b158c
 
 
72c8c5e
55f5f0f
72c8c5e
62053b6
55f5f0f
72c8c5e
5de10eb
62053b6
a1b158c
62053b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5de10eb
62053b6
 
55f5f0f
72c8c5e
5de10eb
62053b6
2152dfa
 
 
 
5de10eb
2152dfa
 
 
 
 
72c8c5e
2152dfa
55f5f0f
5de10eb
2152dfa
 
 
5de10eb
55f5f0f
2152dfa
5de10eb
 
 
 
 
 
2152dfa
 
6b67881
5de10eb
2152dfa
 
 
 
 
 
 
62053b6
 
 
 
 
 
 
2152dfa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
from cgi import test
import gradio as gr
import requests
import os
import subprocess
from typing import Union
from pth2onnx import convert_pth_to_onnx
import re
import time
from urllib.parse import urlparse, unquote
import gdown
import sys
from typing import Optional
import datetime
from mnnsr import modelTest_for_gradio
from PIL import Image
import cv2




log_to_terminal = True
task_counter = 0
download_cache = {}  # 格式: {url: 文件路径}

# 日志函数
def print_log(task_id, stage, status):
    if log_to_terminal:
        print(f"任务{task_id}: [{status}] {stage}")

# 使用 MNN 库自带的转换工具
def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
    param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
    if fp16:
        param.append('--fp16')
    subprocess.run(param, check=True)

def download_file(url: str, save_path: str):
    response = requests.get(url)
    with open(save_path, 'wb') as f:
        f.write(response.content)

def download_gdrive_file(

    url: str, 

    folder: str, 

    filesize_max: int, 

    filesize_min: int = 0

) -> Optional[str]:
    """

    从 Google Drive 链接获取文件信息,检查大小后下载文件。



    Args:

        url (str): Google Drive 的分享链接。

        folder (str): 用于保存文件的目标文件夹路径。

        filesize_max (int): 允许下载的最大文件大小(单位:字节)。

        filesize_min (int): 允许下载的最小文件大小(单位:字节),默认为 0。



    Returns:

        Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。

    """
    print(f"--- 开始处理链接: {url} ---")
    # 准备下载路径并创建文件夹
    try:
        os.makedirs(folder, exist_ok=True)
    except OSError as e:
        print(f"错误:创建文件夹 '{folder}' 失败: {e}", file=sys.stderr)
        return None

    # 4. 下载文件
    try:
        print(f"开始下载google drive文件到: {folder}")
        # 保存当前工作目录
        original_dir = os.getcwd()
        # 切换到目标文件夹
        os.chdir(folder)
        # 下载文件,不指定output_path参数
        downloaded_filename = gdown.download(url, quiet=True, fuzzy=True)
        # 恢复原始工作目录
        os.chdir(original_dir)
        
        # 检查下载结果和文件信息
        if downloaded_filename:
            downloaded_path = os.path.join(folder, downloaded_filename)
            if os.path.exists(downloaded_path):
                # 获取文件大小
                file_size = os.path.getsize(downloaded_path)
                if file_size< filesize_max and file_size> filesize_min:
                    print(f"下载成功!文件已保存至: {downloaded_path}")
                    return downloaded_path
                else:
                    print(f"文件大小超出范围: {file_size} bytes")
                    return None
            else:
                print(f"错误:下载的文件 '{downloaded_path}' 不存在。", file=sys.stderr)
                return None
        else:
            print("错误:gdown 下载过程未返回有效文件名。", file=sys.stderr)
            return None
    except Exception as e:
        print(f"下载过程中发生错误: {e}", file=sys.stderr)
        return None
    finally:
        # 确保工作目录恢复,即使发生异常
        if 'original_dir' in locals():
            os.chdir(original_dir)
            

def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
    global download_cache
    
    # 检查缓存是否存在且文件有效
    if url in download_cache:
        cached_path = download_cache[url]
        if os.path.exists(cached_path) and os.path.getsize(cached_path) >= filesize_min:
            print(f"    使用缓存文件: {cached_path}")
            return cached_path
    
    save_path = None  # 初始化save_path为None,避免except块中引用错误
    
    try:
        # 发送HTTP请求,流式下载
        with requests.get(url, stream=True, timeout=10) as response:
            response.raise_for_status()  # 检查HTTP错误状态
            
            # 获取文件总大小(如果服务器提供)
            total_size = int(response.headers.get('content-length', 0))
            if total_size > filesize_max:
                return None  # 文件大小超过最大值,不下载
            
            # 提取文件名
            filename = None
            content_disposition = response.headers.get('Content-Disposition')
            if content_disposition:
                # 使用正则表达式提取filename
                match = re.search(r'filename="?([^"]+)"?', content_disposition)
                if match:
                    filename = match.group(1)
                    # 处理可能的URL编码文件名
                    filename = unquote(filename)
            
            # 如果响应头没有,从URL解析
            if not filename:
                parsed_url = urlparse(url)
                filename = os.path.basename(parsed_url.path)
                # 处理URL编码的文件名
                filename = unquote(filename)
            
            # 如果仍然没有文件名,生成默认文件名
            if not filename:
                filename = f"download_{int(time.time())}.bin"
            
            # 确保目标文件夹存在
            os.makedirs(folder, exist_ok=True)
            save_path = os.path.join(folder, filename)
            
            downloaded_size = 0
            with open(save_path, 'wb') as file:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:  # 过滤空块
                        downloaded_size += len(chunk)
                        # 检查是否超过最大允许大小
                        if downloaded_size > filesize_max:
                            file.close()
                            os.remove(save_path)
                            return None
                        file.write(chunk)
        
        # 下载完成后检查最小文件大小
        if os.path.getsize(save_path) < filesize_min:
            os.remove(save_path)
            return None
        
        # 下载成功后更新缓存
        if save_path:
            download_cache[url] = save_path
        return save_path
    
    except Exception as e:
        # 发生异常时清理文件
        if save_path and os.path.exists(save_path):
            os.remove(save_path)
        return None

async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
    log = ('初始化日志记录...\n')
    print_log(task_id, '初始化日志记录', '开始')
    yield [],log

    if isinstance(model_input, str):
        input_path = model_input
        log += f'使用文件: {input_path}\n'
    else:
        input_path = model_input.name
        log += f'已上传文件: {input_path}\n'
    print_log(task_id, log.split('\n')[-1], '开始')
    yield [], log

    if not input_path:
        log += ( f'未获得正确的模型文件\n')
        print_log(task_id, f'未获得正确的模型文件', '错误')
        yield  [],log
        return


    if input_path.endswith('.onnx'):
        onnx_path = input_path
        log += ( '输入已经是 ONNX 文件\n')
        print_log(task_id, '输入已经是 ONNX 文件', '跳过')
        yield [],log
    else:
        print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
        onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
        if onnx_path:
            log += ( f'成功生成ONNX模型: {onnx_path}\n')
            print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
        else:
            log += ( '生成ONNX模型失败\n')
            print_log(task_id, '生成ONNX模型', '错误')
            yield [],log
            return


    # 转换为 MNN 模型
    output_name= os.path.splitext(os.path.basename(onnx_path))[0]
    mnn_path = os.path.join(output_dir,  f'{output_name}.mnn')
    try:
        log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
        print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
        convertmnn(onnx_path, mnn_path,fp16)
        yield onnx_path,log
    except Exception as e:
        log += ( f'转换 MNN 模型时出错: {str(e)}\n')
        print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
        yield onnx_path,log

    print_log(task_id, '模型转换任务完成', '完成')
    
    # 转换为 MNN 模型后对文件检查
    if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024:  # 1KB = 1024 bytes
        log += ( f'MNN 模型已保存到: {mnn_path}\n')
    else:
        log += ( 'MNN 模型生成失败或文件大小不足1KB\n')
        mnn_path = None
    
    yield onnx_path, mnn_path, log

with gr.Blocks() as demo:
    gr.Markdown("# MNN模型转换工具")
    model_type_opt = ['从链接下载', '直接上传文件']
    with gr.Row():
        with gr.Column():
            input_type = gr.Radio(model_type_opt, label='模型文件来源')
            url_input = gr.Textbox(label='模型链接')
            file_input = gr.File(label='模型文件', visible=False)

            def show_input(input_type):
                if input_type == model_type_opt[0]:
                    return gr.update(visible=True), gr.update(visible=False)
                else:
                    return gr.update(visible=False), gr.update(visible=True)
            
            input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
            
            tilesize = gr.Number(label="Tilesize", value=0, precision=0)
            # 添加fp16和try_run复选框
            fp16 = gr.Checkbox(label="FP16", value=False)
            try_run = gr.Checkbox(label="pymnnsr测试", value=False)
            convert_btn = gr.Button("开始转换")
        with gr.Column():
        # with gr.Row():
            log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
            onnx_output = gr.File(label="ONNX 模型输出",file_types=["filepath"])
            mnn_output = gr.File(label="MNN 模型输出",file_types=["filepath"])
            img_output = gr.Image(type="pil", label="测试输出(边缘正常说明模型转换成功,色彩有bug)" ,visible=False)
            def show_try_run(try_run):
                if try_run:
                    return gr.update(visible=True)
                else:
                    return gr.update(visible=False)
            try_run.change(show_try_run, inputs=try_run, outputs=img_output)

    async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
        global task_counter
        task_counter += 1

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join(os.getcwd(), f"output_{task_counter}_{timestamp}")
        os.makedirs(output_dir, exist_ok=True)
        log=""

        if input_type == model_type_opt[0] and url_input:
            log = f'正在下载模型文件: {url_input}\n'
            print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
            yield None, None, log, None

            if url_input.startswith("https://drive.google.com/"):
                model_input = download_gdrive_file(
                    url=url_input,
                    folder=output_dir,
                    filesize_max=200*1024*1024,  # 200MB
                    filesize_min=1024     # 1KB
                )
            else:
                model_input = download_file2folder(
                    url=url_input,
                    folder=output_dir,
                    filesize_max=200*1024*1024,  # 200MB
                    filesize_min=1024     # 1KB
                )

            if not model_input:
                log += f'\n模型文件下载失败\n'
                print_log(task_counter, f'模型文件载', '失败')
                yield None, None, log, None
                return
            
            log += f'\n模型文件已下载到: {model_input}\n'
            print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
            yield None, None, log, None
        elif input_type == model_type_opt[1] and file_input:
            model_input = file_input
        else:
            # 改为通过yield返回错误日志
            log = '\n请选择输入类型并提供有效的输入!'
            yield None, None, log,None
            return
        
        onnx_path = None
        mnn_path = None
        # 调用重命名后的函数
        async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
            if isinstance(result, tuple) and len(result) == 3:
                onnx_path, mnn_path, process_log = result
                yield onnx_path, mnn_path, log+process_log,None
            elif isinstance(result, tuple) and len(result) == 2:
                # 处理纯日志yield
                _, process_log = result
                yield None, None, log+process_log,None
            # yield onnx_path, mnn_path, log+process_log

            if mnn_path and try_run:
                processed_image_np = modelTest_for_gradio(mnn_path, "./sample.jpg")
                processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
                # processed_image_pil = Image.fromarray(processed_image_np)
                yield onnx_path, mnn_path, log+process_log,processed_image_pil

    convert_btn.click(
        process_model,
        inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
        outputs=[onnx_output, mnn_output, log_box, img_output],
        api_name="convert_model"
    )

    # 将示例移至底部并包裹在列组件中
    examples_column = gr.Column(visible=True)
    with examples_column:
        examples = [
            [model_type_opt[0], "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
            [model_type_opt[0], "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
            [model_type_opt[0], "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
            [model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
            [model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
            [model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
            [model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
        ]
        example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
    
demo.launch()