File size: 6,700 Bytes
2152dfa da006e5 2152dfa 20ad880 2152dfa 62053b6 2152dfa 20ad880 2152dfa 20ad880 2152dfa da006e5 2152dfa b5d90c6 2152dfa da006e5 2152dfa 20ad880 2152dfa 20ad880 2152dfa 20ad880 2152dfa b5d90c6 da006e5 2152dfa b5d90c6 da006e5 2152dfa da006e5 b5d90c6 da006e5 2152dfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import argparse
import torch
import os
import re
import onnx
from spandrel import ImageModelDescriptor, ModelLoader
from onnxsim import simplify
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
"""
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
Args:
pth_path: Path to the input .pth model file.
onnx_path: Path to save the output .onnx file.
channel: Number of input channels for the model.
use_fp16: Boolean to determine if the model should be converted to half precision.
simplify_model: Boolean to determine if the ONNX model should be simplified.
"""
print(f"Loading model from: {pth_path}")
try:
# Use Spandrel to load the model architecture and state dict
model_descriptor = ModelLoader().load_from_file(pth_path)
# Ensure it's the expected type from Spandrel
if not isinstance(model_descriptor, ImageModelDescriptor):
print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
return False
# Get the underlying torch.nn.Module
torch_model = model_descriptor.model
# Set the model to evaluation mode (important for dropout, batchnorm layers)
torch_model.eval()
except Exception as e:
print(f"Error loading model: {e}")
return False
if channel == 0:
channel = model_descriptor.input_channels
if tilesize<1:
tilesize = 64
example_input = torch.randn(1, channel, tilesize, tilesize)
print("Model input channels:", channel, "tile size:", tilesize)
if use_fp16:
if torch.cuda.is_available():
torch_model.cuda()
example_input = example_input.cuda()
else:
print("Warning: no CUDA device")
torch_model.half()
example_input = example_input.half() # 转换为半精度输入
print(f"Model loaded successfully: {type(torch_model).__name__}")
if output_folder:
os.makedirs(output_folder, exist_ok=True)
if onnx_path is None:
base_path, _ = os.path.splitext(pth_path)
if output_folder:
base_path = os.path.join(output_folder, os.path.basename(base_path))
scale = model_descriptor.scale
# 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale
filename = os.path.basename(pth_path).upper()
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
if re.search(pattern, filename):
print(f'File name contains scale info: {filename} ')
else:
base_path = f"{base_path}-x{scale}"
# print("final use_fp16", str(use_fp16) )
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
# 处理相对路径情况
# elif output_folder and not os.path.isabs(onnx_path):
elif output_folder:
onnx_path = os.path.join(output_folder, onnx_path)
# print(f"output_folder: {output_folder}, onnx_path: {onnx_path}")
print(f"ONNX model exporting...")
try:
# Export the model
if dynamic_axes:
axes = {
"input": {2: "height", 3: "width"},
"output": {2: "height", 3: "width"},
}
else:
axes = {}
torch.onnx.export(
torch_model, # The model instance
example_input, # An example input tensor
onnx_path, # Where to save the model (file path)
export_params=True, # Store the trained parameter weights inside the model file
opset_version=opset, # The ONNX version to export the model to (choose based on target runtime)
do_constant_folding=True, # Whether to execute constant folding for optimization
input_names=['input'], # The model's input names
output_names=['output'], # The model's output names
dynamic_axes=axes
)
print(f"ONNX model export successful: {onnx_path}")
# Optional: Simplify the ONNX model
if simplify_model:
model = onnx.load(onnx_path)
model_simplified, _ = simplify(model)
onnx.save(model_simplified, onnx_path)
print(f"ONNX model simplified successfully: {onnx_path}")
# 添加文件验证逻辑
if os.path.exists(onnx_path):
file_size = os.path.getsize(onnx_path)
if file_size > min_size:
return onnx_path
os.remove(onnx_path)
print(f"ONNX model has unexpected file size ({file_size} bytes), deleted invalid file")
return ""
except Exception as e:
print(f"ONNX model export error: {e}")
return ""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
args = parser.parse_args()
success = convert_pth_to_onnx(
pth_path=args.pthpath,
onnx_path=args.onnxpath,
channel=args.channel,
tilesize=args.tilesize,
use_fp16=args.fp16,
simplify_model=args.simplify,
opset=args.opset,
dynamic_axes= not args.fixed_axes,
)
if success:
print("Conversion process finished.")
else:
print("Conversion process failed.")
exit(1) # Exit with error code
|