vvmnnnkv commited on
Commit
02a40a0
Β·
1 Parent(s): 2f285dd
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,34 @@
1
  ---
2
- title: Owlv2 Visual Prompt
3
- emoji: πŸŒ–
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.38.2
8
  app_file: app.py
9
- pinned: false
10
- short_description: OWLv2 zero-shot detection with visual prompt
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: OWLv2 Visual Prompt
3
+ short_description: OWLv2 zero-shot detection with visual prompt
4
+ emoji: πŸ‘€
 
5
  sdk: gradio
6
+ sdk_version: 4.44.1
7
  app_file: app.py
 
 
8
  ---
9
 
10
+ # OWLv2: Zero-shot detection with visual prompt πŸ‘€
11
+
12
+ This demo showcases the OWLv2 model's ability to perform zero-shot object detection using visual and text prompts.
13
+
14
+ You can either provide a text prompt or an image as a visual prompt to detect objects in the target image.
15
+
16
+ For visual prompting, following sample code is used, taken from the HF documentation:
17
+ ```python
18
+ processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
19
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
20
+
21
+ target_image = Image.open(...)
22
+ prompt_image = Image.open(...)
23
+ inputs = processor(images=target_image, query_images=prompt_image, return_tensors="pt")
24
+
25
+ # forward pass
26
+ with torch.no_grad():
27
+ outputs = model.image_guided_detection(**inputs)
28
+
29
+ target_sizes = torch.Tensor([image.size[::-1]])
30
+
31
+ results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes)
32
+ ```
33
+
34
+ For some reason, visual prompt works much worse than text, perhaps it's HF implementation issue.
app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ # Mock audio modules to avoid installing them
4
+ sys.modules["audioop"] = type("audioop", (), {"__file__": ""})()
5
+ sys.modules["pyaudioop"] = type("pyaudioop", (), {"__file__": ""})()
6
+
7
+ import torch
8
+ import gradio as gr
9
+ import supervision as sv
10
+ import spaces
11
+ from transformers import AutoProcessor, Owlv2ForObjectDetection
12
+
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ @spaces.GPU
16
+ def init_model(model_id):
17
+ processor = AutoProcessor.from_pretrained(model_id)
18
+ model = Owlv2ForObjectDetection.from_pretrained(model_id)
19
+ model.eval()
20
+ model.to(DEVICE)
21
+ return processor, model
22
+
23
+ @spaces.GPU
24
+ def inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type):
25
+ processor, model = init_model(model_id)
26
+
27
+ result = None
28
+ class_names = {}
29
+
30
+ if prompt_type == "Text":
31
+ inputs = processor(
32
+ images=target_image,
33
+ text=prompts["texts"],
34
+ return_tensors="pt"
35
+ ).to(DEVICE)
36
+
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+
40
+ target_sizes = torch.tensor([target_image.size[::-1]])
41
+ result = processor.post_process_grounded_object_detection(
42
+ outputs=outputs,
43
+ target_sizes=target_sizes,
44
+ threshold=conf_thresh
45
+ )[0]
46
+ class_names = {k: v for k, v in enumerate(prompts["texts"])}
47
+
48
+ elif prompt_type == "Visual":
49
+ inputs = processor(
50
+ images=target_image,
51
+ query_images=prompts["images"],
52
+ return_tensors="pt"
53
+ ).to(DEVICE)
54
+ with torch.no_grad():
55
+ outputs = model.image_guided_detection(**inputs)
56
+
57
+ # Post-process results
58
+ target_sizes = torch.tensor([target_image.size[::-1]])
59
+ result = processor.post_process_image_guided_detection(
60
+ outputs=outputs,
61
+ target_sizes=target_sizes,
62
+ threshold=conf_thresh,
63
+ nms_threshold=iou_thresh
64
+ )[0]
65
+
66
+ # prepare for supervision: add 0 label for all boxes
67
+ result['labels'] = torch.zeros(len(result['boxes']), dtype=torch.int64)
68
+ class_names = {0: "object"}
69
+
70
+ detections = sv.Detections.from_transformers(result, class_names)
71
+
72
+ resolution_wh = target_image.size
73
+ thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
74
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
75
+
76
+ labels = [
77
+ f"{class_name} {confidence:.2f}"
78
+ for class_name, confidence
79
+ in zip(detections['class_name'], detections.confidence)
80
+ ]
81
+
82
+ annotated_image = target_image.copy()
83
+ annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
84
+ scene=annotated_image, detections=detections)
85
+ annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
86
+ scene=annotated_image, detections=detections, labels=labels)
87
+
88
+ return annotated_image
89
+
90
+
91
+ def app():
92
+ with gr.Blocks():
93
+ with gr.Row():
94
+ with gr.Column():
95
+ with gr.Row():
96
+ target_image = gr.Image(type="pil", label="Target Image", visible=True, interactive=True)
97
+
98
+ detect_button = gr.Button(value="Detect Objects")
99
+ prompt_type = gr.State(value='Visual') # Default prompt type
100
+
101
+ with gr.Tab("Visual") as visual_tab:
102
+ with gr.Row():
103
+ prompt_image = gr.Image(type="pil", label="Prompt Image", visible=True, interactive=True)
104
+
105
+ with gr.Tab("Text") as text_tab:
106
+ texts = gr.Textbox(label="Input Texts", value='', placeholder='person,bus', visible=True, interactive=True)
107
+
108
+ visual_tab.select(
109
+ fn=lambda: ("Visual", gr.update(visible=True)),
110
+ inputs=None,
111
+ outputs=[prompt_type, prompt_image]
112
+ )
113
+
114
+ text_tab.select(
115
+ fn=lambda: ("Text", gr.update(value=None, visible=False)),
116
+ inputs=None,
117
+ outputs=[prompt_type, prompt_image]
118
+ )
119
+
120
+ model_id = gr.Dropdown(
121
+ label="Model",
122
+ choices=[
123
+ "google/owlv2-base-patch16-ensemble",
124
+ "google/owlv2-large-patch14"
125
+ ],
126
+ value="google/owlv2-base-patch16-ensemble",
127
+ )
128
+ conf_thresh = gr.Slider(
129
+ label="Confidence Threshold",
130
+ minimum=0.0,
131
+ maximum=1.0,
132
+ step=0.05,
133
+ value=0.25,
134
+ )
135
+ iou_thresh = gr.Slider(
136
+ label="IoU Threshold",
137
+ minimum=0.0,
138
+ maximum=1.0,
139
+ step=0.05,
140
+ value=0.70,
141
+ )
142
+
143
+ with gr.Column():
144
+ output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
145
+
146
+
147
+ def run_inference(prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type):
148
+ # add text/built-in prompts
149
+ if prompt_type == "Text":
150
+ texts = [text.strip() for text in texts.split(',')]
151
+ prompts = {
152
+ "texts": texts
153
+ }
154
+ # add visual prompt
155
+ elif prompt_type == "Visual":
156
+ prompts = {
157
+ "images": prompt_image,
158
+ }
159
+
160
+ return inference(prompts, target_image, model_id, conf_thresh, iou_thresh, prompt_type)
161
+
162
+ detect_button.click(
163
+ fn=run_inference,
164
+ inputs=[prompt_image, target_image, texts, model_id, conf_thresh, iou_thresh, prompt_type],
165
+ outputs=[output_image],
166
+ )
167
+
168
+ ###################### Examples ##########################
169
+ image_examples_list = [[
170
+ "test-data/target1.jpg",
171
+ "test-data/prompt1.jpg",
172
+ "google/owlv2-base-patch16-ensemble",
173
+ 0.9,
174
+ 0.3,
175
+ ],
176
+ [
177
+ "test-data/target2.jpg",
178
+ "test-data/prompt2.jpg",
179
+ "google/owlv2-base-patch16-ensemble",
180
+ 0.9,
181
+ 0.3,
182
+ ],
183
+ [
184
+ "test-data/target3.jpg",
185
+ "test-data/prompt3.jpg",
186
+ "google/owlv2-base-patch16-ensemble",
187
+ 0.9,
188
+ 0.3,
189
+ ],
190
+ [
191
+ "test-data/target4.jpg",
192
+ "test-data/prompt4.jpg",
193
+ "google/owlv2-base-patch16-ensemble",
194
+ 0.9,
195
+ 0.3,
196
+ ]
197
+ ]
198
+
199
+ text_examples = gr.Examples(
200
+ examples=[[
201
+ "test-data/target1.jpg",
202
+ "logo",
203
+ "google/owlv2-base-patch16-ensemble",
204
+ 0.3],
205
+ [
206
+ "test-data/target2.jpg",
207
+ "cat,remote",
208
+ "google/owlv2-base-patch16-ensemble",
209
+ 0.3],
210
+ [
211
+ "test-data/target3.jpg",
212
+ "frog,spider,lizard",
213
+ "google/owlv2-base-patch16-ensemble",
214
+ 0.3],
215
+ [
216
+ "test-data/target4.jpg",
217
+ "cat",
218
+ "google/owlv2-base-patch16-ensemble",
219
+ 0.3]
220
+ ],
221
+ inputs=[target_image, texts, model_id, conf_thresh],
222
+ visible=False, cache_examples=False, label="Text Prompt Examples")
223
+
224
+ image_examples = gr.Examples(
225
+ examples=image_examples_list,
226
+ inputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh],
227
+ visible=True, cache_examples=False, label="Box Visual Prompt Examples")
228
+
229
+ # Examples update
230
+ def update_text_examples():
231
+ return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.update(visible=False)
232
+
233
+ def update_visual_examples():
234
+ return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.update(visible=True)
235
+
236
+ text_tab.select(
237
+ fn=update_text_examples,
238
+ inputs=None,
239
+ outputs=[text_examples.dataset, image_examples.dataset, iou_thresh]
240
+ )
241
+
242
+ visual_tab.select(
243
+ fn=update_visual_examples,
244
+ inputs=None,
245
+ outputs=[text_examples.dataset, image_examples.dataset, iou_thresh]
246
+ )
247
+
248
+ return target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list
249
+
250
+ gradio_app = gr.Blocks()
251
+ with gradio_app:
252
+ gr.HTML(
253
+ """
254
+ <h1 style='text-align: center'>OWLv2: Zero-shot detection with visual prompt πŸ‘€</h1>
255
+ """)
256
+ gr.Markdown("""
257
+ This demo showcases the OWLv2 model's ability to perform zero-shot object detection using visual and text prompts.
258
+
259
+ You can either provide a text prompt or an image as a visual prompt to detect objects in the target image.
260
+
261
+ For visual prompting, following sample code is used, taken from the HF documentation:
262
+ ```python
263
+ processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
264
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
265
+
266
+ target_image = Image.open(...)
267
+ prompt_image = Image.open(...)
268
+ inputs = processor(images=target_image, query_images=prompt_image, return_tensors="pt")
269
+
270
+ # forward pass
271
+ with torch.no_grad():
272
+ outputs = model.image_guided_detection(**inputs)
273
+
274
+ target_sizes = torch.Tensor([image.size[::-1]])
275
+
276
+ results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes)
277
+ ```
278
+
279
+ For some reason, visual prompt works much worse than text, perhaps it's HF implementation issue.
280
+ """)
281
+
282
+ with gr.Row():
283
+ with gr.Column():
284
+ # Create a list of all UI components
285
+ ui_components = app()
286
+ # Unpack the components
287
+ target_image, prompt_image, model_id, conf_thresh, iou_thresh, image_examples_list = ui_components
288
+
289
+ gradio_app.load(
290
+ fn=lambda: image_examples_list[1],
291
+ outputs=[target_image, prompt_image, model_id, conf_thresh, iou_thresh]
292
+ )
293
+
294
+
295
+ if __name__ == '__main__':
296
+ gradio_app.launch(allowed_paths=["figures"])
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==4.44.1
2
+ gradio_client==1.3.0
3
+ supervision==0.26.1
4
+ transformers==4.53.2
test-data/prompt1.jpg ADDED

Git LFS Details

  • SHA256: e326e064a7d33d3d539faa554eaf52a831613808f69e5c000805dfa61df397aa
  • Pointer size: 130 Bytes
  • Size of remote file: 23 kB
test-data/prompt2.jpg ADDED

Git LFS Details

  • SHA256: 6ba641627c08ef424b7dc3e6cee069aa0dd49615a31c54ae0de1ecd1fabd9dea
  • Pointer size: 130 Bytes
  • Size of remote file: 97.6 kB
test-data/prompt3.jpg ADDED

Git LFS Details

  • SHA256: eb6d5e499f7f99925bafc47892233958eaa4aa32b6b7480ce823a13ef73b9c0e
  • Pointer size: 130 Bytes
  • Size of remote file: 86.1 kB
test-data/prompt4.jpg ADDED

Git LFS Details

  • SHA256: e7bada4545c89aa9a17ec5d4578a944bc3470c2e2dbec9ceae4764849f59e933
  • Pointer size: 130 Bytes
  • Size of remote file: 42.4 kB
test-data/target1.jpg ADDED

Git LFS Details

  • SHA256: aaa26d1e67cdebcfbcf06272051f2b0b187534f3da6941b1777f39be7cbf5ccc
  • Pointer size: 130 Bytes
  • Size of remote file: 31 kB
test-data/target2.jpg ADDED

Git LFS Details

  • SHA256: dea9e7ef97386345f7cff32f9055da4982da5471c48d575146c796ab4563b04e
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
test-data/target3.jpg ADDED

Git LFS Details

  • SHA256: bfc2f628115e44d2fecb92be08ac310f2961e8bcbf5db3b8552d85effcfdd3f3
  • Pointer size: 131 Bytes
  • Size of remote file: 331 kB
test-data/target4.jpg ADDED

Git LFS Details

  • SHA256: 4d8f81777acc59322bb774d730b8933c2d15f5849d72dc67ce5a6475f534a379
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB