tumuyan2 commited on
Commit
5de10eb
·
1 Parent(s): a1b158c
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +39 -22
  3. app_mnnsr.py +28 -0
  4. mnnsr.py +150 -0
  5. requirements.txt +4 -1
  6. sample.jpg +0 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  /__pycache__
2
  /output*
 
 
1
  /__pycache__
2
  /output*
3
+ *.mnn
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import requests
3
  import os
@@ -11,6 +12,12 @@ import gdown
11
  import sys
12
  from typing import Optional
13
  import datetime
 
 
 
 
 
 
14
 
15
  log_to_terminal = True
16
  task_counter = 0
@@ -52,9 +59,7 @@ def download_gdrive_file(
52
  Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
53
  """
54
  print(f"--- 开始处理链接: {url} ---")
55
-
56
-
57
- # 3. 准备下载路径并创建文件夹
58
  try:
59
  os.makedirs(folder, exist_ok=True)
60
  except OSError as e:
@@ -179,7 +184,7 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
179
  async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
180
  log = ('初始化日志记录...\n')
181
  print_log(task_id, '初始化日志记录', '开始')
182
- yield [],[], log
183
 
184
  if isinstance(model_input, str):
185
  input_path = model_input
@@ -188,12 +193,12 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
188
  input_path = model_input.name
189
  log += f'已上传文件: {input_path}\n'
190
  print_log(task_id, log.split('\n')[-1], '开始')
191
- yield [], [], log
192
 
193
  if not input_path:
194
  log += ( f'未获得正确的模型文件\n')
195
  print_log(task_id, f'未获得正确的模型文件', '错误')
196
- yield [],[], log
197
  return
198
 
199
 
@@ -201,7 +206,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
201
  onnx_path = input_path
202
  log += ( '输入已经是 ONNX 文件\n')
203
  print_log(task_id, '输入已经是 ONNX 文件', '跳过')
204
- yield [],[], log
205
  else:
206
  print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
207
  onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
@@ -211,7 +216,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
211
  else:
212
  log += ( '生成ONNX模型失败\n')
213
  print_log(task_id, '生成ONNX模型', '错误')
214
- yield [], [], log
215
  return
216
 
217
 
@@ -222,11 +227,11 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
222
  log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
223
  print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
224
  convertmnn(onnx_path, mnn_path,fp16)
225
- yield onnx_path,[], log
226
  except Exception as e:
227
  log += ( f'转换 MNN 模型时出错: {str(e)}\n')
228
  print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
229
- yield onnx_path,[], log
230
 
231
  print_log(task_id, '模型转换任务完成', '完成')
232
 
@@ -240,7 +245,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
240
  yield onnx_path, mnn_path, log
241
 
242
  with gr.Blocks() as demo:
243
- gr.Markdown("# 模型转换工具")
244
  model_type_opt = ['从链接下载', '直接上传文件']
245
  with gr.Row():
246
  with gr.Column():
@@ -259,13 +264,20 @@ with gr.Blocks() as demo:
259
  tilesize = gr.Number(label="Tilesize", value=0, precision=0)
260
  # 添加fp16和try_run复选框
261
  fp16 = gr.Checkbox(label="FP16", value=False)
262
- try_run = gr.Checkbox(label="测试mnn推理", value=False)
263
  convert_btn = gr.Button("开始转换")
264
  with gr.Column():
265
  # with gr.Row():
266
  log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
267
- onnx_output = gr.File(label="ONNX 模型输出")
268
- mnn_output = gr.File(label="MNN 模型输出")
 
 
 
 
 
 
 
269
 
270
  async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
271
  global task_counter
@@ -279,7 +291,7 @@ with gr.Blocks() as demo:
279
  if input_type == model_type_opt[0] and url_input:
280
  log = f'正在下载模型文件: {url_input}\n'
281
  print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
282
- yield None, None, log
283
 
284
  if url_input.startswith("https://drive.google.com/"):
285
  model_input = download_gdrive_file(
@@ -296,22 +308,21 @@ with gr.Blocks() as demo:
296
  filesize_min=1024 # 1KB
297
  )
298
 
299
-
300
  if not model_input:
301
  log += f'\n模型文件下载失败\n'
302
  print_log(task_counter, f'模型文件载', '失败')
303
- yield None, None, log
304
  return
305
 
306
  log += f'\n模型文件已下载到: {model_input}\n'
307
  print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
308
- yield None, None, log
309
  elif input_type == model_type_opt[1] and file_input:
310
  model_input = file_input
311
  else:
312
  # 改为通过yield返回错误日志
313
  log = '\n请选择输入类型并提供有效的输入!'
314
- yield None, None, log
315
  return
316
 
317
  onnx_path = None
@@ -320,17 +331,23 @@ with gr.Blocks() as demo:
320
  async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
321
  if isinstance(result, tuple) and len(result) == 3:
322
  onnx_path, mnn_path, process_log = result
323
- yield onnx_path, mnn_path, log+process_log
324
  elif isinstance(result, tuple) and len(result) == 2:
325
  # 处理纯日志yield
326
  _, process_log = result
327
- yield None, None, log+process_log
328
  # yield onnx_path, mnn_path, log+process_log
329
 
 
 
 
 
 
 
330
  convert_btn.click(
331
  process_model,
332
  inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
333
- outputs=[onnx_output, mnn_output, log_box],
334
  api_name="convert_model"
335
  )
336
 
 
1
+ from cgi import test
2
  import gradio as gr
3
  import requests
4
  import os
 
12
  import sys
13
  from typing import Optional
14
  import datetime
15
+ from mnnsr import modelTest_for_gradio
16
+ from PIL import Image
17
+ import cv2
18
+
19
+
20
+
21
 
22
  log_to_terminal = True
23
  task_counter = 0
 
59
  Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
60
  """
61
  print(f"--- 开始处理链接: {url} ---")
62
+ # 准备下载路径并创建文件夹
 
 
63
  try:
64
  os.makedirs(folder, exist_ok=True)
65
  except OSError as e:
 
184
  async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
185
  log = ('初始化日志记录...\n')
186
  print_log(task_id, '初始化日志记录', '开始')
187
+ yield [],log
188
 
189
  if isinstance(model_input, str):
190
  input_path = model_input
 
193
  input_path = model_input.name
194
  log += f'已上传文件: {input_path}\n'
195
  print_log(task_id, log.split('\n')[-1], '开始')
196
+ yield [], log
197
 
198
  if not input_path:
199
  log += ( f'未获得正确的模型文件\n')
200
  print_log(task_id, f'未获得正确的模型文件', '错误')
201
+ yield [],log
202
  return
203
 
204
 
 
206
  onnx_path = input_path
207
  log += ( '输入已经是 ONNX 文件\n')
208
  print_log(task_id, '输入已经是 ONNX 文件', '跳过')
209
+ yield [],log
210
  else:
211
  print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
212
  onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
 
216
  else:
217
  log += ( '生成ONNX模型失败\n')
218
  print_log(task_id, '生成ONNX模型', '错误')
219
+ yield [],log
220
  return
221
 
222
 
 
227
  log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
228
  print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
229
  convertmnn(onnx_path, mnn_path,fp16)
230
+ yield onnx_path,log
231
  except Exception as e:
232
  log += ( f'转换 MNN 模型时出错: {str(e)}\n')
233
  print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
234
+ yield onnx_path,log
235
 
236
  print_log(task_id, '模型转换任务完成', '完成')
237
 
 
245
  yield onnx_path, mnn_path, log
246
 
247
  with gr.Blocks() as demo:
248
+ gr.Markdown("# MNN模型转换工具")
249
  model_type_opt = ['从链接下载', '直接上传文件']
250
  with gr.Row():
251
  with gr.Column():
 
264
  tilesize = gr.Number(label="Tilesize", value=0, precision=0)
265
  # 添加fp16和try_run复选框
266
  fp16 = gr.Checkbox(label="FP16", value=False)
267
+ try_run = gr.Checkbox(label="pymnnsr测试", value=False)
268
  convert_btn = gr.Button("开始转换")
269
  with gr.Column():
270
  # with gr.Row():
271
  log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
272
+ onnx_output = gr.File(label="ONNX 模型输出",file_types=["filepath"])
273
+ mnn_output = gr.File(label="MNN 模型输出",file_types=["filepath"])
274
+ img_output = gr.Image(type="pil", label="测试输出(边缘正常说明模型转换成功,色彩有bug)" ,visible=False)
275
+ def show_try_run(try_run):
276
+ if try_run:
277
+ return gr.update(visible=True)
278
+ else:
279
+ return gr.update(visible=False)
280
+ try_run.change(show_try_run, inputs=try_run, outputs=img_output)
281
 
282
  async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
283
  global task_counter
 
291
  if input_type == model_type_opt[0] and url_input:
292
  log = f'正在下载模型文件: {url_input}\n'
293
  print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
294
+ yield None, None, log, None
295
 
296
  if url_input.startswith("https://drive.google.com/"):
297
  model_input = download_gdrive_file(
 
308
  filesize_min=1024 # 1KB
309
  )
310
 
 
311
  if not model_input:
312
  log += f'\n模型文件下载失败\n'
313
  print_log(task_counter, f'模型文件载', '失败')
314
+ yield None, None, log, None
315
  return
316
 
317
  log += f'\n模型文件已下载到: {model_input}\n'
318
  print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
319
+ yield None, None, log, None
320
  elif input_type == model_type_opt[1] and file_input:
321
  model_input = file_input
322
  else:
323
  # 改为通过yield返回错误日志
324
  log = '\n请选择输入类型并提供有效的输入!'
325
+ yield None, None, log,None
326
  return
327
 
328
  onnx_path = None
 
331
  async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
332
  if isinstance(result, tuple) and len(result) == 3:
333
  onnx_path, mnn_path, process_log = result
334
+ yield onnx_path, mnn_path, log+process_log,None
335
  elif isinstance(result, tuple) and len(result) == 2:
336
  # 处理纯日志yield
337
  _, process_log = result
338
+ yield None, None, log+process_log,None
339
  # yield onnx_path, mnn_path, log+process_log
340
 
341
+ if mnn_path and try_run:
342
+ processed_image_np = modelTest_for_gradio(mnn_path, "./sample.jpg")
343
+ processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
344
+ # processed_image_pil = Image.fromarray(processed_image_np)
345
+ yield onnx_path, mnn_path, log+process_log,processed_image_pil
346
+
347
  convert_btn.click(
348
  process_model,
349
  inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
350
+ outputs=[onnx_output, mnn_output, log_box, img_output],
351
  api_name="convert_model"
352
  )
353
 
app_mnnsr.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import MNN
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ from mnnsr import modelTest_for_gradio
8
+
9
+ def gradio_interface(modelPath, input_image):
10
+ processed_image_np = modelTest_for_gradio(modelPath, input_image)
11
+
12
+ processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
13
+
14
+ return processed_image_pil
15
+
16
+ # 创建Gradio界面
17
+ iface = gr.Interface(
18
+ fn=gradio_interface,
19
+ # inputs=gr.Image(type="pil", label="上传图像"),
20
+ inputs = [gr.File(label="上传MNN模型"), gr.Image(type="filepath", label="上传图像",value="./sample.jpg")],
21
+ outputs=gr.Image(type="pil", label="处理后的图像"),
22
+ title="MNN图像超分辨率处理",
23
+ description="上传图像,使用MNN模型进行超分辨率处理"
24
+ )
25
+
26
+ if __name__ == "__main__":
27
+ # 启动Gradio界面
28
+ iface.launch()
mnnsr.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import MNN
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+
8
+
9
+ # 复制原始modelTest函数中的必要函数
10
+ def process_image(data, H, W, C, color='BGR'):
11
+ """
12
+ 处理图像数据(灰度或彩色)并转换为指定色彩空间
13
+
14
+ 参数:
15
+ data: 输入数据指针 (numpy数组形式传入)
16
+ H: 高度
17
+ W: 宽度
18
+ C: 通道数 (1 或 3)
19
+ color: 目标色彩空间 ('BGR', 'RGB', 'YCbCr', 'YUV')
20
+
21
+ 返回:
22
+ numpy数组 处理后的图像
23
+ """
24
+ if C == 1:
25
+ # 灰度图像处理
26
+ gray = np.array(data, dtype=np.float32).reshape(H, W)
27
+ gray = (gray * 255).astype(np.uint8)
28
+ result = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
29
+ return result
30
+ else:
31
+ # 彩色图像处理
32
+ # 将数据拆分为各个通道
33
+ channels = [
34
+ np.array(data[i*H*W : (i+1)*H*W], dtype=np.float32).reshape(H, W)
35
+ for i in range(C)
36
+ ]
37
+ if color == 'RGB':
38
+ # RGB -> BGR (OpenCV默认顺序)
39
+ channels[0], channels[2] = channels[2], channels[0] # 交换R和B通道
40
+ result = cv2.merge(channels)
41
+ result = (result * 255).astype(np.uint8)
42
+ else:
43
+ # 先合并为BGR格式
44
+ rgb = cv2.merge(channels)
45
+ rgb = (rgb * 255).astype(np.uint8)
46
+ # 转换为目标色彩空间
47
+ if color == 'YCbCr':
48
+ result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YCrCb)
49
+ elif color == 'YUV':
50
+ result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YUV)
51
+ else:
52
+ result = rgb # 默认为BGR
53
+
54
+ return result
55
+
56
+ def createTensor(tensor):
57
+ shape = tensor.getShape()
58
+ data = np.ones(shape, dtype=np.float32)
59
+ return MNN.Tensor(shape, tensor.getDataType(), data, tensor.getDimensionType())
60
+
61
+ def modelTest_for_gradio(modelPath, image_path, tilesize = 128):
62
+ model_name = os.path.basename(modelPath)
63
+ if "-Grayscale" in model_name:
64
+ model_channel = 1
65
+ elif "-4ch" in model_name:
66
+ model_channel = 4
67
+ else:
68
+ model_channel = 3
69
+
70
+ net = MNN.Interpreter(modelPath)
71
+ # set 9 for Session_Backend_Auto, Let BackGround Tuning
72
+ net.setSessionMode(9)
73
+ # set 0 for tune_num
74
+ # net.setSessionHint(0, 20)
75
+ config = {}
76
+ # "CPU"或0(默认), "OPENCL"或3,"OPENGL"或6, "VULKAN"或7, "METAL"或1, "TRT"或9, "CUDA"或2, "HIAI"或8
77
+ config['backend'] = 3
78
+ #config['precision'] = "low"
79
+ session = net.createSession(config)
80
+
81
+ print("Run on backendtype: %d \n" % net.getSessionInfo(session, 2))
82
+
83
+ # 读取图像
84
+ image = cv2.imread(image_path)
85
+ if image.ndim == 2:
86
+ # 为了方便处理,先将其扩展为3维数组 (height, width, 1)
87
+ print("extend dims")
88
+ image = np.expand_dims(image, axis=-1)
89
+ image_channel = image.shape[2]
90
+
91
+ image = cv2.resize(image, (tilesize, tilesize))
92
+
93
+ # 处理通道数不匹配的情况
94
+ if image_channel == 3:
95
+ if model_channel == 1:
96
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
97
+ elif model_channel == 3:
98
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
+ elif model_channel == 4:
100
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
101
+ else:
102
+ print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
103
+ elif image_channel == 4:
104
+ if model_channel == 1:
105
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2GRAY)
106
+ elif model_channel == 3:
107
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
108
+ elif model_channel == 4:
109
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
110
+ else:
111
+ print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
112
+ else:
113
+ print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
114
+
115
+ # 显示图像(在Gradio中不需要)
116
+ # display(Image(data=cv2.imencode('.jpg', image)[1].tobytes()))
117
+
118
+ image = image/255.0
119
+ #preprocess it
120
+ image = image.transpose((2, 0, 1))
121
+ #change numpy data type as np.float32 to match tensor's format
122
+ image = image.astype(np.float32)
123
+ #cv2 read shape is NHWC, Tensor's need is NCHW,transpose it
124
+ tmp_input = MNN.Tensor((1, model_channel, tilesize, tilesize), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
125
+
126
+ # input
127
+ inputTensor = net.getSessionInput(session)
128
+ net.resizeTensor(inputTensor, (1, model_channel, tilesize, tilesize))
129
+ net.resizeSession(session)
130
+ inputTensor.copyFrom(tmp_input)
131
+ # infer
132
+ net.runSession(session)
133
+ outputTensor = net.getSessionOutput(session)
134
+ # output
135
+ outputShape = outputTensor.getShape()
136
+ print("outputShape",outputShape)
137
+ outputHost = createTensor(outputTensor)
138
+ outputTensor.copyToHostTensor(outputHost)
139
+
140
+ outimage = process_image(outputHost.getData(), outputShape[2], outputShape[3], outputShape[1], color='RGB')
141
+
142
+ # 返回处理后的图像(numpy数组)
143
+ return outimage
144
+
145
+ # def gradio_interface(modelPath, input_image):
146
+ # processed_image_np = modelTest_for_gradio(modelPath, input_image)
147
+
148
+ # processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
149
+
150
+ # return processed_image_pil
requirements.txt CHANGED
@@ -3,6 +3,9 @@ torch
3
  pnnx
4
  onnx
5
  onnxsim
6
- mnn
7
  gradio
8
  gdown
 
 
 
 
 
3
  pnnx
4
  onnx
5
  onnxsim
 
6
  gradio
7
  gdown
8
+ MNN
9
+ numpy
10
+ opencv-python
11
+ Pillow
sample.jpg ADDED