File size: 6,700 Bytes
2152dfa
 
 
 
 
 
 
 
da006e5
2152dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ad880
2152dfa
 
62053b6
2152dfa
 
 
 
 
 
 
20ad880
2152dfa
20ad880
2152dfa
 
da006e5
 
 
 
 
 
 
 
2152dfa
 
 
 
 
b5d90c6
2152dfa
 
 
da006e5
2152dfa
20ad880
2152dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ad880
2152dfa
 
 
20ad880
2152dfa
 
 
 
 
 
 
 
 
 
 
b5d90c6
da006e5
2152dfa
 
b5d90c6
da006e5
 
2152dfa
 
da006e5
 
b5d90c6
da006e5
2152dfa
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import argparse
import torch
import os
import re
import onnx
from spandrel import ImageModelDescriptor, ModelLoader
from onnxsim import simplify

def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
    """

    Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.



    Args:

        pth_path: Path to the input .pth model file.

        onnx_path: Path to save the output .onnx file.

        channel: Number of input channels for the model.

        use_fp16: Boolean to determine if the model should be converted to half precision.

        simplify_model: Boolean to determine if the ONNX model should be simplified.

    """

    print(f"Loading model from: {pth_path}")
    try:
        # Use Spandrel to load the model architecture and state dict
        model_descriptor = ModelLoader().load_from_file(pth_path)

        # Ensure it's the expected type from Spandrel
        if not isinstance(model_descriptor, ImageModelDescriptor):
            print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
            print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
            return False


        # Get the underlying torch.nn.Module
        torch_model = model_descriptor.model

        # Set the model to evaluation mode (important for dropout, batchnorm layers)
        torch_model.eval()

    except Exception as e:
        print(f"Error loading model: {e}")
        return False

    if channel == 0:
        channel = model_descriptor.input_channels
    if tilesize<1:
        tilesize = 64
    example_input = torch.randn(1, channel, tilesize, tilesize)
    print("Model input channels:", channel, "tile size:", tilesize)

    if use_fp16:
        if torch.cuda.is_available():
            torch_model.cuda()
            example_input = example_input.cuda()
        else:
            print("Warning: no CUDA device")
        torch_model.half()
        example_input = example_input.half()  # 转换为半精度输入
    print(f"Model loaded successfully: {type(torch_model).__name__}")

    if output_folder:
        os.makedirs(output_folder, exist_ok=True)
    
    if onnx_path is None:
        base_path, _ = os.path.splitext(pth_path)
        if output_folder:
            base_path = os.path.join(output_folder, os.path.basename(base_path))

        scale = model_descriptor.scale
        # 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale
        filename = os.path.basename(pth_path).upper()
        pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
        if re.search(pattern, filename):
            print(f'File name contains scale info: {filename} ')
        else:
            base_path = f"{base_path}-x{scale}"
        # print("final use_fp16", str(use_fp16) )
        onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
        
    # 处理相对路径情况
    # elif output_folder and not os.path.isabs(onnx_path):
    elif output_folder:
        onnx_path = os.path.join(output_folder, onnx_path)
        
    # print(f"output_folder: {output_folder}, onnx_path: {onnx_path}")

    print(f"ONNX model exporting...")
    try:
        # Export the model
        if dynamic_axes:
            axes = {
                "input": {2: "height", 3: "width"}, 
                "output": {2: "height", 3: "width"},
            }
        else:
            axes = {}

        torch.onnx.export(
            torch_model,                 # The model instance
            example_input,               # An example input tensor
            onnx_path,                   # Where to save the model (file path)
            export_params=True,          # Store the trained parameter weights inside the model file
            opset_version=opset,         # The ONNX version to export the model to (choose based on target runtime)
            do_constant_folding=True,    # Whether to execute constant folding for optimization
            input_names=['input'],       # The model's input names
            output_names=['output'],     # The model's output names
            dynamic_axes=axes
        )
        print(f"ONNX model export successful: {onnx_path}")

        # Optional: Simplify the ONNX model
        if simplify_model:
            model = onnx.load(onnx_path)
            model_simplified, _ = simplify(model)
            onnx.save(model_simplified, onnx_path)
            print(f"ONNX model simplified successfully: {onnx_path}")

        # 添加文件验证逻辑
        if os.path.exists(onnx_path):
            file_size = os.path.getsize(onnx_path)
            if file_size > min_size:
                return onnx_path
            
            os.remove(onnx_path)
            print(f"ONNX model has unexpected file size ({file_size} bytes), deleted invalid file")
        return ""

    except Exception as e:
        print(f"ONNX model export error: {e}")
        return ""

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
    parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
    parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
    parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
    parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
    parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
    parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
    parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
    parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
    args = parser.parse_args()

    success = convert_pth_to_onnx(
        pth_path=args.pthpath,
        onnx_path=args.onnxpath,
        channel=args.channel,
        tilesize=args.tilesize,
        use_fp16=args.fp16,
        simplify_model=args.simplify,
        opset=args.opset,
        dynamic_axes= not args.fixed_axes,
    )

    if success:
        print("Conversion process finished.")
    else:
        print("Conversion process failed.")
        exit(1)  # Exit with error code