|
import gradio as gr |
|
import os |
|
import MNN |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import time |
|
|
|
|
|
|
|
def process_image(data, H, W, C, color='BGR'): |
|
""" |
|
处理图像数据(灰度或彩色)并转换为指定色彩空间 |
|
|
|
参数: |
|
data: 输入数据指针 (numpy数组形式传入) |
|
H: 高度 |
|
W: 宽度 |
|
C: 通道数 (1 或 3) |
|
color: 目标色彩空间 ('BGR', 'RGB', 'YCbCr', 'YUV') |
|
|
|
返回: |
|
numpy数组 处理后的图像 |
|
""" |
|
if C == 1: |
|
|
|
gray = np.array(data, dtype=np.float32).reshape(H, W) |
|
gray = (gray * 255).astype(np.uint8) |
|
result = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) |
|
return result |
|
else: |
|
|
|
|
|
channels = [ |
|
np.array(data[i*H*W : (i+1)*H*W], dtype=np.float32).reshape(H, W) |
|
for i in range(C) |
|
] |
|
if color == 'RGB': |
|
|
|
channels[0], channels[2] = channels[2], channels[0] |
|
result = cv2.merge(channels) |
|
result = (result * 255).astype(np.uint8) |
|
else: |
|
|
|
rgb = cv2.merge(channels) |
|
rgb = (rgb * 255).astype(np.uint8) |
|
|
|
if color == 'YCbCr': |
|
result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YCrCb) |
|
elif color == 'YUV': |
|
result = cv2.cvtColor(rgb, cv2.COLOR_BGR2YUV) |
|
else: |
|
result = rgb |
|
|
|
return result |
|
|
|
def createTensor(tensor): |
|
shape = tensor.getShape() |
|
data = np.ones(shape, dtype=np.float32) |
|
return MNN.Tensor(shape, tensor.getDataType(), data, tensor.getDimensionType()) |
|
|
|
def modelTest_for_gradio(modelPath, image_path, tilesize = 0, backend = 3): |
|
if tilesize<=0: |
|
tilesize = 128 |
|
model_name = os.path.basename(modelPath) |
|
if "-Grayscale" in model_name: |
|
model_channel = 1 |
|
elif "-4ch" in model_name or "RGBA" in model_name: |
|
model_channel = 4 |
|
else: |
|
model_channel = 3 |
|
|
|
|
|
load_start_time = time.time() |
|
|
|
net = MNN.Interpreter(modelPath) |
|
|
|
net.setSessionMode(9) |
|
|
|
|
|
config = {} |
|
|
|
config['backend'] = backend |
|
|
|
session = net.createSession(config) |
|
print("Run on backendtype: %d \n" % net.getSessionInfo(session, 2)) |
|
inputTensor = net.getSessionInput(session) |
|
net.resizeTensor(inputTensor, (1, model_channel, tilesize, tilesize)) |
|
net.resizeSession(session) |
|
|
|
load_time = time.time() - load_start_time |
|
print(f"Load mnn model: {load_time:.4f} sec") |
|
|
|
|
|
image = cv2.imread(image_path) |
|
if image.ndim == 2: |
|
|
|
|
|
image_channel = 1 |
|
|
|
else: |
|
image_channel = image.shape[2] |
|
|
|
image = cv2.resize(image, (tilesize, tilesize)) |
|
|
|
|
|
infer_start_time = time.time() |
|
|
|
|
|
if image_channel == 3: |
|
if model_channel == 1: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
elif model_channel == 3: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
elif model_channel == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA) |
|
else: |
|
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}") |
|
elif image_channel == 4: |
|
if model_channel == 1: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2GRAY) |
|
elif model_channel == 3: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) |
|
elif model_channel == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) |
|
else: |
|
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}") |
|
else: |
|
print(f"unexpect input: model_channel {model_channel}, image_channel {image_channel}") |
|
|
|
|
|
|
|
|
|
image = image/255.0 |
|
if model_channel>=3: |
|
image = image.transpose((2, 0, 1)) |
|
image = image.astype(np.float32) |
|
|
|
tmp_input = MNN.Tensor((1, model_channel, tilesize, tilesize), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe) |
|
inputTensor.copyFrom(tmp_input) |
|
|
|
|
|
net.runSession(session) |
|
|
|
outputTensor = net.getSessionOutput(session) |
|
outputShape = outputTensor.getShape() |
|
print("outputShape",outputShape) |
|
outputHost = createTensor(outputTensor) |
|
outputTensor.copyToHostTensor(outputHost) |
|
|
|
outimage = process_image(outputHost.getData(), outputShape[2], outputShape[3], outputShape[1], color='RGB') |
|
|
|
|
|
infer_time = time.time() - infer_start_time |
|
print(f"Infer latency time: {infer_time:.4f} sec") |
|
|
|
return outimage, load_time, infer_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|