geshang777 commited on
Commit
cb86726
·
1 Parent(s): 4c269f1
Files changed (2) hide show
  1. app.py +3 -21
  2. sam2_weights/sam2.1_hiera_large.pt +3 -0
app.py CHANGED
@@ -50,24 +50,6 @@ def install_sam2():
50
  # 以可编辑模式安装SAM2
51
  subprocess.run(["pip", "install", "-e", "."], check=True)
52
 
53
- # 切换到checkpoints目录下载模型
54
- os.chdir("checkpoints")
55
- if not os.path.exists("download_ckpts.sh"):
56
- subprocess.run([
57
- "wget",
58
- "https://raw.githubusercontent.com/facebookresearch/sam2/main/checkpoints/download_ckpts.sh"
59
- ], check=True)
60
- subprocess.run(["chmod", "+x", "download_ckpts.sh"], check=True)
61
-
62
- # 下载检查点(添加超时和重试)
63
- result = subprocess.run(["./download_ckpts.sh"], check=False)
64
- if result.returncode != 0:
65
- print("Warning: Checkpoint download failed, trying alternative method...")
66
- subprocess.run([
67
- "wget",
68
- "https://dl.fbaipublicfiles.com/sam2/sam2.1_hiera_large.pt",
69
- "-O", "sam2.1_hiera_large.pt"
70
- ], check=True)
71
 
72
  except Exception as e:
73
  print(f"Error during SAM2 installation: {str(e)}")
@@ -94,7 +76,7 @@ print("🎉 SAM2 modules imported successfully!")
94
  # ------------------ 初始化模型 ------------------
95
  # 使用相对路径
96
  MODEL_PATH = "geshang/Seg-R1-COD"
97
- SAM_CHECKPOINT = "third_party/sam2/checkpoints/sam2.1_hiera_large.pt"
98
 
99
  # 自动检测设备
100
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -123,8 +105,8 @@ class CustomSAMWrapper:
123
  from sam2.sam2_image_predictor import SAM2ImagePredictor
124
 
125
  self.device = torch.device(device)
126
- model_cfg = os.path.join("third_party/sam2", "configs/sam2.1/sam2.1_hiera_l.yaml")
127
- sam_model = build_sam2(model_cfg, model_path)
128
  sam_model = sam_model.to(self.device)
129
  self.predictor = SAM2ImagePredictor(sam_model)
130
  self.last_mask = None
 
50
  # 以可编辑模式安装SAM2
51
  subprocess.run(["pip", "install", "-e", "."], check=True)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  except Exception as e:
55
  print(f"Error during SAM2 installation: {str(e)}")
 
76
  # ------------------ 初始化模型 ------------------
77
  # 使用相对路径
78
  MODEL_PATH = "geshang/Seg-R1-COD"
79
+ SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt"
80
 
81
  # 自动检测设备
82
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
105
  from sam2.sam2_image_predictor import SAM2ImagePredictor
106
 
107
  self.device = torch.device(device)
108
+ # model_cfg = os.path.join("third_party/sam2", "configs/sam2.1/sam2.1_hiera_l.yaml")
109
+ sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path)
110
  sam_model = sam_model.to(self.device)
111
  self.predictor = SAM2ImagePredictor(sam_model)
112
  self.last_mask = None
sam2_weights/sam2.1_hiera_large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318
3
+ size 898083611