update
Browse files- .gitignore +1 -0
- app.py +39 -22
- app_mnnsr.py +28 -0
- mnnsr.py +150 -0
- requirements.txt +4 -1
- sample.jpg +0 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
/__pycache__
|
2 |
/output*
|
|
|
|
1 |
/__pycache__
|
2 |
/output*
|
3 |
+
*.mnn
|
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
import os
|
@@ -11,6 +12,12 @@ import gdown
|
|
11 |
import sys
|
12 |
from typing import Optional
|
13 |
import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
log_to_terminal = True
|
16 |
task_counter = 0
|
@@ -52,9 +59,7 @@ def download_gdrive_file(
|
|
52 |
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
|
53 |
"""
|
54 |
print(f"--- 开始处理链接: {url} ---")
|
55 |
-
|
56 |
-
|
57 |
-
# 3. 准备下载路径并创建文件夹
|
58 |
try:
|
59 |
os.makedirs(folder, exist_ok=True)
|
60 |
except OSError as e:
|
@@ -179,7 +184,7 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
|
|
179 |
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
|
180 |
log = ('初始化日志记录...\n')
|
181 |
print_log(task_id, '初始化日志记录', '开始')
|
182 |
-
yield [],
|
183 |
|
184 |
if isinstance(model_input, str):
|
185 |
input_path = model_input
|
@@ -188,12 +193,12 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
188 |
input_path = model_input.name
|
189 |
log += f'已上传文件: {input_path}\n'
|
190 |
print_log(task_id, log.split('\n')[-1], '开始')
|
191 |
-
yield [],
|
192 |
|
193 |
if not input_path:
|
194 |
log += ( f'未获得正确的模型文件\n')
|
195 |
print_log(task_id, f'未获得正确的模型文件', '错误')
|
196 |
-
yield [],
|
197 |
return
|
198 |
|
199 |
|
@@ -201,7 +206,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
201 |
onnx_path = input_path
|
202 |
log += ( '输入已经是 ONNX 文件\n')
|
203 |
print_log(task_id, '输入已经是 ONNX 文件', '跳过')
|
204 |
-
yield [],
|
205 |
else:
|
206 |
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
207 |
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
|
@@ -211,7 +216,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
211 |
else:
|
212 |
log += ( '生成ONNX模型失败\n')
|
213 |
print_log(task_id, '生成ONNX模型', '错误')
|
214 |
-
yield [],
|
215 |
return
|
216 |
|
217 |
|
@@ -222,11 +227,11 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
222 |
log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
|
223 |
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
|
224 |
convertmnn(onnx_path, mnn_path,fp16)
|
225 |
-
yield onnx_path,
|
226 |
except Exception as e:
|
227 |
log += ( f'转换 MNN 模型时出错: {str(e)}\n')
|
228 |
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
|
229 |
-
yield onnx_path,
|
230 |
|
231 |
print_log(task_id, '模型转换任务完成', '完成')
|
232 |
|
@@ -240,7 +245,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
240 |
yield onnx_path, mnn_path, log
|
241 |
|
242 |
with gr.Blocks() as demo:
|
243 |
-
gr.Markdown("# 模型转换工具")
|
244 |
model_type_opt = ['从链接下载', '直接上传文件']
|
245 |
with gr.Row():
|
246 |
with gr.Column():
|
@@ -259,13 +264,20 @@ with gr.Blocks() as demo:
|
|
259 |
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
|
260 |
# 添加fp16和try_run复选框
|
261 |
fp16 = gr.Checkbox(label="FP16", value=False)
|
262 |
-
try_run = gr.Checkbox(label="测试
|
263 |
convert_btn = gr.Button("开始转换")
|
264 |
with gr.Column():
|
265 |
# with gr.Row():
|
266 |
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
|
267 |
-
onnx_output = gr.File(label="ONNX 模型输出")
|
268 |
-
mnn_output = gr.File(label="MNN 模型输出")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
|
271 |
global task_counter
|
@@ -279,7 +291,7 @@ with gr.Blocks() as demo:
|
|
279 |
if input_type == model_type_opt[0] and url_input:
|
280 |
log = f'正在下载模型文件: {url_input}\n'
|
281 |
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
282 |
-
yield None, None, log
|
283 |
|
284 |
if url_input.startswith("https://drive.google.com/"):
|
285 |
model_input = download_gdrive_file(
|
@@ -296,22 +308,21 @@ with gr.Blocks() as demo:
|
|
296 |
filesize_min=1024 # 1KB
|
297 |
)
|
298 |
|
299 |
-
|
300 |
if not model_input:
|
301 |
log += f'\n模型文件下载失败\n'
|
302 |
print_log(task_counter, f'模型文件载', '失败')
|
303 |
-
yield None, None, log
|
304 |
return
|
305 |
|
306 |
log += f'\n模型文件已下载到: {model_input}\n'
|
307 |
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
|
308 |
-
yield None, None, log
|
309 |
elif input_type == model_type_opt[1] and file_input:
|
310 |
model_input = file_input
|
311 |
else:
|
312 |
# 改为通过yield返回错误日志
|
313 |
log = '\n请选择输入类型并提供有效的输入!'
|
314 |
-
yield None, None, log
|
315 |
return
|
316 |
|
317 |
onnx_path = None
|
@@ -320,17 +331,23 @@ with gr.Blocks() as demo:
|
|
320 |
async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
|
321 |
if isinstance(result, tuple) and len(result) == 3:
|
322 |
onnx_path, mnn_path, process_log = result
|
323 |
-
yield onnx_path, mnn_path, log+process_log
|
324 |
elif isinstance(result, tuple) and len(result) == 2:
|
325 |
# 处理纯日志yield
|
326 |
_, process_log = result
|
327 |
-
yield None, None, log+process_log
|
328 |
# yield onnx_path, mnn_path, log+process_log
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
convert_btn.click(
|
331 |
process_model,
|
332 |
inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
|
333 |
-
outputs=[onnx_output, mnn_output, log_box],
|
334 |
api_name="convert_model"
|
335 |
)
|
336 |
|
|
|
1 |
+
from cgi import test
|
2 |
import gradio as gr
|
3 |
import requests
|
4 |
import os
|
|
|
12 |
import sys
|
13 |
from typing import Optional
|
14 |
import datetime
|
15 |
+
from mnnsr import modelTest_for_gradio
|
16 |
+
from PIL import Image
|
17 |
+
import cv2
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
|
22 |
log_to_terminal = True
|
23 |
task_counter = 0
|
|
|
59 |
Optional[str]: 如果下载成功,返回文件的完整路径;否则返回 None。
|
60 |
"""
|
61 |
print(f"--- 开始处理链接: {url} ---")
|
62 |
+
# 准备下载路径并创建文件夹
|
|
|
|
|
63 |
try:
|
64 |
os.makedirs(folder, exist_ok=True)
|
65 |
except OSError as e:
|
|
|
184 |
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool):
|
185 |
log = ('初始化日志记录...\n')
|
186 |
print_log(task_id, '初始化日志记录', '开始')
|
187 |
+
yield [],log
|
188 |
|
189 |
if isinstance(model_input, str):
|
190 |
input_path = model_input
|
|
|
193 |
input_path = model_input.name
|
194 |
log += f'已上传文件: {input_path}\n'
|
195 |
print_log(task_id, log.split('\n')[-1], '开始')
|
196 |
+
yield [], log
|
197 |
|
198 |
if not input_path:
|
199 |
log += ( f'未获得正确的模型文件\n')
|
200 |
print_log(task_id, f'未获得正确的模型文件', '错误')
|
201 |
+
yield [],log
|
202 |
return
|
203 |
|
204 |
|
|
|
206 |
onnx_path = input_path
|
207 |
log += ( '输入已经是 ONNX 文件\n')
|
208 |
print_log(task_id, '输入已经是 ONNX 文件', '跳过')
|
209 |
+
yield [],log
|
210 |
else:
|
211 |
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
212 |
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16)
|
|
|
216 |
else:
|
217 |
log += ( '生成ONNX模型失败\n')
|
218 |
print_log(task_id, '生成ONNX模型', '错误')
|
219 |
+
yield [],log
|
220 |
return
|
221 |
|
222 |
|
|
|
227 |
log += ( '正在将 ONNX 模型转换为 MNN 格式...\n')
|
228 |
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
|
229 |
convertmnn(onnx_path, mnn_path,fp16)
|
230 |
+
yield onnx_path,log
|
231 |
except Exception as e:
|
232 |
log += ( f'转换 MNN 模型时出错: {str(e)}\n')
|
233 |
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
|
234 |
+
yield onnx_path,log
|
235 |
|
236 |
print_log(task_id, '模型转换任务完成', '完成')
|
237 |
|
|
|
245 |
yield onnx_path, mnn_path, log
|
246 |
|
247 |
with gr.Blocks() as demo:
|
248 |
+
gr.Markdown("# MNN模型转换工具")
|
249 |
model_type_opt = ['从链接下载', '直接上传文件']
|
250 |
with gr.Row():
|
251 |
with gr.Column():
|
|
|
264 |
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
|
265 |
# 添加fp16和try_run复选框
|
266 |
fp16 = gr.Checkbox(label="FP16", value=False)
|
267 |
+
try_run = gr.Checkbox(label="pymnnsr测试", value=False)
|
268 |
convert_btn = gr.Button("开始转换")
|
269 |
with gr.Column():
|
270 |
# with gr.Row():
|
271 |
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
|
272 |
+
onnx_output = gr.File(label="ONNX 模型输出",file_types=["filepath"])
|
273 |
+
mnn_output = gr.File(label="MNN 模型输出",file_types=["filepath"])
|
274 |
+
img_output = gr.Image(type="pil", label="测试输出(边缘正常说明模型转换成功,色彩有bug)" ,visible=False)
|
275 |
+
def show_try_run(try_run):
|
276 |
+
if try_run:
|
277 |
+
return gr.update(visible=True)
|
278 |
+
else:
|
279 |
+
return gr.update(visible=False)
|
280 |
+
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
|
281 |
|
282 |
async def process_model(input_type, url_input, file_input, tilesize, fp16, try_run):
|
283 |
global task_counter
|
|
|
291 |
if input_type == model_type_opt[0] and url_input:
|
292 |
log = f'正在下载模型文件: {url_input}\n'
|
293 |
print_log(task_counter, f'正在下载模型文件: {url_input}', '开始')
|
294 |
+
yield None, None, log, None
|
295 |
|
296 |
if url_input.startswith("https://drive.google.com/"):
|
297 |
model_input = download_gdrive_file(
|
|
|
308 |
filesize_min=1024 # 1KB
|
309 |
)
|
310 |
|
|
|
311 |
if not model_input:
|
312 |
log += f'\n模型文件下载失败\n'
|
313 |
print_log(task_counter, f'模型文件载', '失败')
|
314 |
+
yield None, None, log, None
|
315 |
return
|
316 |
|
317 |
log += f'\n模型文件已下载到: {model_input}\n'
|
318 |
print_log(task_counter, f'模型文件已下载到: {model_input}', '完成')
|
319 |
+
yield None, None, log, None
|
320 |
elif input_type == model_type_opt[1] and file_input:
|
321 |
model_input = file_input
|
322 |
else:
|
323 |
# 改为通过yield返回错误日志
|
324 |
log = '\n请选择输入类型并提供有效的输入!'
|
325 |
+
yield None, None, log,None
|
326 |
return
|
327 |
|
328 |
onnx_path = None
|
|
|
331 |
async for result in _process_model(model_input, int(tilesize), output_dir, task_counter, fp16):
|
332 |
if isinstance(result, tuple) and len(result) == 3:
|
333 |
onnx_path, mnn_path, process_log = result
|
334 |
+
yield onnx_path, mnn_path, log+process_log,None
|
335 |
elif isinstance(result, tuple) and len(result) == 2:
|
336 |
# 处理纯日志yield
|
337 |
_, process_log = result
|
338 |
+
yield None, None, log+process_log,None
|
339 |
# yield onnx_path, mnn_path, log+process_log
|
340 |
|
341 |
+
if mnn_path and try_run:
|
342 |
+
processed_image_np = modelTest_for_gradio(mnn_path, "./sample.jpg")
|
343 |
+
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
344 |
+
# processed_image_pil = Image.fromarray(processed_image_np)
|
345 |
+
yield onnx_path, mnn_path, log+process_log,processed_image_pil
|
346 |
+
|
347 |
convert_btn.click(
|
348 |
process_model,
|
349 |
inputs=[input_type, url_input, file_input, tilesize, fp16, try_run],
|
350 |
+
outputs=[onnx_output, mnn_output, log_box, img_output],
|
351 |
api_name="convert_model"
|
352 |
)
|
353 |
|
app_mnnsr.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import MNN
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
from mnnsr import modelTest_for_gradio
|
8 |
+
|
9 |
+
def gradio_interface(modelPath, input_image):
|
10 |
+
processed_image_np = modelTest_for_gradio(modelPath, input_image)
|
11 |
+
|
12 |
+
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
13 |
+
|
14 |
+
return processed_image_pil
|
15 |
+
|
16 |
+
# 创建Gradio界面
|
17 |
+
iface = gr.Interface(
|
18 |
+
fn=gradio_interface,
|
19 |
+
# inputs=gr.Image(type="pil", label="上传图像"),
|
20 |
+
inputs = [gr.File(label="上传MNN模型"), gr.Image(type="filepath", label="上传图像",value="./sample.jpg")],
|
21 |
+
outputs=gr.Image(type="pil", label="处理后的图像"),
|
22 |
+
title="MNN图像超分辨率处理",
|
23 |
+
description="上传图像,使用MNN模型进行超分辨率处理"
|
24 |
+
)
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
# 启动Gradio界面
|
28 |
+
iface.launch()
|
mnnsr.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import MNN
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
# 复制原始modelTest函数中的必要函数
|
10 |
+
def process_image(data, H, W, C, color='BGR'):
|
11 |
+
"""
|
12 |
+
处理图像数据(灰度或彩色)并转换为指定色彩空间
|
13 |
+
|
14 |
+
参数:
|
15 |
+
data: 输入数据指针 (numpy数组形式传入)
|
16 |
+
H: 高度
|
17 |
+
W: 宽度
|
18 |
+
C: 通道数 (1 或 3)
|
19 |
+
color: 目标色彩空间 ('BGR', 'RGB', 'YCbCr', 'YUV')
|
20 |
+
|
21 |
+
返回:
|
22 |
+
numpy数组 处理后的图像
|
23 |
+
"""
|
24 |
+
if C == 1:
|
25 |
+
# 灰度图像处理
|
26 |
+
gray = np.array(data, dtype=np.float32).reshape(H, W)
|
27 |
+
gray = (gray * 255).astype(np.uint8)
|
28 |
+
result = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
29 |
+
return result
|
30 |
+
else:
|
31 |
+
# 彩色图像处理
|
32 |
+
# 将数据拆分为各个通道
|
33 |
+
channels = [
|
34 |
+
np.array(data[i*H*W : (i+1)*H*W], dtype=np.float32).reshape(H, W)
|
35 |
+
for i in range(C)
|
36 |
+
]
|
37 |
+
if color == 'RGB':
|
38 |
+
# RGB -> BGR (OpenCV默认顺序)
|
39 |
+
channels[0], channels[2] = channels[2], channels[0] # 交换R和B通道
|
40 |
+
result = cv2.merge(channels)
|
41 |
+
result = (result * 255).astype(np.uint8)
|
42 |
+
else:
|
43 |
+
# 先合并为BGR格式
|
44 |
+
rgb = cv2.merge(channels)
|
45 |
+
rgb = (rgb * 255).astype(np.uint8)
|
46 |
+
# 转换为目标色彩空间
|
47 |
+
if color == 'YCbCr':
|
48 |
+
result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YCrCb)
|
49 |
+
elif color == 'YUV':
|
50 |
+
result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YUV)
|
51 |
+
else:
|
52 |
+
result = rgb # 默认为BGR
|
53 |
+
|
54 |
+
return result
|
55 |
+
|
56 |
+
def createTensor(tensor):
|
57 |
+
shape = tensor.getShape()
|
58 |
+
data = np.ones(shape, dtype=np.float32)
|
59 |
+
return MNN.Tensor(shape, tensor.getDataType(), data, tensor.getDimensionType())
|
60 |
+
|
61 |
+
def modelTest_for_gradio(modelPath, image_path, tilesize = 128):
|
62 |
+
model_name = os.path.basename(modelPath)
|
63 |
+
if "-Grayscale" in model_name:
|
64 |
+
model_channel = 1
|
65 |
+
elif "-4ch" in model_name:
|
66 |
+
model_channel = 4
|
67 |
+
else:
|
68 |
+
model_channel = 3
|
69 |
+
|
70 |
+
net = MNN.Interpreter(modelPath)
|
71 |
+
# set 9 for Session_Backend_Auto, Let BackGround Tuning
|
72 |
+
net.setSessionMode(9)
|
73 |
+
# set 0 for tune_num
|
74 |
+
# net.setSessionHint(0, 20)
|
75 |
+
config = {}
|
76 |
+
# "CPU"或0(默认), "OPENCL"或3,"OPENGL"或6, "VULKAN"或7, "METAL"或1, "TRT"或9, "CUDA"或2, "HIAI"或8
|
77 |
+
config['backend'] = 3
|
78 |
+
#config['precision'] = "low"
|
79 |
+
session = net.createSession(config)
|
80 |
+
|
81 |
+
print("Run on backendtype: %d \n" % net.getSessionInfo(session, 2))
|
82 |
+
|
83 |
+
# 读取图像
|
84 |
+
image = cv2.imread(image_path)
|
85 |
+
if image.ndim == 2:
|
86 |
+
# 为了方便处理,先将其扩展为3维数组 (height, width, 1)
|
87 |
+
print("extend dims")
|
88 |
+
image = np.expand_dims(image, axis=-1)
|
89 |
+
image_channel = image.shape[2]
|
90 |
+
|
91 |
+
image = cv2.resize(image, (tilesize, tilesize))
|
92 |
+
|
93 |
+
# 处理通道数不匹配的情况
|
94 |
+
if image_channel == 3:
|
95 |
+
if model_channel == 1:
|
96 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
97 |
+
elif model_channel == 3:
|
98 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
99 |
+
elif model_channel == 4:
|
100 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
|
101 |
+
else:
|
102 |
+
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
|
103 |
+
elif image_channel == 4:
|
104 |
+
if model_channel == 1:
|
105 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGRA2GRAY)
|
106 |
+
elif model_channel == 3:
|
107 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
108 |
+
elif model_channel == 4:
|
109 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
110 |
+
else:
|
111 |
+
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
|
112 |
+
else:
|
113 |
+
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
|
114 |
+
|
115 |
+
# 显示图像(在Gradio中不需要)
|
116 |
+
# display(Image(data=cv2.imencode('.jpg', image)[1].tobytes()))
|
117 |
+
|
118 |
+
image = image/255.0
|
119 |
+
#preprocess it
|
120 |
+
image = image.transpose((2, 0, 1))
|
121 |
+
#change numpy data type as np.float32 to match tensor's format
|
122 |
+
image = image.astype(np.float32)
|
123 |
+
#cv2 read shape is NHWC, Tensor's need is NCHW,transpose it
|
124 |
+
tmp_input = MNN.Tensor((1, model_channel, tilesize, tilesize), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
|
125 |
+
|
126 |
+
# input
|
127 |
+
inputTensor = net.getSessionInput(session)
|
128 |
+
net.resizeTensor(inputTensor, (1, model_channel, tilesize, tilesize))
|
129 |
+
net.resizeSession(session)
|
130 |
+
inputTensor.copyFrom(tmp_input)
|
131 |
+
# infer
|
132 |
+
net.runSession(session)
|
133 |
+
outputTensor = net.getSessionOutput(session)
|
134 |
+
# output
|
135 |
+
outputShape = outputTensor.getShape()
|
136 |
+
print("outputShape",outputShape)
|
137 |
+
outputHost = createTensor(outputTensor)
|
138 |
+
outputTensor.copyToHostTensor(outputHost)
|
139 |
+
|
140 |
+
outimage = process_image(outputHost.getData(), outputShape[2], outputShape[3], outputShape[1], color='RGB')
|
141 |
+
|
142 |
+
# 返回处理后的图像(numpy数组)
|
143 |
+
return outimage
|
144 |
+
|
145 |
+
# def gradio_interface(modelPath, input_image):
|
146 |
+
# processed_image_np = modelTest_for_gradio(modelPath, input_image)
|
147 |
+
|
148 |
+
# processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
149 |
+
|
150 |
+
# return processed_image_pil
|
requirements.txt
CHANGED
@@ -3,6 +3,9 @@ torch
|
|
3 |
pnnx
|
4 |
onnx
|
5 |
onnxsim
|
6 |
-
mnn
|
7 |
gradio
|
8 |
gdown
|
|
|
|
|
|
|
|
|
|
3 |
pnnx
|
4 |
onnx
|
5 |
onnxsim
|
|
|
6 |
gradio
|
7 |
gdown
|
8 |
+
MNN
|
9 |
+
numpy
|
10 |
+
opencv-python
|
11 |
+
Pillow
|
sample.jpg
ADDED
![]() |