model2mnn / mnnsr.py
tumuyan2's picture
update
20ad880
import gradio as gr
import os
import MNN
import numpy as np
import cv2
from PIL import Image
import time
# 复制原始modelTest函数中的必要函数
def process_image(data, H, W, C, color='BGR'):
"""
处理图像数据(灰度或彩色)并转换为指定色彩空间
参数:
data: 输入数据指针 (numpy数组形式传入)
H: 高度
W: 宽度
C: 通道数 (1 或 3)
color: 目标色彩空间 ('BGR', 'RGB', 'YCbCr', 'YUV')
返回:
numpy数组 处理后的图像
"""
if C == 1:
# 灰度图像处理
gray = np.array(data, dtype=np.float32).reshape(H, W)
gray = (gray * 255).astype(np.uint8)
result = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
return result
else:
# 彩色图像处理
# 将数据拆分为各个通道
channels = [
np.array(data[i*H*W : (i+1)*H*W], dtype=np.float32).reshape(H, W)
for i in range(C)
]
if color == 'RGB':
# RGB -> BGR (OpenCV默认顺序)
channels[0], channels[2] = channels[2], channels[0] # 交换R和B通道
result = cv2.merge(channels)
result = (result * 255).astype(np.uint8)
else:
# 先合并为BGR格式
rgb = cv2.merge(channels)
rgb = (rgb * 255).astype(np.uint8)
# 转换为目标色彩空间
if color == 'YCbCr':
result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YCrCb)
elif color == 'YUV':
result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YUV)
else:
result = rgb # 默认为BGR
return result
def createTensor(tensor):
shape = tensor.getShape()
data = np.ones(shape, dtype=np.float32)
return MNN.Tensor(shape, tensor.getDataType(), data, tensor.getDimensionType())
def modelTest_for_gradio(modelPath, image_path, tilesize = 0, backend = 3):
if tilesize<=0:
tilesize = 128
model_name = os.path.basename(modelPath)
if "-Grayscale" in model_name:
model_channel = 1
elif "-4ch" in model_name or "RGBA" in model_name:
model_channel = 4
else:
model_channel = 3
# 记录模型加载开始时间
load_start_time = time.time()
# 加载模型(计时范围内)
net = MNN.Interpreter(modelPath)
# set 9 for Session_Backend_Auto, Let BackGround Tuning
net.setSessionMode(9)
# set 0 for tune_num
# net.setSessionHint(0, 20)
config = {}
# "CPU"或0(默认), "OPENCL"或3,"OPENGL"或6, "VULKAN"或7, "METAL"或1, "TRT"或9, "CUDA"或2, "HIAI"或8
config['backend'] = backend
#config['precision'] = "low"
session = net.createSession(config)
print("Run on backendtype: %d \n" % net.getSessionInfo(session, 2))
inputTensor = net.getSessionInput(session)
net.resizeTensor(inputTensor, (1, model_channel, tilesize, tilesize))
net.resizeSession(session)
# 计算模型加载耗时
load_time = time.time() - load_start_time
print(f"Load mnn model: {load_time:.4f} sec")
# 读取图像
image = cv2.imread(image_path)
if image.ndim == 2:
# 为了方便处理,先将其扩展为3维数组 (height, width, 1)
# print("extend dims, image.shape=", image.shape)
image_channel = 1
# image = np.expand_dims(image, axis=-1)
else:
image_channel = image.shape[2]
image = cv2.resize(image, (tilesize, tilesize))
# 记录推理开始时间
infer_start_time = time.time()
# 处理通道数不匹配的情况
if image_channel == 3:
if model_channel == 1:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
elif model_channel == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif model_channel == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
else:
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
elif image_channel == 4:
if model_channel == 1:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2GRAY)
elif model_channel == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif model_channel == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
else:
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
else:
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}")
# 显示图像(在Gradio中不需要)
# display(Image(data=cv2.imencode('.jpg', image)[1].tobytes()))
# print("image.shape=", image.shape)
image = image/255.0
if model_channel>=3:
image = image.transpose((2, 0, 1))
image = image.astype(np.float32)
tmp_input = MNN.Tensor((1, model_channel, tilesize, tilesize), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
inputTensor.copyFrom(tmp_input)
# 执行推理
net.runSession(session)
outputTensor = net.getSessionOutput(session)
outputShape = outputTensor.getShape()
print("outputShape",outputShape)
outputHost = createTensor(outputTensor)
outputTensor.copyToHostTensor(outputHost)
outimage = process_image(outputHost.getData(), outputShape[2], outputShape[3], outputShape[1], color='RGB')
# 计算推理耗时
infer_time = time.time() - infer_start_time
print(f"Infer latency time: {infer_time:.4f} sec")
# 返回图像、加载时间、推理时间
return outimage, load_time, infer_time
# def gradio_interface(modelPath, input_image):
# processed_image_np = modelTest_for_gradio(modelPath, input_image)
# processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
# return processed_image_pil