Svane20 commited on
Commit
d3ba20a
·
1 Parent(s): 5725bfe

Updated model to use PyTorch instead of ONNX

Browse files
Files changed (1) hide show
  1. pipeline.py +8 -7
pipeline.py CHANGED
@@ -16,7 +16,8 @@ class Pipeline:
16
  )
17
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
  self.is_torch_script = self.device.type == 'cpu'
19
- self.model = self._load_model(model_name)
 
20
 
21
  self._log_device_info()
22
 
@@ -33,7 +34,7 @@ class Pipeline:
33
 
34
  return np.squeeze(output, axis=0).squeeze()
35
 
36
- def _load_pytorch_model(self, checkpoint):
37
  model = SwinMattingModel({
38
  "encoder": {
39
  "model_name": "microsoft/swin-small-patch4-window7-224"
@@ -43,23 +44,23 @@ class Pipeline:
43
  "refine_channels": 16
44
  }
45
  })
46
- self._load_checkpoint(model, checkpoint)
47
 
48
  model.to(self.device)
49
  model.eval()
50
 
51
  return model
52
 
53
- def _load_model(self, model_name):
54
- model = self._load_pytorch_model(checkpoint=f"models/{model_name}.pt")
55
 
56
  model.to(self.device)
57
  model.eval()
58
 
59
  return model
60
 
61
- def _load_checkpoint(self, model, checkpoint_path):
62
- checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
63
 
64
  missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
65
  if missing_keys:
 
16
  )
17
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
  self.is_torch_script = self.device.type == 'cpu'
19
+ self.checkpoint = f"models/{model_name}.pt"
20
+ self.model = self._load_model()
21
 
22
  self._log_device_info()
23
 
 
34
 
35
  return np.squeeze(output, axis=0).squeeze()
36
 
37
+ def _load_pytorch_model(self):
38
  model = SwinMattingModel({
39
  "encoder": {
40
  "model_name": "microsoft/swin-small-patch4-window7-224"
 
44
  "refine_channels": 16
45
  }
46
  })
47
+ self._load_checkpoint(model)
48
 
49
  model.to(self.device)
50
  model.eval()
51
 
52
  return model
53
 
54
+ def _load_model(self):
55
+ model = self._load_pytorch_model()
56
 
57
  model.to(self.device)
58
  model.eval()
59
 
60
  return model
61
 
62
+ def _load_checkpoint(self, model):
63
+ checkpoint = torch.load(self.checkpoint, map_location="cpu", weights_only=True)
64
 
65
  missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
66
  if missing_keys: