import torch from torchvision.transforms import Compose, Resize, ToTensor, Normalize import numpy as np from models import SwinMattingModel class Pipeline: def __init__(self, model_name: str): self.transforms = Compose( [ Resize(size=(512, 512)), ToTensor(), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ], ) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.is_torch_script = self.device.type == 'cpu' self.checkpoint = f"models/{model_name}.pt" self.model = self._load_model() self._log_device_info() def inference(self, image): if self.model is None: raise RuntimeError("Model is not loaded. Call load_model() first.") tensor = self.transforms(image).unsqueeze(0).to(self.device) with torch.inference_mode(): output = self.model(tensor) output = output.detach().cpu().numpy() output = np.clip(output, a_min=0, a_max=1) return np.squeeze(output, axis=0).squeeze() def _load_pytorch_model(self): model = SwinMattingModel({ "encoder": { "model_name": "microsoft/swin-small-patch4-window7-224" }, "decoder": { "use_attn": True, "refine_channels": 16 } }) self._load_checkpoint(model) model.to(self.device) model.eval() return model def _load_model(self): model = self._load_pytorch_model() model.to(self.device) model.eval() return model def _load_checkpoint(self, model): checkpoint = torch.load(self.checkpoint, map_location="cpu", weights_only=True) missing_keys, unexpected_keys = model.load_state_dict(checkpoint) if missing_keys: print(missing_keys) raise RuntimeError("Missing keys in checkpoint.") if unexpected_keys: print(unexpected_keys) raise RuntimeError("Unexpected keys in checkpoint.") def _log_device_info(self): if self.device.type == 'cuda': print(f"Hardware: {torch.cuda.get_device_name(torch.cuda.current_device())}")