|
import argparse
|
|
import torch
|
|
import os
|
|
import re
|
|
import onnx
|
|
from spandrel import ImageModelDescriptor, ModelLoader
|
|
from onnxsim import simplify
|
|
|
|
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
|
|
"""
|
|
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
|
|
|
Args:
|
|
pth_path: Path to the input .pth model file.
|
|
onnx_path: Path to save the output .onnx file.
|
|
channel: Number of input channels for the model.
|
|
use_fp16: Boolean to determine if the model should be converted to half precision.
|
|
simplify_model: Boolean to determine if the ONNX model should be simplified.
|
|
"""
|
|
|
|
print(f"Loading model from: {pth_path}")
|
|
try:
|
|
|
|
model_descriptor = ModelLoader().load_from_file(pth_path)
|
|
|
|
|
|
if not isinstance(model_descriptor, ImageModelDescriptor):
|
|
print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
|
|
print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
|
|
return False
|
|
|
|
|
|
|
|
torch_model = model_descriptor.model
|
|
|
|
|
|
torch_model.eval()
|
|
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
return False
|
|
|
|
if channel == 0:
|
|
channel = model_descriptor.input_channels
|
|
if tilesize<1:
|
|
tilesize = 64
|
|
example_input = torch.randn(1, channel, tilesize, tilesize)
|
|
print("Model input channels:", channel, "tile size:", tilesize)
|
|
|
|
if use_fp16:
|
|
if torch.cuda.is_available():
|
|
torch_model.cuda()
|
|
example_input = example_input.cuda()
|
|
else:
|
|
print("Warning: no CUDA device")
|
|
torch_model.half()
|
|
example_input = example_input.half()
|
|
print(f"Model loaded successfully: {type(torch_model).__name__}")
|
|
|
|
if output_folder:
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
if onnx_path is None:
|
|
base_path, _ = os.path.splitext(pth_path)
|
|
if output_folder:
|
|
base_path = os.path.join(output_folder, os.path.basename(base_path))
|
|
|
|
scale = model_descriptor.scale
|
|
|
|
filename = os.path.basename(pth_path).upper()
|
|
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
|
|
if re.search(pattern, filename):
|
|
print(f'File name contains scale info: {filename} ')
|
|
else:
|
|
base_path = f"{base_path}-x{scale}"
|
|
|
|
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
|
|
|
|
|
|
|
|
elif output_folder:
|
|
onnx_path = os.path.join(output_folder, onnx_path)
|
|
|
|
|
|
|
|
print(f"ONNX model exporting...")
|
|
try:
|
|
|
|
if dynamic_axes:
|
|
axes = {
|
|
"input": {2: "height", 3: "width"},
|
|
"output": {2: "height", 3: "width"},
|
|
}
|
|
else:
|
|
axes = {}
|
|
|
|
torch.onnx.export(
|
|
torch_model,
|
|
example_input,
|
|
onnx_path,
|
|
export_params=True,
|
|
opset_version=opset,
|
|
do_constant_folding=True,
|
|
input_names=['input'],
|
|
output_names=['output'],
|
|
dynamic_axes=axes
|
|
)
|
|
print(f"ONNX model export successful: {onnx_path}")
|
|
|
|
|
|
if simplify_model:
|
|
model = onnx.load(onnx_path)
|
|
model_simplified, _ = simplify(model)
|
|
onnx.save(model_simplified, onnx_path)
|
|
print(f"ONNX model simplified successfully: {onnx_path}")
|
|
|
|
|
|
if os.path.exists(onnx_path):
|
|
file_size = os.path.getsize(onnx_path)
|
|
if file_size > min_size:
|
|
return onnx_path
|
|
|
|
os.remove(onnx_path)
|
|
print(f"ONNX model has unexpected file size ({file_size} bytes), deleted invalid file")
|
|
return ""
|
|
|
|
except Exception as e:
|
|
print(f"ONNX model export error: {e}")
|
|
return ""
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
|
|
parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
|
|
parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
|
|
parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
|
|
parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
|
|
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
|
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
|
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
|
|
parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
|
|
args = parser.parse_args()
|
|
|
|
success = convert_pth_to_onnx(
|
|
pth_path=args.pthpath,
|
|
onnx_path=args.onnxpath,
|
|
channel=args.channel,
|
|
tilesize=args.tilesize,
|
|
use_fp16=args.fp16,
|
|
simplify_model=args.simplify,
|
|
opset=args.opset,
|
|
dynamic_axes= not args.fixed_axes,
|
|
)
|
|
|
|
if success:
|
|
print("Conversion process finished.")
|
|
else:
|
|
print("Conversion process failed.")
|
|
exit(1)
|
|
|