Svane20 commited on
Commit
2c67a31
·
1 Parent(s): 6c0e2b4

Updated model to use PyTorch instead of ONNX

Browse files
Files changed (3) hide show
  1. app.py +36 -25
  2. model.py +139 -0
  3. requirements.txt +2 -2
app.py CHANGED
@@ -1,38 +1,47 @@
1
  import gradio as gr
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
4
- import onnxruntime as ort
5
-
6
  import pymatting
7
  import numpy as np
8
- import os
9
  from PIL import Image
10
  from typing import Tuple
11
  import random
12
  from pathlib import Path
13
 
 
14
 
15
- def _load_model(checkpoint):
16
- """
17
- Load the ONNX model for inference.
18
 
19
- Args:
20
- checkpoint (str): Path to the ONNX model file.
 
21
 
22
- Returns:
23
- session (onnxruntime.InferenceSession): The ONNX runtime session.
24
- input_name (str): The name of the input tensor.
25
- output_name (str): The name of the output tensor.
26
- """
27
- session_options = ort.SessionOptions()
28
- session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
29
- session_options.intra_op_num_threads = min(1, os.cpu_count() - 1)
30
- providers = ['CUDAExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
31
- session = ort.InferenceSession(checkpoint, providers=providers)
32
- input_name = session.get_inputs()[0].name
33
- output_name = session.get_outputs()[0].name
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- return session, input_name, output_name
36
 
37
 
38
  transforms = Compose(
@@ -44,9 +53,9 @@ transforms = Compose(
44
  )
45
 
46
  share_repo = False
47
- checkpoint_path = "swin_small_patch4_window7_224_512_v1_latest.onnx"
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
- session, input_name, output_name = _load_model(checkpoint_path)
50
 
51
 
52
  def _get_foreground_estimation(image, alpha):
@@ -130,9 +139,11 @@ def _inference(image):
130
  Returns:
131
  np.ndarray: The predicted alpha mask.
132
  """
133
- output = session.run(output_names=[output_name], input_feed={input_name: image.cpu().numpy()})[0]
 
134
 
135
  # Ensure the output is in valid range [0, 1]
 
136
  output = np.clip(output, a_min=0, a_max=1)
137
 
138
  return np.squeeze(output, axis=0).squeeze()
@@ -276,4 +287,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
276
  run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
277
 
278
  # Launch the interface
279
- demo.launch(share=share_repo)
 
1
  import gradio as gr
2
  import torch
3
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
 
 
4
  import pymatting
5
  import numpy as np
 
6
  from PIL import Image
7
  from typing import Tuple
8
  import random
9
  from pathlib import Path
10
 
11
+ from model import SwinMattingModel
12
 
 
 
 
13
 
14
+ def _load_checkpoint(model, checkpoint_path):
15
+ # Load the checkpoint
16
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
17
 
18
+ # Check if there are any errors when loading the state dictionary
19
+ missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
20
+ if missing_keys:
21
+ print(missing_keys)
22
+ raise RuntimeError("Missing keys in checkpoint.")
23
+
24
+ if unexpected_keys:
25
+ print(unexpected_keys)
26
+ raise RuntimeError("Unexpected keys in checkpoint.")
27
+
28
+
29
+ def _load_model(checkpoint, device):
30
+ model = SwinMattingModel({
31
+ "encoder": {
32
+ "model_name": "microsoft/swin-small-patch4-window7-224"
33
+ },
34
+ "decoder": {
35
+ "use_attn": True,
36
+ "refine_channels": 16
37
+ }
38
+ })
39
+ _load_checkpoint(model, checkpoint)
40
+
41
+ model.to(device)
42
+ model.eval()
43
 
44
+ return model
45
 
46
 
47
  transforms = Compose(
 
53
  )
54
 
55
  share_repo = False
56
+ checkpoint_path = "swin_small_patch4_window7_224_512_v1_latest.pt"
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ model = _load_model(checkpoint_path, device)
59
 
60
 
61
  def _get_foreground_estimation(image, alpha):
 
139
  Returns:
140
  np.ndarray: The predicted alpha mask.
141
  """
142
+ with torch.inference_mode():
143
+ output = model(image)
144
 
145
  # Ensure the output is in valid range [0, 1]
146
+ output = output.detach().cpu().numpy()
147
  output = np.clip(output, a_min=0, a_max=1)
148
 
149
  return np.squeeze(output, axis=0).squeeze()
 
287
  run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
288
 
289
  # Launch the interface
290
+ demo.launch(share=share_repo, ssr_mode=False)
model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoBackbone
5
+ from typing import Any, Dict
6
+
7
+
8
+ class SwinMattingModel(nn.Module):
9
+ def __init__(self, config: Dict[str, Any]):
10
+ super().__init__()
11
+ encoder_config = config['encoder']
12
+ decoder_config = config['decoder']
13
+
14
+ self.encoder = SwinEncoder(model_name=encoder_config["model_name"])
15
+ self.decoder = MattingDecoder(
16
+ use_attn=decoder_config["use_attn"],
17
+ refine_channels=decoder_config["refine_channels"]
18
+ )
19
+
20
+ def forward(self, x):
21
+ """
22
+ Args:
23
+ x (torch.Tensor): Input image [B, 3, 512, 512], normalized as needed for Swin.
24
+ Returns:
25
+ torch.Tensor: Alpha matte [B, 1, 512, 512].
26
+ """
27
+ features = self.encoder(x) # list of 4 feature maps
28
+ return self.decoder(features, x) # decoded and refined alpha matte
29
+
30
+
31
+ class SwinEncoder(nn.Module):
32
+ def __init__(self, model_name="microsoft/swin-small-patch4-window7-224"):
33
+ super().__init__()
34
+
35
+ self.backbone = AutoBackbone.from_pretrained(model_name, out_indices=(1, 2, 3, 4))
36
+
37
+ def forward(self, x):
38
+ outputs = self.backbone(pixel_values=x)
39
+ features = outputs.feature_maps
40
+ features = list(features)
41
+ return features
42
+
43
+
44
+ class MattingDecoder(nn.Module):
45
+ def __init__(self, use_attn=False, refine_channels=16):
46
+ super().__init__()
47
+ self.use_attn = use_attn
48
+ self.refine_channels = refine_channels
49
+
50
+ # Bottom convolution (process 1/32 feature)
51
+ self.conv_bottom = nn.Conv2d(768, 768, kernel_size=3, padding=1)
52
+ self.bn_bottom = nn.BatchNorm2d(768)
53
+
54
+ # Upsample + fuse with skip connections
55
+ self.conv_up3 = nn.Conv2d(768 + 384, 384, kernel_size=3, padding=1)
56
+ self.bn_up3 = nn.BatchNorm2d(384)
57
+
58
+ self.conv_up2 = nn.Conv2d(384 + 192, 192, kernel_size=3, padding=1)
59
+ self.bn_up2 = nn.BatchNorm2d(192)
60
+
61
+ self.conv_up1 = nn.Conv2d(192 + 96, 96, kernel_size=3, padding=1)
62
+ self.bn_up1 = nn.BatchNorm2d(96)
63
+
64
+ self.conv_out = nn.Conv2d(96, 1, kernel_size=3, padding=1)
65
+
66
+ # Detail refinement
67
+ self.refine_conv1 = nn.Conv2d(4, self.refine_channels, kernel_size=3, padding=1)
68
+ self.bn_refine1 = nn.BatchNorm2d(self.refine_channels)
69
+
70
+ self.refine_conv2 = nn.Conv2d(self.refine_channels, self.refine_channels, kernel_size=3, padding=1)
71
+ self.bn_refine2 = nn.BatchNorm2d(self.refine_channels)
72
+
73
+ self.refine_conv3 = nn.Conv2d(self.refine_channels, 1, kernel_size=3, padding=1)
74
+
75
+ # Attention gates
76
+ if self.use_attn:
77
+ self.reduce_768_to_384 = nn.Conv2d(768, 384, kernel_size=1)
78
+ self.reduce_384_to_192 = nn.Conv2d(384, 192, kernel_size=1)
79
+ self.reduce_192_to_96 = nn.Conv2d(192, 96, kernel_size=1)
80
+
81
+ self.gate_16 = nn.Conv2d(384, 384, kernel_size=1)
82
+ self.skip_16 = nn.Conv2d(384, 384, kernel_size=1)
83
+
84
+ self.gate_8 = nn.Conv2d(192, 192, kernel_size=1)
85
+ self.skip_8 = nn.Conv2d(192, 192, kernel_size=1)
86
+
87
+ self.gate_4 = nn.Conv2d(96, 96, kernel_size=1)
88
+ self.skip_4 = nn.Conv2d(96, 96, kernel_size=1)
89
+
90
+ def forward(self, features, original_image):
91
+ f1, f2, f3, f4 = features # [1/4, 1/8, 1/16, 1/32]
92
+
93
+ # Bottom (1/32)
94
+ x = F.relu(self.bn_bottom(self.conv_bottom(f4)))
95
+
96
+ # 1/16 stage
97
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest') # -> [B, 768, 32, 32]
98
+ if self.use_attn:
99
+ x_reduced = self.reduce_768_to_384(x)
100
+ g = self.gate_16(x_reduced)
101
+ skip = self.skip_16(f3)
102
+ att = torch.sigmoid(g + skip)
103
+ f3 = f3 * att
104
+ x = torch.cat([x, f3], dim=1)
105
+ x = F.relu(self.bn_up3(self.conv_up3(x))) # -> [B, 384, 32, 32]
106
+
107
+ # 1/8 stage
108
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest')
109
+ if self.use_attn:
110
+ x_reduced = self.reduce_384_to_192(x)
111
+ g = self.gate_8(x_reduced)
112
+ skip = self.skip_8(f2)
113
+ att = torch.sigmoid(g + skip)
114
+ f2 = f2 * att
115
+ x = torch.cat([x, f2], dim=1)
116
+ x = F.relu(self.bn_up2(self.conv_up2(x))) # -> [B, 192, 64, 64]
117
+
118
+ # 1/4 stage
119
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest')
120
+ if self.use_attn:
121
+ x_reduced = self.reduce_192_to_96(x)
122
+ g = self.gate_4(x_reduced)
123
+ skip = self.skip_4(f1)
124
+ att = torch.sigmoid(g + skip)
125
+ f1 = f1 * att
126
+ x = torch.cat([x, f1], dim=1)
127
+ x = F.relu(self.bn_up1(self.conv_up1(x))) # -> [B, 96, 128, 128]
128
+
129
+ # Upsample to full resolution and predict coarse alpha
130
+ x = F.interpolate(x, size=original_image.shape[-2:], mode='nearest') # -> [B, 96, 512, 512]
131
+ coarse_alpha = self.conv_out(x)
132
+
133
+ # Detail refinement
134
+ refine_input = torch.cat([coarse_alpha, original_image], dim=1)
135
+ r = F.relu(self.bn_refine1(self.refine_conv1(refine_input)))
136
+ r = F.relu(self.bn_refine2(self.refine_conv2(r)))
137
+ refined_alpha = self.refine_conv3(r)
138
+
139
+ return torch.sigmoid(refined_alpha)
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  gradio
2
  torch
3
  torchvision
 
4
  numpy
5
  pillow
6
  pymatting
7
- opencv-python
8
- onnxruntime-gpu
 
1
  gradio
2
  torch
3
  torchvision
4
+ transformers
5
  numpy
6
  pillow
7
  pymatting
8
+ opencv-python