Spaces:
Running
on
Zero
Running
on
Zero
geshang777
commited on
Commit
·
cb86726
1
Parent(s):
4c269f1
add sam2
Browse files- app.py +3 -21
- sam2_weights/sam2.1_hiera_large.pt +3 -0
app.py
CHANGED
@@ -50,24 +50,6 @@ def install_sam2():
|
|
50 |
# 以可编辑模式安装SAM2
|
51 |
subprocess.run(["pip", "install", "-e", "."], check=True)
|
52 |
|
53 |
-
# 切换到checkpoints目录下载模型
|
54 |
-
os.chdir("checkpoints")
|
55 |
-
if not os.path.exists("download_ckpts.sh"):
|
56 |
-
subprocess.run([
|
57 |
-
"wget",
|
58 |
-
"https://raw.githubusercontent.com/facebookresearch/sam2/main/checkpoints/download_ckpts.sh"
|
59 |
-
], check=True)
|
60 |
-
subprocess.run(["chmod", "+x", "download_ckpts.sh"], check=True)
|
61 |
-
|
62 |
-
# 下载检查点(添加超时和重试)
|
63 |
-
result = subprocess.run(["./download_ckpts.sh"], check=False)
|
64 |
-
if result.returncode != 0:
|
65 |
-
print("Warning: Checkpoint download failed, trying alternative method...")
|
66 |
-
subprocess.run([
|
67 |
-
"wget",
|
68 |
-
"https://dl.fbaipublicfiles.com/sam2/sam2.1_hiera_large.pt",
|
69 |
-
"-O", "sam2.1_hiera_large.pt"
|
70 |
-
], check=True)
|
71 |
|
72 |
except Exception as e:
|
73 |
print(f"Error during SAM2 installation: {str(e)}")
|
@@ -94,7 +76,7 @@ print("🎉 SAM2 modules imported successfully!")
|
|
94 |
# ------------------ 初始化模型 ------------------
|
95 |
# 使用相对路径
|
96 |
MODEL_PATH = "geshang/Seg-R1-COD"
|
97 |
-
SAM_CHECKPOINT = "
|
98 |
|
99 |
# 自动检测设备
|
100 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -123,8 +105,8 @@ class CustomSAMWrapper:
|
|
123 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
124 |
|
125 |
self.device = torch.device(device)
|
126 |
-
model_cfg = os.path.join("third_party/sam2", "configs/sam2.1/sam2.1_hiera_l.yaml")
|
127 |
-
sam_model = build_sam2(
|
128 |
sam_model = sam_model.to(self.device)
|
129 |
self.predictor = SAM2ImagePredictor(sam_model)
|
130 |
self.last_mask = None
|
|
|
50 |
# 以可编辑模式安装SAM2
|
51 |
subprocess.run(["pip", "install", "-e", "."], check=True)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
except Exception as e:
|
55 |
print(f"Error during SAM2 installation: {str(e)}")
|
|
|
76 |
# ------------------ 初始化模型 ------------------
|
77 |
# 使用相对路径
|
78 |
MODEL_PATH = "geshang/Seg-R1-COD"
|
79 |
+
SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt"
|
80 |
|
81 |
# 自动检测设备
|
82 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
105 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
106 |
|
107 |
self.device = torch.device(device)
|
108 |
+
# model_cfg = os.path.join("third_party/sam2", "configs/sam2.1/sam2.1_hiera_l.yaml")
|
109 |
+
sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path)
|
110 |
sam_model = sam_model.to(self.device)
|
111 |
self.predictor = SAM2ImagePredictor(sam_model)
|
112 |
self.last_mask = None
|
sam2_weights/sam2.1_hiera_large.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318
|
3 |
+
size 898083611
|