geshang commited on
Commit
67a2118
·
verified ·
1 Parent(s): a3b73e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -86,7 +86,7 @@ except Exception as e:
86
  class CustomSAMWrapper:
87
  def __init__(self, model_path: str, device: str = DEVICE):
88
  # try:
89
- self.device = "cpu" #torch.device(device)
90
  sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path, self.device)
91
  sam_model = sam_model.to(self.device)
92
  self.predictor = SAM2ImagePredictor(sam_model)
@@ -231,7 +231,8 @@ def visualize_masks_on_image(
231
  blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0)
232
  return PILImage.fromarray(blended)
233
 
234
-
 
235
  def run_pipeline(image: PILImage.Image, prompt: str):
236
  if not model or not processor:
237
  return "Models not loaded. Please check logs.", None
 
86
  class CustomSAMWrapper:
87
  def __init__(self, model_path: str, device: str = DEVICE):
88
  # try:
89
+ self.device = torch.device(device)
90
  sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path, self.device)
91
  sam_model = sam_model.to(self.device)
92
  self.predictor = SAM2ImagePredictor(sam_model)
 
231
  blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0)
232
  return PILImage.fromarray(blended)
233
 
234
+ @spaces.GPU
235
+ @torch.no_grad()
236
  def run_pipeline(image: PILImage.Image, prompt: str):
237
  if not model or not processor:
238
  return "Models not loaded. Please check logs.", None