Updated model to use PyTorch instead of ONNX
Browse files- pipeline.py +8 -7
pipeline.py
CHANGED
@@ -16,7 +16,8 @@ class Pipeline:
|
|
16 |
)
|
17 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
self.is_torch_script = self.device.type == 'cpu'
|
19 |
-
self.
|
|
|
20 |
|
21 |
self._log_device_info()
|
22 |
|
@@ -33,7 +34,7 @@ class Pipeline:
|
|
33 |
|
34 |
return np.squeeze(output, axis=0).squeeze()
|
35 |
|
36 |
-
def _load_pytorch_model(self
|
37 |
model = SwinMattingModel({
|
38 |
"encoder": {
|
39 |
"model_name": "microsoft/swin-small-patch4-window7-224"
|
@@ -43,23 +44,23 @@ class Pipeline:
|
|
43 |
"refine_channels": 16
|
44 |
}
|
45 |
})
|
46 |
-
self._load_checkpoint(model
|
47 |
|
48 |
model.to(self.device)
|
49 |
model.eval()
|
50 |
|
51 |
return model
|
52 |
|
53 |
-
def _load_model(self
|
54 |
-
model = self._load_pytorch_model(
|
55 |
|
56 |
model.to(self.device)
|
57 |
model.eval()
|
58 |
|
59 |
return model
|
60 |
|
61 |
-
def _load_checkpoint(self, model
|
62 |
-
checkpoint = torch.load(
|
63 |
|
64 |
missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
|
65 |
if missing_keys:
|
|
|
16 |
)
|
17 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
self.is_torch_script = self.device.type == 'cpu'
|
19 |
+
self.checkpoint = f"models/{model_name}.pt"
|
20 |
+
self.model = self._load_model()
|
21 |
|
22 |
self._log_device_info()
|
23 |
|
|
|
34 |
|
35 |
return np.squeeze(output, axis=0).squeeze()
|
36 |
|
37 |
+
def _load_pytorch_model(self):
|
38 |
model = SwinMattingModel({
|
39 |
"encoder": {
|
40 |
"model_name": "microsoft/swin-small-patch4-window7-224"
|
|
|
44 |
"refine_channels": 16
|
45 |
}
|
46 |
})
|
47 |
+
self._load_checkpoint(model)
|
48 |
|
49 |
model.to(self.device)
|
50 |
model.eval()
|
51 |
|
52 |
return model
|
53 |
|
54 |
+
def _load_model(self):
|
55 |
+
model = self._load_pytorch_model()
|
56 |
|
57 |
model.to(self.device)
|
58 |
model.eval()
|
59 |
|
60 |
return model
|
61 |
|
62 |
+
def _load_checkpoint(self, model):
|
63 |
+
checkpoint = torch.load(self.checkpoint, map_location="cpu", weights_only=True)
|
64 |
|
65 |
missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
|
66 |
if missing_keys:
|