Spaces:
Running
on
Zero
Running
on
Zero
initial
Browse files- .gitattributes +1 -0
- README.md +29 -8
- app.py +296 -0
- requirements.txt +4 -0
- test-data/prompt1.jpg +3 -0
- test-data/prompt2.jpg +3 -0
- test-data/prompt3.jpg +3 -0
- test-data/prompt4.jpg +3 -0
- test-data/target1.jpg +3 -0
- test-data/target2.jpg +3 -0
- test-data/target3.jpg +3 -0
- test-data/target4.jpg +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,34 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
short_description: OWLv2 zero-shot detection with visual prompt
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: OWLv2 Visual Prompt
|
3 |
+
short_description: OWLv2 zero-shot detection with visual prompt
|
4 |
+
emoji: π
|
|
|
5 |
sdk: gradio
|
6 |
+
sdk_version: 4.44.1
|
7 |
app_file: app.py
|
|
|
|
|
8 |
---
|
9 |
|
10 |
+
# OWLv2: Zero-shot detection with visual prompt π
|
11 |
+
|
12 |
+
This demo showcases the OWLv2 model's ability to perform zero-shot object detection using visual and text prompts.
|
13 |
+
|
14 |
+
You can either provide a text prompt or an image as a visual prompt to detect objects in the target image.
|
15 |
+
|
16 |
+
For visual prompting, following sample code is used, taken from the HF documentation:
|
17 |
+
```python
|
18 |
+
processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
|
19 |
+
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
|
20 |
+
|
21 |
+
target_image = Image.open(...)
|
22 |
+
prompt_image = Image.open(...)
|
23 |
+
inputs = processor(images=target_image, query_images=prompt_image, return_tensors="pt")
|
24 |
+
|
25 |
+
# forward pass
|
26 |
+
with torch.no_grad():
|
27 |
+
outputs = model.image_guided_detection(**inputs)
|
28 |
+
|
29 |
+
target_sizes = torch.Tensor([image.size[::-1]])
|
30 |
+
|
31 |
+
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes)
|
32 |
+
```
|
33 |
+
|
34 |
+
For some reason, visual prompt works much worse than text, perhaps it's HF implementation issue.
|
app.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
# Mock audio modules to avoid installing them
|
4 |
+
sys.modules["audioop"] = type("audioop", (), {"__file__": ""})()
|
5 |
+
sys.modules["pyaudioop"] = type("pyaudioop", (), {"__file__": ""})()
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import gradio as gr
|
9 |
+
import supervision as sv
|
10 |
+
import spaces
|
11 |
+
from transformers import AutoProcessor, Owlv2ForObjectDetection
|
12 |
+
|
13 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
@spaces.GPU
|
16 |
+
def init_model(model_id):
|
17 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
18 |
+
model = Owlv2ForObjectDetection.from_pretrained(model_id)
|
19 |
+
model.eval()
|
20 |
+
model.to(DEVICE)
|
21 |
+
return processor, model
|
22 |
+
|
23 |
+
@spaces.GPU
|
24 |
+
def inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type):
|
25 |
+
processor, model = init_model(model_id)
|
26 |
+
|
27 |
+
result = None
|
28 |
+
class_names = {}
|
29 |
+
|
30 |
+
if prompt_type == "Text":
|
31 |
+
inputs = processor(
|
32 |
+
images=target_image,
|
33 |
+
text=prompts["texts"],
|
34 |
+
return_tensors="pt"
|
35 |
+
).to(DEVICE)
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = model(**inputs)
|
39 |
+
|
40 |
+
target_sizes = torch.tensor([target_image.size[::-1]])
|
41 |
+
result = processor.post_process_grounded_object_detection(
|
42 |
+
outputs=outputs,
|
43 |
+
target_sizes=target_sizes,
|
44 |
+
threshold=conf_thresh
|
45 |
+
)[0]
|
46 |
+
class_names = {k: v for k, v in enumerate(prompts["texts"])}
|
47 |
+
|
48 |
+
elif prompt_type == "Visual":
|
49 |
+
inputs = processor(
|
50 |
+
images=target_image,
|
51 |
+
query_images=prompts["images"],
|
52 |
+
return_tensors="pt"
|
53 |
+
).to(DEVICE)
|
54 |
+
with torch.no_grad():
|
55 |
+
outputs = model.image_guided_detection(**inputs)
|
56 |
+
|
57 |
+
# Post-process results
|
58 |
+
target_sizes = torch.tensor([target_image.size[::-1]])
|
59 |
+
result = processor.post_process_image_guided_detection(
|
60 |
+
outputs=outputs,
|
61 |
+
target_sizes=target_sizes,
|
62 |
+
threshold=conf_thresh,
|
63 |
+
nms_threshold=iou_thresh
|
64 |
+
)[0]
|
65 |
+
|
66 |
+
# prepare for supervision: add 0 label for all boxes
|
67 |
+
result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64)
|
68 |
+
class_names = {0: "object"}
|
69 |
+
|
70 |
+
detections = sv.Detections.from_transformers(result, class_names)
|
71 |
+
|
72 |
+
resolution_wh = target_image.size
|
73 |
+
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
|
74 |
+
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
|
75 |
+
|
76 |
+
labels = [
|
77 |
+
f"{class_name} {confidence:.2f}"
|
78 |
+
for class_name, confidence
|
79 |
+
in zip(detections['class_name'], detections.confidence)
|
80 |
+
]
|
81 |
+
|
82 |
+
annotated_image = target_image.copy()
|
83 |
+
annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
|
84 |
+
scene=annotated_image, detections=detections)
|
85 |
+
annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
|
86 |
+
scene=annotated_image, detections=detections, labels=labels)
|
87 |
+
|
88 |
+
return annotated_image
|
89 |
+
|
90 |
+
|
91 |
+
def app():
|
92 |
+
with gr.Blocks():
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column():
|
95 |
+
with gr.Row():
|
96 |
+
target_image = gr.Image(type="pil", label="Target Image", visible=True, interactive=True)
|
97 |
+
|
98 |
+
detect_button = gr.Button(value="Detect Objects")
|
99 |
+
prompt_type = gr.State(value='Visual') # Default prompt type
|
100 |
+
|
101 |
+
with gr.Tab("Visual") as visual_tab:
|
102 |
+
with gr.Row():
|
103 |
+
prompt_image = gr.Image(type="pil", label="Prompt Image", visible=True, interactive=True)
|
104 |
+
|
105 |
+
with gr.Tab("Text") as text_tab:
|
106 |
+
texts = gr.Textbox(label="Input Texts", value='', placeholder='person,bus', visible=True, interactive=True)
|
107 |
+
|
108 |
+
visual_tab.select(
|
109 |
+
fn=lambda: ("Visual", gr.update(visible=True)),
|
110 |
+
inputs=None,
|
111 |
+
outputs=[prompt_type, prompt_image]
|
112 |
+
)
|
113 |
+
|
114 |
+
text_tab.select(
|
115 |
+
fn=lambda: ("Text", gr.update(value=None, visible=False)),
|
116 |
+
inputs=None,
|
117 |
+
outputs=[prompt_type, prompt_image]
|
118 |
+
)
|
119 |
+
|
120 |
+
model_id = gr.Dropdown(
|
121 |
+
label="Model",
|
122 |
+
choices=[
|
123 |
+
"google/owlv2-base-patch16-ensemble",
|
124 |
+
"google/owlv2-large-patch14"
|
125 |
+
],
|
126 |
+
value="google/owlv2-base-patch16-ensemble",
|
127 |
+
)
|
128 |
+
conf_thresh = gr.Slider(
|
129 |
+
label="Confidence Threshold",
|
130 |
+
minimum=0.0,
|
131 |
+
maximum=1.0,
|
132 |
+
step=0.05,
|
133 |
+
value=0.25,
|
134 |
+
)
|
135 |
+
iou_thresh = gr.Slider(
|
136 |
+
label="IoU Threshold",
|
137 |
+
minimum=0.0,
|
138 |
+
maximum=1.0,
|
139 |
+
step=0.05,
|
140 |
+
value=0.70,
|
141 |
+
)
|
142 |
+
|
143 |
+
with gr.Column():
|
144 |
+
output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
|
145 |
+
|
146 |
+
|
147 |
+
def run_inference(prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type):
|
148 |
+
# add text/built-in prompts
|
149 |
+
if prompt_type == "Text":
|
150 |
+
texts = [text.strip() for text in texts.split(',')]
|
151 |
+
prompts = {
|
152 |
+
"texts": texts
|
153 |
+
}
|
154 |
+
# add visual prompt
|
155 |
+
elif prompt_type == "Visual":
|
156 |
+
prompts = {
|
157 |
+
"images": prompt_image,
|
158 |
+
}
|
159 |
+
|
160 |
+
return inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type)
|
161 |
+
|
162 |
+
detect_button.click(
|
163 |
+
fn=run_inference,
|
164 |
+
inputs=[prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type],
|
165 |
+
outputs=[output_image],
|
166 |
+
)
|
167 |
+
|
168 |
+
###################### Examples ##########################
|
169 |
+
image_examples_list = [[
|
170 |
+
"test-data/target1.jpg",
|
171 |
+
"test-data/prompt1.jpg",
|
172 |
+
"google/owlv2-base-patch16-ensemble",
|
173 |
+
0.9,
|
174 |
+
0.3,
|
175 |
+
],
|
176 |
+
[
|
177 |
+
"test-data/target2.jpg",
|
178 |
+
"test-data/prompt2.jpg",
|
179 |
+
"google/owlv2-base-patch16-ensemble",
|
180 |
+
0.9,
|
181 |
+
0.3,
|
182 |
+
],
|
183 |
+
[
|
184 |
+
"test-data/target3.jpg",
|
185 |
+
"test-data/prompt3.jpg",
|
186 |
+
"google/owlv2-base-patch16-ensemble",
|
187 |
+
0.9,
|
188 |
+
0.3,
|
189 |
+
],
|
190 |
+
[
|
191 |
+
"test-data/target4.jpg",
|
192 |
+
"test-data/prompt4.jpg",
|
193 |
+
"google/owlv2-base-patch16-ensemble",
|
194 |
+
0.9,
|
195 |
+
0.3,
|
196 |
+
]
|
197 |
+
]
|
198 |
+
|
199 |
+
text_examples = gr.Examples(
|
200 |
+
examples=[[
|
201 |
+
"test-data/target1.jpg",
|
202 |
+
"logo",
|
203 |
+
"google/owlv2-base-patch16-ensemble",
|
204 |
+
0.3],
|
205 |
+
[
|
206 |
+
"test-data/target2.jpg",
|
207 |
+
"cat,remote",
|
208 |
+
"google/owlv2-base-patch16-ensemble",
|
209 |
+
0.3],
|
210 |
+
[
|
211 |
+
"test-data/target3.jpg",
|
212 |
+
"frog,spider,lizard",
|
213 |
+
"google/owlv2-base-patch16-ensemble",
|
214 |
+
0.3],
|
215 |
+
[
|
216 |
+
"test-data/target4.jpg",
|
217 |
+
"cat",
|
218 |
+
"google/owlv2-base-patch16-ensemble",
|
219 |
+
0.3]
|
220 |
+
],
|
221 |
+
inputs=[target_image, texts, model_id, conf_thresh],
|
222 |
+
visible=False, cache_examples=False, label="Text Prompt Examples")
|
223 |
+
|
224 |
+
image_examples = gr.Examples(
|
225 |
+
examples=image_examples_list,
|
226 |
+
inputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh],
|
227 |
+
visible=True, cache_examples=False, label="Box Visual Prompt Examples")
|
228 |
+
|
229 |
+
# Examples update
|
230 |
+
def update_text_examples():
|
231 |
+
return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.update(visible=False)
|
232 |
+
|
233 |
+
def update_visual_examples():
|
234 |
+
return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.update(visible=True)
|
235 |
+
|
236 |
+
text_tab.select(
|
237 |
+
fn=update_text_examples,
|
238 |
+
inputs=None,
|
239 |
+
outputs=[text_examples.dataset, image_examples.dataset, iou_thresh]
|
240 |
+
)
|
241 |
+
|
242 |
+
visual_tab.select(
|
243 |
+
fn=update_visual_examples,
|
244 |
+
inputs=None,
|
245 |
+
outputs=[text_examples.dataset, image_examples.dataset, iou_thresh]
|
246 |
+
)
|
247 |
+
|
248 |
+
return target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list
|
249 |
+
|
250 |
+
gradio_app = gr.Blocks()
|
251 |
+
with gradio_app:
|
252 |
+
gr.HTML(
|
253 |
+
"""
|
254 |
+
<h1 style='text-align: center'>OWLv2: Zero-shot detection with visual prompt π</h1>
|
255 |
+
""")
|
256 |
+
gr.Markdown("""
|
257 |
+
This demo showcases the OWLv2 model's ability to perform zero-shot object detection using visual and text prompts.
|
258 |
+
|
259 |
+
You can either provide a text prompt or an image as a visual prompt to detect objects in the target image.
|
260 |
+
|
261 |
+
For visual prompting, following sample code is used, taken from the HF documentation:
|
262 |
+
```python
|
263 |
+
processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
|
264 |
+
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
|
265 |
+
|
266 |
+
target_image = Image.open(...)
|
267 |
+
prompt_image = Image.open(...)
|
268 |
+
inputs = processor(images=target_image, query_images=prompt_image, return_tensors="pt")
|
269 |
+
|
270 |
+
# forward pass
|
271 |
+
with torch.no_grad():
|
272 |
+
outputs = model.image_guided_detection(**inputs)
|
273 |
+
|
274 |
+
target_sizes = torch.Tensor([image.size[::-1]])
|
275 |
+
|
276 |
+
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes)
|
277 |
+
```
|
278 |
+
|
279 |
+
For some reason, visual prompt works much worse than text, perhaps it's HF implementation issue.
|
280 |
+
""")
|
281 |
+
|
282 |
+
with gr.Row():
|
283 |
+
with gr.Column():
|
284 |
+
# Create a list of all UI components
|
285 |
+
ui_components = app()
|
286 |
+
# Unpack the components
|
287 |
+
target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list = ui_components
|
288 |
+
|
289 |
+
gradio_app.load(
|
290 |
+
fn=lambda: image_examples_list[1],
|
291 |
+
outputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh]
|
292 |
+
)
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == '__main__':
|
296 |
+
gradio_app.launch(allowed_paths=["figures"])
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.44.1
|
2 |
+
gradio_client==1.3.0
|
3 |
+
supervision==0.26.1
|
4 |
+
transformers==4.53.2
|
test-data/prompt1.jpg
ADDED
![]() |
Git LFS Details
|
test-data/prompt2.jpg
ADDED
![]() |
Git LFS Details
|
test-data/prompt3.jpg
ADDED
![]() |
Git LFS Details
|
test-data/prompt4.jpg
ADDED
![]() |
Git LFS Details
|
test-data/target1.jpg
ADDED
![]() |
Git LFS Details
|
test-data/target2.jpg
ADDED
![]() |
Git LFS Details
|
test-data/target3.jpg
ADDED
![]() |
Git LFS Details
|
test-data/target4.jpg
ADDED
![]() |
Git LFS Details
|