import argparse import torch import os import re import onnx from spandrel import ImageModelDescriptor, ModelLoader from onnxsim import simplify def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True): """ Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format. Args: pth_path: Path to the input .pth model file. onnx_path: Path to save the output .onnx file. channel: Number of input channels for the model. use_fp16: Boolean to determine if the model should be converted to half precision. simplify_model: Boolean to determine if the ONNX model should be simplified. """ print(f"Loading model from: {pth_path}") try: # Use Spandrel to load the model architecture and state dict model_descriptor = ModelLoader().load_from_file(pth_path) # Ensure it's the expected type from Spandrel if not isinstance(model_descriptor, ImageModelDescriptor): print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}") print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.") return False # Get the underlying torch.nn.Module torch_model = model_descriptor.model # Set the model to evaluation mode (important for dropout, batchnorm layers) torch_model.eval() except Exception as e: print(f"Error loading model: {e}") return False if channel == 0: channel = model_descriptor.input_channels if tilesize<1: tilesize = 64 example_input = torch.randn(1, channel, tilesize, tilesize) print("Model input channels:", channel, "tile size:", tilesize) if use_fp16: if torch.cuda.is_available(): torch_model.cuda() example_input = example_input.cuda() else: print("Warning: no CUDA device") torch_model.half() example_input = example_input.half() # 转换为半精度输入 print(f"Model loaded successfully: {type(torch_model).__name__}") if output_folder: os.makedirs(output_folder, exist_ok=True) if onnx_path is None: base_path, _ = os.path.splitext(pth_path) if output_folder: base_path = os.path.join(output_folder, os.path.basename(base_path)) scale = model_descriptor.scale # 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale filename = os.path.basename(pth_path).upper() pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)' if re.search(pattern, filename): print(f'File name contains scale info: {filename} ') else: base_path = f"{base_path}-x{scale}" # print("final use_fp16", str(use_fp16) ) onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx") # 处理相对路径情况 # elif output_folder and not os.path.isabs(onnx_path): elif output_folder: onnx_path = os.path.join(output_folder, onnx_path) # print(f"output_folder: {output_folder}, onnx_path: {onnx_path}") print(f"ONNX model exporting...") try: # Export the model if dynamic_axes: axes = { "input": {2: "height", 3: "width"}, "output": {2: "height", 3: "width"}, } else: axes = {} torch.onnx.export( torch_model, # The model instance example_input, # An example input tensor onnx_path, # Where to save the model (file path) export_params=True, # Store the trained parameter weights inside the model file opset_version=opset, # The ONNX version to export the model to (choose based on target runtime) do_constant_folding=True, # Whether to execute constant folding for optimization input_names=['input'], # The model's input names output_names=['output'], # The model's output names dynamic_axes=axes ) print(f"ONNX model export successful: {onnx_path}") # Optional: Simplify the ONNX model if simplify_model: model = onnx.load(onnx_path) model_simplified, _ = simplify(model) onnx.save(model_simplified, onnx_path) print(f"ONNX model simplified successfully: {onnx_path}") # 添加文件验证逻辑 if os.path.exists(onnx_path): file_size = os.path.getsize(onnx_path) if file_size > min_size: return onnx_path os.remove(onnx_path) print(f"ONNX model has unexpected file size ({file_size} bytes), deleted invalid file") return "" except Exception as e: print(f"ONNX model export error: {e}") return "" if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.') parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.') parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.') parser.add_argument('--channel', type=int, default=0, help='Channel parameter.') parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.') parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.') parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.') parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.') parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.') args = parser.parse_args() success = convert_pth_to_onnx( pth_path=args.pthpath, onnx_path=args.onnxpath, channel=args.channel, tilesize=args.tilesize, use_fp16=args.fp16, simplify_model=args.simplify, opset=args.opset, dynamic_axes= not args.fixed_axes, ) if success: print("Conversion process finished.") else: print("Conversion process failed.") exit(1) # Exit with error code