import os import subprocess import gradio as gr from PIL import Image as PILImage import torchvision.transforms.functional as TF import numpy as np import torch from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration from qwen_vl_utils import process_vision_info import re import io import base64 import cv2 from typing import List, Tuple, Optional import sys import spaces def add_sam2_to_path(): sam2_dir = os.path.abspath("third_party/sam2") if sam2_dir not in sys.path: sys.path.insert(0, sam2_dir) return sam2_dir def install_sam2(): sam2_dir = "third_party/sam2" if not os.path.exists(sam2_dir): print("Installing SAM2...") os.makedirs("third_party", exist_ok=True) subprocess.run([ "git", "clone", "--recursive", "https://github.com/facebookresearch/sam2.git", sam2_dir ], check=True) original_dir = os.getcwd() try: os.chdir(sam2_dir) subprocess.run(["pip", "install", "-e", "."], check=True) except Exception as e: print(f"Error during SAM2 installation: {str(e)}") raise finally: os.chdir(original_dir) print("✅ SAM2 installed successfully!") else: print("SAM2 already exists, skipping installation.") install_sam2() sam2_dir = add_sam2_to_path() from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor print("🎉 SAM2 modules imported successfully!") MODEL_PATH = "geshang/Seg-R1-7B" SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt" DEVICE = "cuda" #if torch.cuda.is_available() else "cpu" RESIZE_SIZE = (1024, 1024) try: model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, device_map="auto" if DEVICE == "cuda" else None ) processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True) print(f"Qwen model loaded on {DEVICE}") except Exception as e: print(f"Error loading Qwen model: {e}") model = None processor = None # SAM Wrapper class CustomSAMWrapper: def __init__(self, model_path: str, device: str = DEVICE): # try: self.device = torch.device(device) sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path, self.device) sam_model = sam_model.to(self.device) self.predictor = SAM2ImagePredictor(sam_model) self.last_mask = None print(f"SAM model loaded on {device}") # except Exception as e: # print(f"Error loading SAM model: {e}") # self.predictor = None def predict(self, image: PILImage.Image, points: List[Tuple[int, int]], labels: List[int], bbox: Optional[List[List[int]]] = None) -> Tuple[np.ndarray, float]: if not self.predictor: return np.zeros((image.height, image.width), dtype=bool), 0.0 try: input_points = np.array(points) if points else None input_labels = np.array(labels) if labels else None input_bboxes = np.array(bbox) if bbox else None image_np = np.array(image) rgb_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) self.predictor.set_image(rgb_image) mask_pred, score, logits = self.predictor.predict( point_coords=input_points, point_labels=input_labels, box=input_bboxes, multimask_output=False, ) self.last_mask = mask_pred[0] return mask_pred[0], score[0] except Exception as e: print(f"SAM prediction error: {e}") return np.zeros((image.height, image.width), dtype=bool), 0.0 def parse_custom_format(content: str): point_pattern = r"\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*" label_pattern = r"\s*(\[\s*(?:\d+\s*,?\s*)+\])\s*" bbox_pattern = r"\s*(\[\s*\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+\s*\])\s*" point_match = re.search(point_pattern, content) label_match = re.search(label_pattern, content) bbox_matches = re.findall(bbox_pattern, content) try: points = np.array(eval(point_match.group(1))) if point_match else None labels = np.array(eval(label_match.group(1))) if label_match else None if points is not None and labels is not None: if not (len(points.shape) == 2 and points.shape[1] == 2 and len(labels) == points.shape[0]): points, labels = None, None bboxes = [] for bbox_str in bbox_matches: bbox = np.array(eval(bbox_str)) if len(bbox.shape) == 1 and bbox.shape[0] == 4: bboxes.append(bbox) bboxes = np.stack(bboxes, axis=0) if bboxes else None return points, labels, bboxes except Exception as e: print("Error parsing content:", e) return None, None, None def prepare_test_messages(image, prompt): buffered = io.BytesIO() image = TF.resize(image, RESIZE_SIZE) image.save(buffered, format="JPEG") img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process should enclosed within tags, and the bounding box, points and points labels should be enclosed within , , and , respectively. i.e., " " reasoning process here [x1,y1,x2,y2], [[x3,y3],[x4,y4],...], [1,0,...]" "Where 1 indicates a foreground (object) point, and 0 indicates a background point." ) messages = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, { "role": "user", "content": [ {"type": "image", "image": f"data:image/jpeg;base64,{img_base64}"}, {"type": "text", "text": prompt}, ], }, ] return [messages] def answer_question(batch_messages): if not model or not processor: return ["Model not loaded. Please check logs."] try: text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] image_inputs, video_inputs = process_vision_info(batch_messages) inputs = processor(text=text, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True).to(DEVICE) outputs = model.generate(**inputs, use_cache=True, max_new_tokens=1024) trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)] return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) except Exception as e: print(f"Error generating answer: {e}") return ["Error generating response"] def visualize_masks_on_image( image: PILImage.Image, masks_np: list, colors=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (128, 128, 255)], alpha=0.5, ): if not masks_np: return image image_np = np.array(image) color_mask = np.zeros((image_np.shape[0], image_np.shape[1], 3), dtype=np.uint8) for i, mask in enumerate(masks_np): color = colors[i % len(colors)] mask = mask.astype(np.uint8) if mask.shape[:2] != image_np.shape[:2]: mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) color_mask[:, :, 0] = color_mask[:, :, 0] | (mask * color[0]) color_mask[:, :, 1] = color_mask[:, :, 1] | (mask * color[1]) color_mask[:, :, 2] = color_mask[:, :, 2] | (mask * color[2]) blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0) return PILImage.fromarray(blended) @spaces.GPU @torch.no_grad() def run_pipeline(image: PILImage.Image, prompt: str): sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE) if not model or not processor: return "Models not loaded. Please check logs.", None try: img_original = image.copy() img_resized = TF.resize(image, RESIZE_SIZE) messages = prepare_test_messages(img_resized, prompt) output_text = answer_question(messages)[0] print(f"Model output: {output_text}") points, labels, bbox = parse_custom_format(output_text) mask_pred = None final_mask = np.zeros(RESIZE_SIZE[::-1], dtype=bool) if (points is not None and labels is not None) or (bbox is not None): img = img_resized if bbox is not None and len(bbox.shape) == 2: for b in bbox: b = b.tolist() if points is not None and labels is not None: in_bbox_mask = ( (points[:, 0] >= b[0]) & (points[:, 0] <= b[2]) & (points[:, 1] >= b[1]) & (points[:, 1] <= b[3]) ) selected_points = points[in_bbox_mask] selected_labels = labels[in_bbox_mask] else: selected_points, selected_labels = None, None try: mask, _ = sam_wrapper.predict( img, selected_points.tolist() if selected_points is not None and len(selected_points) > 0 else None, selected_labels.tolist() if selected_labels is not None and len(selected_labels) > 0 else None, b ) final_mask |= (mask > 0) except Exception as e: print(f"Mask prediction error for bbox: {e}") continue mask_pred = final_mask else: try: mask_pred, _ = sam_wrapper.predict( img, points.tolist() if points is not None else None, labels.tolist() if labels is not None else None, bbox.tolist() if bbox is not None else None ) mask_pred = mask_pred > 0 except Exception as e: print(f"Mask prediction error: {e}") mask_pred = np.zeros(RESIZE_SIZE[::-1], dtype=bool) else: return output_text, None mask_np = mask_pred mask_img = PILImage.fromarray((mask_np * 255).astype(np.uint8)).resize(img_original.size) mask_img = mask_img.convert("L") mask_np = np.array(mask_img) > 128 visualized_img = visualize_masks_on_image( img_original, masks_np=[mask_np], alpha=0.6 ) match = re.search(r'(.*?)', output_text, re.DOTALL) if match: output_text = match.group(1) return output_text, visualized_img except Exception as e: print(f"Pipeline error: {e}") return f"Error processing request: {str(e)}", None def load_description(fp): with open(fp, 'r', encoding='utf-8') as f: content = f.read() return content with gr.Blocks(title="Seg-R1") as demo: # gr.Markdown("# Seg-R1") # gr.Markdown("Upload an image and ask questions about segmentation.") gr.HTML(load_description("assets/title.md")) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") text_input = gr.Textbox(lines=2, label="Question", placeholder="Ask about objects in the image...") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): text_output = gr.Textbox(label="Model Response", interactive=False) image_output = gr.Image(type="pil", label="Segmentation Result", interactive=False) submit_btn.click( fn=run_pipeline, inputs=[image_input, text_input], outputs=[text_output, image_output] ) gr.Examples( examples=[ ["imgs/camourflage1.jpg", "There is a creature hidden in its surroundings, segment it."], ["imgs/camourflage2.jpg", "Please segment the camouflaged object in this image."], ["imgs/dog_in_sheeps.jpg", "Find the one that suffers."], ["imgs/kind_lady.jpg", "Find the most uncommon part of this picture."], ["imgs/painting.jpg", "Identify and segment the man and the sky."], ["imgs/man_and_cat.jpg", "Identify and segment the cat and the glasses of the man."], ], inputs=[image_input, text_input], outputs=[text_output, image_output], fn=run_pipeline, cache_examples=True ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)