tumuyan2 commited on
Commit
98e4b89
·
1 Parent(s): b5d90c6
Files changed (2) hide show
  1. app.py +8 -1
  2. mnnsr.py +8 -6
app.py CHANGED
@@ -4,6 +4,8 @@ import requests
4
  import os
5
  import subprocess
6
  from typing import Union
 
 
7
  from pth2onnx import convert_pth_to_onnx
8
  import re
9
  import time
@@ -33,7 +35,11 @@ def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
33
  param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
34
  if fp16:
35
  param.append('--fp16')
36
- subprocess.run(param, check=True)
 
 
 
 
37
 
38
  def download_file(url: str, save_path: str):
39
  response = requests.get(url)
@@ -381,6 +387,7 @@ with gr.Blocks() as demo:
381
  [model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
382
  [model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
383
  [model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
 
384
  ]
385
  example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
386
 
 
4
  import os
5
  import subprocess
6
  from typing import Union
7
+
8
+ from sympy import E
9
  from pth2onnx import convert_pth_to_onnx
10
  import re
11
  import time
 
35
  param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
36
  if fp16:
37
  param.append('--fp16')
38
+ try:
39
+ subprocess.run(param, check=True)
40
+ except Exception as e:
41
+ print(f"转换 MNN 模型时出错: {e}")
42
+
43
 
44
  def download_file(url: str, save_path: str):
45
  response = requests.get(url)
 
387
  [model_type_opt[0], "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"],
388
  [model_type_opt[0], "https://drive.google.com/uc?export=download&confirm=1&id=1PeqL1ikJbBJbVzvlqvtb4d7QdSW7BzrQ"],
389
  [model_type_opt[0], "https://drive.google.com/file/d/1maYmC5yyzWCC42X5O0HeDuepsLFh7AV4/view?usp=drive_link"],
390
+ [model_type_opt[0], "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/1x-MangaJPEGHQ.pth"],
391
  ]
392
  example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
393
 
mnnsr.py CHANGED
@@ -81,14 +81,15 @@ def modelTest_for_gradio(modelPath, image_path, tilesize = 0, backend = 3):
81
  session = net.createSession(config)
82
 
83
  print("Run on backendtype: %d \n" % net.getSessionInfo(session, 2))
84
-
85
  # 读取图像
86
  image = cv2.imread(image_path)
87
  if image.ndim == 2:
88
  # 为了方便处理,先将其扩展为3维数组 (height, width, 1)
89
- print("extend dims")
90
- image = np.expand_dims(image, axis=-1)
91
- image_channel = image.shape[2]
 
 
92
 
93
  image = cv2.resize(image, (tilesize, tilesize))
94
 
@@ -116,10 +117,11 @@ def modelTest_for_gradio(modelPath, image_path, tilesize = 0, backend = 3):
116
 
117
  # 显示图像(在Gradio中不需要)
118
  # display(Image(data=cv2.imencode('.jpg', image)[1].tobytes()))
119
-
120
  image = image/255.0
121
  #preprocess it
122
- image = image.transpose((2, 0, 1))
 
123
  #change numpy data type as np.float32 to match tensor's format
124
  image = image.astype(np.float32)
125
  #cv2 read shape is NHWC, Tensor's need is NCHW,transpose it
 
81
  session = net.createSession(config)
82
 
83
  print("Run on backendtype: %d \n" % net.getSessionInfo(session, 2))
 
84
  # 读取图像
85
  image = cv2.imread(image_path)
86
  if image.ndim == 2:
87
  # 为了方便处理,先将其扩展为3维数组 (height, width, 1)
88
+ # print("extend dims, image.shape=", image.shape)
89
+ image_channel = 1
90
+ # image = np.expand_dims(image, axis=-1)
91
+ else:
92
+ image_channel = image.shape[2]
93
 
94
  image = cv2.resize(image, (tilesize, tilesize))
95
 
 
117
 
118
  # 显示图像(在Gradio中不需要)
119
  # display(Image(data=cv2.imencode('.jpg', image)[1].tobytes()))
120
+ # print("image.shape=", image.shape)
121
  image = image/255.0
122
  #preprocess it
123
+ if model_channel>=3:
124
+ image = image.transpose((2, 0, 1))
125
  #change numpy data type as np.float32 to match tensor's format
126
  image = image.astype(np.float32)
127
  #cv2 read shape is NHWC, Tensor's need is NCHW,transpose it