model2mnn / pth2onnx.py
tumuyan2's picture
update
da006e5
import argparse
import torch
import os
import re
import onnx
from spandrel import ImageModelDescriptor, ModelLoader
from onnxsim import simplify
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
"""
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
Args:
pth_path: Path to the input .pth model file.
onnx_path: Path to save the output .onnx file.
channel: Number of input channels for the model.
use_fp16: Boolean to determine if the model should be converted to half precision.
simplify_model: Boolean to determine if the ONNX model should be simplified.
"""
print(f"Loading model from: {pth_path}")
try:
# Use Spandrel to load the model architecture and state dict
model_descriptor = ModelLoader().load_from_file(pth_path)
# Ensure it's the expected type from Spandrel
if not isinstance(model_descriptor, ImageModelDescriptor):
print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
return False
# Get the underlying torch.nn.Module
torch_model = model_descriptor.model
# Set the model to evaluation mode (important for dropout, batchnorm layers)
torch_model.eval()
except Exception as e:
print(f"Error loading model: {e}")
return False
if channel == 0:
channel = model_descriptor.input_channels
if tilesize<1:
tilesize = 64
example_input = torch.randn(1, channel, tilesize, tilesize)
print("Model input channels:", channel, "tile size:", tilesize)
if use_fp16:
if torch.cuda.is_available():
torch_model.cuda()
example_input = example_input.cuda()
else:
print("Warning: no CUDA device")
torch_model.half()
example_input = example_input.half() # 转换为半精度输入
print(f"Model loaded successfully: {type(torch_model).__name__}")
if output_folder:
os.makedirs(output_folder, exist_ok=True)
if onnx_path is None:
base_path, _ = os.path.splitext(pth_path)
if output_folder:
base_path = os.path.join(output_folder, os.path.basename(base_path))
scale = model_descriptor.scale
# 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale
filename = os.path.basename(pth_path).upper()
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
if re.search(pattern, filename):
print(f'File name contains scale info: {filename} ')
else:
base_path = f"{base_path}-x{scale}"
# print("final use_fp16", str(use_fp16) )
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
# 处理相对路径情况
# elif output_folder and not os.path.isabs(onnx_path):
elif output_folder:
onnx_path = os.path.join(output_folder, onnx_path)
# print(f"output_folder: {output_folder}, onnx_path: {onnx_path}")
print(f"ONNX model exporting...")
try:
# Export the model
if dynamic_axes:
axes = {
"input": {2: "height", 3: "width"},
"output": {2: "height", 3: "width"},
}
else:
axes = {}
torch.onnx.export(
torch_model, # The model instance
example_input, # An example input tensor
onnx_path, # Where to save the model (file path)
export_params=True, # Store the trained parameter weights inside the model file
opset_version=opset, # The ONNX version to export the model to (choose based on target runtime)
do_constant_folding=True, # Whether to execute constant folding for optimization
input_names=['input'], # The model's input names
output_names=['output'], # The model's output names
dynamic_axes=axes
)
print(f"ONNX model export successful: {onnx_path}")
# Optional: Simplify the ONNX model
if simplify_model:
model = onnx.load(onnx_path)
model_simplified, _ = simplify(model)
onnx.save(model_simplified, onnx_path)
print(f"ONNX model simplified successfully: {onnx_path}")
# 添加文件验证逻辑
if os.path.exists(onnx_path):
file_size = os.path.getsize(onnx_path)
if file_size > min_size:
return onnx_path
os.remove(onnx_path)
print(f"ONNX model has unexpected file size ({file_size} bytes), deleted invalid file")
return ""
except Exception as e:
print(f"ONNX model export error: {e}")
return ""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
args = parser.parse_args()
success = convert_pth_to_onnx(
pth_path=args.pthpath,
onnx_path=args.onnxpath,
channel=args.channel,
tilesize=args.tilesize,
use_fp16=args.fp16,
simplify_model=args.simplify,
opset=args.opset,
dynamic_axes= not args.fixed_axes,
)
if success:
print("Conversion process finished.")
else:
print("Conversion process failed.")
exit(1) # Exit with error code