Prabhsimran09 commited on
Commit
943a1f9
·
verified ·
1 Parent(s): 4c1b7aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -30
app.py CHANGED
@@ -1,38 +1,65 @@
1
  import gradio as gr
2
- from transformers import YolosImageProcessor, YolosForObjectDetection
3
- from PIL import Image, ImageDraw
4
- import torch
5
 
6
- # Load YOLOS processor and model
7
- processor = YolosImageProcessor.from_pretrained("hustvl/yolos-small")
8
- model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")
9
 
10
- def detect_objects(image):
11
- # Preprocess the image
12
- inputs = processor(images=image, return_tensors="pt")
13
- outputs = model(**inputs)
14
 
15
- # Post-process outputs
16
- target_sizes = torch.tensor([image.size[::-1]])
17
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
18
 
19
- # Draw boxes on image
 
 
 
20
  draw = ImageDraw.Draw(image)
21
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
22
- box = [round(i, 2) for i in box.tolist()]
23
- label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
24
- draw.rectangle(box, outline="red", width=2)
25
- draw.text((box[0], box[1]), label_text, fill="red")
26
-
27
- return image
28
-
29
- # Gradio Interface
30
- iface = gr.Interface(
31
- fn=detect_objects,
32
- inputs=gr.Image(type="pil"),
33
- outputs=gr.Image(type="pil"),
34
- title="🟡 Object Detection with YOLOS",
35
- description="Upload an image to detect objects using YOLOS (You Only Look One-level Series)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image, ImageDraw, ImageFont
 
4
 
5
+ # Load the YOLOS object detection model
6
+ detector = pipeline("object-detection", model="hustvl/yolos-small")
 
7
 
8
+ # Define some colors to differentiate classes
9
+ COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]
 
 
10
 
11
+ # Helper function to assign color per label
12
+ def get_color_for_label(label):
13
+ return COLORS[hash(label) % len(COLORS)]
14
 
15
+ # Main function: detect, draw, and return outputs
16
+ def detect_and_draw(image, threshold):
17
+ results = detector(image)
18
+ image = image.convert("RGB")
19
  draw = ImageDraw.Draw(image)
20
+
21
+ try:
22
+ font = ImageFont.truetype("arial.ttf", 16)
23
+ except:
24
+ font = ImageFont.load_default()
25
+
26
+ annotations = []
27
+
28
+ for obj in results:
29
+ score = obj["score"]
30
+ if score < threshold:
31
+ continue
32
+
33
+ label = f"{obj['label']} ({score:.2f})"
34
+ box = obj["box"]
35
+ color = get_color_for_label(obj["label"])
36
+
37
+ draw.rectangle(
38
+ [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
39
+ outline=color,
40
+ width=3,
41
+ )
42
+
43
+ draw.text((box["xmin"] + 5, box["ymin"] + 5), label, fill=color, font=font)
44
+
45
+ box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
46
+ annotations.append((box_coords, label))
47
+
48
+ # Return the annotated image and annotations (no download option)
49
+ return image, annotations
50
+
51
+ # Gradio UI setup
52
+ demo = gr.Interface(
53
+ fn=detect_and_draw,
54
+ inputs=[
55
+ gr.Image(type="pil", label="Upload Image"),
56
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold"),
57
+ ],
58
+ outputs=[
59
+ gr.AnnotatedImage(label="Detected Image"),
60
+ ],
61
+ title="YOLOS Object Detection",
62
+ description="Upload an image to detect objects using the YOLOS-small model. Adjust the confidence threshold using the slider.",
63
  )
64
 
65
+ demo.launch()