Updated model to use PyTorch instead of ONNX
Browse files- app.py +36 -25
- model.py +139 -0
- requirements.txt +2 -2
app.py
CHANGED
@@ -1,38 +1,47 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
4 |
-
import onnxruntime as ort
|
5 |
-
|
6 |
import pymatting
|
7 |
import numpy as np
|
8 |
-
import os
|
9 |
from PIL import Image
|
10 |
from typing import Tuple
|
11 |
import random
|
12 |
from pathlib import Path
|
13 |
|
|
|
14 |
|
15 |
-
def _load_model(checkpoint):
|
16 |
-
"""
|
17 |
-
Load the ONNX model for inference.
|
18 |
|
19 |
-
|
20 |
-
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
return
|
36 |
|
37 |
|
38 |
transforms = Compose(
|
@@ -44,9 +53,9 @@ transforms = Compose(
|
|
44 |
)
|
45 |
|
46 |
share_repo = False
|
47 |
-
checkpoint_path = "swin_small_patch4_window7_224_512_v1_latest.
|
48 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
-
|
50 |
|
51 |
|
52 |
def _get_foreground_estimation(image, alpha):
|
@@ -130,9 +139,11 @@ def _inference(image):
|
|
130 |
Returns:
|
131 |
np.ndarray: The predicted alpha mask.
|
132 |
"""
|
133 |
-
|
|
|
134 |
|
135 |
# Ensure the output is in valid range [0, 1]
|
|
|
136 |
output = np.clip(output, a_min=0, a_max=1)
|
137 |
|
138 |
return np.squeeze(output, axis=0).squeeze()
|
@@ -276,4 +287,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
276 |
run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
|
277 |
|
278 |
# Launch the interface
|
279 |
-
demo.launch(share=share_repo)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
|
|
|
|
4 |
import pymatting
|
5 |
import numpy as np
|
|
|
6 |
from PIL import Image
|
7 |
from typing import Tuple
|
8 |
import random
|
9 |
from pathlib import Path
|
10 |
|
11 |
+
from model import SwinMattingModel
|
12 |
|
|
|
|
|
|
|
13 |
|
14 |
+
def _load_checkpoint(model, checkpoint_path):
|
15 |
+
# Load the checkpoint
|
16 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
17 |
|
18 |
+
# Check if there are any errors when loading the state dictionary
|
19 |
+
missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
|
20 |
+
if missing_keys:
|
21 |
+
print(missing_keys)
|
22 |
+
raise RuntimeError("Missing keys in checkpoint.")
|
23 |
+
|
24 |
+
if unexpected_keys:
|
25 |
+
print(unexpected_keys)
|
26 |
+
raise RuntimeError("Unexpected keys in checkpoint.")
|
27 |
+
|
28 |
+
|
29 |
+
def _load_model(checkpoint, device):
|
30 |
+
model = SwinMattingModel({
|
31 |
+
"encoder": {
|
32 |
+
"model_name": "microsoft/swin-small-patch4-window7-224"
|
33 |
+
},
|
34 |
+
"decoder": {
|
35 |
+
"use_attn": True,
|
36 |
+
"refine_channels": 16
|
37 |
+
}
|
38 |
+
})
|
39 |
+
_load_checkpoint(model, checkpoint)
|
40 |
+
|
41 |
+
model.to(device)
|
42 |
+
model.eval()
|
43 |
|
44 |
+
return model
|
45 |
|
46 |
|
47 |
transforms = Compose(
|
|
|
53 |
)
|
54 |
|
55 |
share_repo = False
|
56 |
+
checkpoint_path = "swin_small_patch4_window7_224_512_v1_latest.pt"
|
57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
model = _load_model(checkpoint_path, device)
|
59 |
|
60 |
|
61 |
def _get_foreground_estimation(image, alpha):
|
|
|
139 |
Returns:
|
140 |
np.ndarray: The predicted alpha mask.
|
141 |
"""
|
142 |
+
with torch.inference_mode():
|
143 |
+
output = model(image)
|
144 |
|
145 |
# Ensure the output is in valid range [0, 1]
|
146 |
+
output = output.detach().cpu().numpy()
|
147 |
output = np.clip(output, a_min=0, a_max=1)
|
148 |
|
149 |
return np.squeeze(output, axis=0).squeeze()
|
|
|
287 |
run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
|
288 |
|
289 |
# Launch the interface
|
290 |
+
demo.launch(share=share_repo, ssr_mode=False)
|
model.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoBackbone
|
5 |
+
from typing import Any, Dict
|
6 |
+
|
7 |
+
|
8 |
+
class SwinMattingModel(nn.Module):
|
9 |
+
def __init__(self, config: Dict[str, Any]):
|
10 |
+
super().__init__()
|
11 |
+
encoder_config = config['encoder']
|
12 |
+
decoder_config = config['decoder']
|
13 |
+
|
14 |
+
self.encoder = SwinEncoder(model_name=encoder_config["model_name"])
|
15 |
+
self.decoder = MattingDecoder(
|
16 |
+
use_attn=decoder_config["use_attn"],
|
17 |
+
refine_channels=decoder_config["refine_channels"]
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
x (torch.Tensor): Input image [B, 3, 512, 512], normalized as needed for Swin.
|
24 |
+
Returns:
|
25 |
+
torch.Tensor: Alpha matte [B, 1, 512, 512].
|
26 |
+
"""
|
27 |
+
features = self.encoder(x) # list of 4 feature maps
|
28 |
+
return self.decoder(features, x) # decoded and refined alpha matte
|
29 |
+
|
30 |
+
|
31 |
+
class SwinEncoder(nn.Module):
|
32 |
+
def __init__(self, model_name="microsoft/swin-small-patch4-window7-224"):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.backbone = AutoBackbone.from_pretrained(model_name, out_indices=(1, 2, 3, 4))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
outputs = self.backbone(pixel_values=x)
|
39 |
+
features = outputs.feature_maps
|
40 |
+
features = list(features)
|
41 |
+
return features
|
42 |
+
|
43 |
+
|
44 |
+
class MattingDecoder(nn.Module):
|
45 |
+
def __init__(self, use_attn=False, refine_channels=16):
|
46 |
+
super().__init__()
|
47 |
+
self.use_attn = use_attn
|
48 |
+
self.refine_channels = refine_channels
|
49 |
+
|
50 |
+
# Bottom convolution (process 1/32 feature)
|
51 |
+
self.conv_bottom = nn.Conv2d(768, 768, kernel_size=3, padding=1)
|
52 |
+
self.bn_bottom = nn.BatchNorm2d(768)
|
53 |
+
|
54 |
+
# Upsample + fuse with skip connections
|
55 |
+
self.conv_up3 = nn.Conv2d(768 + 384, 384, kernel_size=3, padding=1)
|
56 |
+
self.bn_up3 = nn.BatchNorm2d(384)
|
57 |
+
|
58 |
+
self.conv_up2 = nn.Conv2d(384 + 192, 192, kernel_size=3, padding=1)
|
59 |
+
self.bn_up2 = nn.BatchNorm2d(192)
|
60 |
+
|
61 |
+
self.conv_up1 = nn.Conv2d(192 + 96, 96, kernel_size=3, padding=1)
|
62 |
+
self.bn_up1 = nn.BatchNorm2d(96)
|
63 |
+
|
64 |
+
self.conv_out = nn.Conv2d(96, 1, kernel_size=3, padding=1)
|
65 |
+
|
66 |
+
# Detail refinement
|
67 |
+
self.refine_conv1 = nn.Conv2d(4, self.refine_channels, kernel_size=3, padding=1)
|
68 |
+
self.bn_refine1 = nn.BatchNorm2d(self.refine_channels)
|
69 |
+
|
70 |
+
self.refine_conv2 = nn.Conv2d(self.refine_channels, self.refine_channels, kernel_size=3, padding=1)
|
71 |
+
self.bn_refine2 = nn.BatchNorm2d(self.refine_channels)
|
72 |
+
|
73 |
+
self.refine_conv3 = nn.Conv2d(self.refine_channels, 1, kernel_size=3, padding=1)
|
74 |
+
|
75 |
+
# Attention gates
|
76 |
+
if self.use_attn:
|
77 |
+
self.reduce_768_to_384 = nn.Conv2d(768, 384, kernel_size=1)
|
78 |
+
self.reduce_384_to_192 = nn.Conv2d(384, 192, kernel_size=1)
|
79 |
+
self.reduce_192_to_96 = nn.Conv2d(192, 96, kernel_size=1)
|
80 |
+
|
81 |
+
self.gate_16 = nn.Conv2d(384, 384, kernel_size=1)
|
82 |
+
self.skip_16 = nn.Conv2d(384, 384, kernel_size=1)
|
83 |
+
|
84 |
+
self.gate_8 = nn.Conv2d(192, 192, kernel_size=1)
|
85 |
+
self.skip_8 = nn.Conv2d(192, 192, kernel_size=1)
|
86 |
+
|
87 |
+
self.gate_4 = nn.Conv2d(96, 96, kernel_size=1)
|
88 |
+
self.skip_4 = nn.Conv2d(96, 96, kernel_size=1)
|
89 |
+
|
90 |
+
def forward(self, features, original_image):
|
91 |
+
f1, f2, f3, f4 = features # [1/4, 1/8, 1/16, 1/32]
|
92 |
+
|
93 |
+
# Bottom (1/32)
|
94 |
+
x = F.relu(self.bn_bottom(self.conv_bottom(f4)))
|
95 |
+
|
96 |
+
# 1/16 stage
|
97 |
+
x = F.interpolate(x, scale_factor=2.0, mode='nearest') # -> [B, 768, 32, 32]
|
98 |
+
if self.use_attn:
|
99 |
+
x_reduced = self.reduce_768_to_384(x)
|
100 |
+
g = self.gate_16(x_reduced)
|
101 |
+
skip = self.skip_16(f3)
|
102 |
+
att = torch.sigmoid(g + skip)
|
103 |
+
f3 = f3 * att
|
104 |
+
x = torch.cat([x, f3], dim=1)
|
105 |
+
x = F.relu(self.bn_up3(self.conv_up3(x))) # -> [B, 384, 32, 32]
|
106 |
+
|
107 |
+
# 1/8 stage
|
108 |
+
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
109 |
+
if self.use_attn:
|
110 |
+
x_reduced = self.reduce_384_to_192(x)
|
111 |
+
g = self.gate_8(x_reduced)
|
112 |
+
skip = self.skip_8(f2)
|
113 |
+
att = torch.sigmoid(g + skip)
|
114 |
+
f2 = f2 * att
|
115 |
+
x = torch.cat([x, f2], dim=1)
|
116 |
+
x = F.relu(self.bn_up2(self.conv_up2(x))) # -> [B, 192, 64, 64]
|
117 |
+
|
118 |
+
# 1/4 stage
|
119 |
+
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
120 |
+
if self.use_attn:
|
121 |
+
x_reduced = self.reduce_192_to_96(x)
|
122 |
+
g = self.gate_4(x_reduced)
|
123 |
+
skip = self.skip_4(f1)
|
124 |
+
att = torch.sigmoid(g + skip)
|
125 |
+
f1 = f1 * att
|
126 |
+
x = torch.cat([x, f1], dim=1)
|
127 |
+
x = F.relu(self.bn_up1(self.conv_up1(x))) # -> [B, 96, 128, 128]
|
128 |
+
|
129 |
+
# Upsample to full resolution and predict coarse alpha
|
130 |
+
x = F.interpolate(x, size=original_image.shape[-2:], mode='nearest') # -> [B, 96, 512, 512]
|
131 |
+
coarse_alpha = self.conv_out(x)
|
132 |
+
|
133 |
+
# Detail refinement
|
134 |
+
refine_input = torch.cat([coarse_alpha, original_image], dim=1)
|
135 |
+
r = F.relu(self.bn_refine1(self.refine_conv1(refine_input)))
|
136 |
+
r = F.relu(self.bn_refine2(self.refine_conv2(r)))
|
137 |
+
refined_alpha = self.refine_conv3(r)
|
138 |
+
|
139 |
+
return torch.sigmoid(refined_alpha)
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
gradio
|
2 |
torch
|
3 |
torchvision
|
|
|
4 |
numpy
|
5 |
pillow
|
6 |
pymatting
|
7 |
-
opencv-python
|
8 |
-
onnxruntime-gpu
|
|
|
1 |
gradio
|
2 |
torch
|
3 |
torchvision
|
4 |
+
transformers
|
5 |
numpy
|
6 |
pillow
|
7 |
pymatting
|
8 |
+
opencv-python
|
|