Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import gradio as gr | |
from PIL import Image as PILImage | |
import torchvision.transforms.functional as TF | |
import numpy as np | |
import torch | |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
from qwen_vl_utils import process_vision_info | |
import re | |
import io | |
import base64 | |
import cv2 | |
from typing import List, Tuple, Optional | |
import sys | |
import spaces | |
def add_sam2_to_path(): | |
sam2_dir = os.path.abspath("third_party/sam2") | |
if sam2_dir not in sys.path: | |
sys.path.insert(0, sam2_dir) | |
return sam2_dir | |
def install_sam2(): | |
sam2_dir = "third_party/sam2" | |
if not os.path.exists(sam2_dir): | |
print("Installing SAM2...") | |
os.makedirs("third_party", exist_ok=True) | |
subprocess.run([ | |
"git", "clone", | |
"--recursive", | |
"https://github.com/facebookresearch/sam2.git", | |
sam2_dir | |
], check=True) | |
original_dir = os.getcwd() | |
try: | |
os.chdir(sam2_dir) | |
subprocess.run(["pip", "install", "-e", "."], check=True) | |
except Exception as e: | |
print(f"Error during SAM2 installation: {str(e)}") | |
raise | |
finally: | |
os.chdir(original_dir) | |
print("✅ SAM2 installed successfully!") | |
else: | |
print("SAM2 already exists, skipping installation.") | |
install_sam2() | |
sam2_dir = add_sam2_to_path() | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
print("🎉 SAM2 modules imported successfully!") | |
MODEL_PATH = "geshang/Seg-R1-7B" | |
SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt" | |
DEVICE = "cuda" #if torch.cuda.is_available() else "cpu" | |
RESIZE_SIZE = (1024, 1024) | |
try: | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
device_map="auto" if DEVICE == "cuda" else None | |
) | |
processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True) | |
print(f"Qwen model loaded on {DEVICE}") | |
except Exception as e: | |
print(f"Error loading Qwen model: {e}") | |
model = None | |
processor = None | |
# SAM Wrapper | |
class CustomSAMWrapper: | |
def __init__(self, model_path: str, device: str = DEVICE): | |
# try: | |
self.device = torch.device(device) | |
sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path, self.device) | |
sam_model = sam_model.to(self.device) | |
self.predictor = SAM2ImagePredictor(sam_model) | |
self.last_mask = None | |
print(f"SAM model loaded on {device}") | |
# except Exception as e: | |
# print(f"Error loading SAM model: {e}") | |
# self.predictor = None | |
def predict(self, image: PILImage.Image, | |
points: List[Tuple[int, int]], | |
labels: List[int], | |
bbox: Optional[List[List[int]]] = None) -> Tuple[np.ndarray, float]: | |
if not self.predictor: | |
return np.zeros((image.height, image.width), dtype=bool), 0.0 | |
try: | |
input_points = np.array(points) if points else None | |
input_labels = np.array(labels) if labels else None | |
input_bboxes = np.array(bbox) if bbox else None | |
image_np = np.array(image) | |
rgb_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
self.predictor.set_image(rgb_image) | |
mask_pred, score, logits = self.predictor.predict( | |
point_coords=input_points, | |
point_labels=input_labels, | |
box=input_bboxes, | |
multimask_output=False, | |
) | |
self.last_mask = mask_pred[0] | |
return mask_pred[0], score[0] | |
except Exception as e: | |
print(f"SAM prediction error: {e}") | |
return np.zeros((image.height, image.width), dtype=bool), 0.0 | |
def parse_custom_format(content: str): | |
point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>" | |
label_pattern = r"<labels>\s*(\[\s*(?:\d+\s*,?\s*)+\])\s*</labels>" | |
bbox_pattern = r"<bbox>\s*(\[\s*\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+\s*\])\s*</bbox>" | |
point_match = re.search(point_pattern, content) | |
label_match = re.search(label_pattern, content) | |
bbox_matches = re.findall(bbox_pattern, content) | |
try: | |
points = np.array(eval(point_match.group(1))) if point_match else None | |
labels = np.array(eval(label_match.group(1))) if label_match else None | |
if points is not None and labels is not None: | |
if not (len(points.shape) == 2 and points.shape[1] == 2 and len(labels) == points.shape[0]): | |
points, labels = None, None | |
bboxes = [] | |
for bbox_str in bbox_matches: | |
bbox = np.array(eval(bbox_str)) | |
if len(bbox.shape) == 1 and bbox.shape[0] == 4: | |
bboxes.append(bbox) | |
bboxes = np.stack(bboxes, axis=0) if bboxes else None | |
return points, labels, bboxes | |
except Exception as e: | |
print("Error parsing content:", e) | |
return None, None, None | |
def prepare_test_messages(image, prompt): | |
buffered = io.BytesIO() | |
image = TF.resize(image, RESIZE_SIZE) | |
image.save(buffered, format="JPEG") | |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
SYSTEM_PROMPT = ( | |
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " | |
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " | |
"process should enclosed within <think> </think> tags, and the bounding box, points and points labels should be enclosed within <bbox></bbox>, <points></points>, and <labels></labels>, respectively. i.e., " | |
"<think> reasoning process here </think> <bbox>[x1,y1,x2,y2]</bbox>, <points>[[x3,y3],[x4,y4],...]</points>, <labels>[1,0,...]</labels>" | |
"Where 1 indicates a foreground (object) point, and 0 indicates a background point." | |
) | |
messages = [ | |
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": f"data:image/jpeg;base64,{img_base64}"}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
return [messages] | |
def answer_question(batch_messages): | |
if not model or not processor: | |
return ["Model not loaded. Please check logs."] | |
try: | |
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] | |
image_inputs, video_inputs = process_vision_info(batch_messages) | |
inputs = processor(text=text, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True).to(DEVICE) | |
outputs = model.generate(**inputs, use_cache=True, max_new_tokens=1024) | |
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)] | |
return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
except Exception as e: | |
print(f"Error generating answer: {e}") | |
return ["Error generating response"] | |
def visualize_masks_on_image( | |
image: PILImage.Image, | |
masks_np: list, | |
colors=[(255, 0, 0), (0, 255, 0), (0, 0, 255), | |
(255, 255, 0), (255, 0, 255), (0, 255, 255), | |
(128, 128, 255)], | |
alpha=0.5, | |
): | |
if not masks_np: | |
return image | |
image_np = np.array(image) | |
color_mask = np.zeros((image_np.shape[0], image_np.shape[1], 3), dtype=np.uint8) | |
for i, mask in enumerate(masks_np): | |
color = colors[i % len(colors)] | |
mask = mask.astype(np.uint8) | |
if mask.shape[:2] != image_np.shape[:2]: | |
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) | |
color_mask[:, :, 0] = color_mask[:, :, 0] | (mask * color[0]) | |
color_mask[:, :, 1] = color_mask[:, :, 1] | (mask * color[1]) | |
color_mask[:, :, 2] = color_mask[:, :, 2] | (mask * color[2]) | |
blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0) | |
return PILImage.fromarray(blended) | |
def run_pipeline(image: PILImage.Image, prompt: str): | |
sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE) | |
if not model or not processor: | |
return "Models not loaded. Please check logs.", None | |
try: | |
img_original = image.copy() | |
img_resized = TF.resize(image, RESIZE_SIZE) | |
messages = prepare_test_messages(img_resized, prompt) | |
output_text = answer_question(messages)[0] | |
print(f"Model output: {output_text}") | |
points, labels, bbox = parse_custom_format(output_text) | |
mask_pred = None | |
final_mask = np.zeros(RESIZE_SIZE[::-1], dtype=bool) | |
if (points is not None and labels is not None) or (bbox is not None): | |
img = img_resized | |
if bbox is not None and len(bbox.shape) == 2: | |
for b in bbox: | |
b = b.tolist() | |
if points is not None and labels is not None: | |
in_bbox_mask = ( | |
(points[:, 0] >= b[0]) & (points[:, 0] <= b[2]) & | |
(points[:, 1] >= b[1]) & (points[:, 1] <= b[3]) | |
) | |
selected_points = points[in_bbox_mask] | |
selected_labels = labels[in_bbox_mask] | |
else: | |
selected_points, selected_labels = None, None | |
try: | |
mask, _ = sam_wrapper.predict( | |
img, | |
selected_points.tolist() if selected_points is not None and len(selected_points) > 0 else None, | |
selected_labels.tolist() if selected_labels is not None and len(selected_labels) > 0 else None, | |
b | |
) | |
final_mask |= (mask > 0) | |
except Exception as e: | |
print(f"Mask prediction error for bbox: {e}") | |
continue | |
mask_pred = final_mask | |
else: | |
try: | |
mask_pred, _ = sam_wrapper.predict( | |
img, | |
points.tolist() if points is not None else None, | |
labels.tolist() if labels is not None else None, | |
bbox.tolist() if bbox is not None else None | |
) | |
mask_pred = mask_pred > 0 | |
except Exception as e: | |
print(f"Mask prediction error: {e}") | |
mask_pred = np.zeros(RESIZE_SIZE[::-1], dtype=bool) | |
else: | |
return output_text, None | |
mask_np = mask_pred | |
mask_img = PILImage.fromarray((mask_np * 255).astype(np.uint8)).resize(img_original.size) | |
mask_img = mask_img.convert("L") | |
mask_np = np.array(mask_img) > 128 | |
visualized_img = visualize_masks_on_image( | |
img_original, | |
masks_np=[mask_np], | |
alpha=0.6 | |
) | |
match = re.search(r'(<think>.*?</think>)', output_text, re.DOTALL) | |
if match: | |
output_text = match.group(1) | |
return output_text, visualized_img | |
except Exception as e: | |
print(f"Pipeline error: {e}") | |
return f"Error processing request: {str(e)}", None | |
def load_description(fp): | |
with open(fp, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
with gr.Blocks(title="Seg-R1") as demo: | |
# gr.Markdown("# Seg-R1") | |
# gr.Markdown("Upload an image and ask questions about segmentation.") | |
gr.HTML(load_description("assets/title.md")) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
text_input = gr.Textbox(lines=2, label="Question", placeholder="Ask about objects in the image...") | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
text_output = gr.Textbox(label="Model Response", interactive=False) | |
image_output = gr.Image(type="pil", label="Segmentation Result", interactive=False) | |
submit_btn.click( | |
fn=run_pipeline, | |
inputs=[image_input, text_input], | |
outputs=[text_output, image_output] | |
) | |
gr.Examples( | |
examples=[ | |
["imgs/camourflage1.jpg", "There is a creature hidden in its surroundings, segment it."], | |
["imgs/camourflage2.jpg", "Please segment the camouflaged object in this image."], | |
["imgs/dog_in_sheeps.jpg", "Find the one that suffers."], | |
["imgs/kind_lady.jpg", "Find the most uncommon part of this picture."], | |
["imgs/painting.jpg", "Identify and segment the man and the sky."], | |
["imgs/man_and_cat.jpg", "Identify and segment the cat and the glasses of the man."], | |
], | |
inputs=[image_input, text_input], | |
outputs=[text_output, image_output], | |
fn=run_pipeline, | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |