geraskalnas xvjiarui commited on
Commit
d81dead
·
0 Parent(s):

Duplicate from xvjiarui/ODISE

Browse files

Co-authored-by: Jiarui Xu <xvjiarui@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ODISE
3
+ emoji: 🤗
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: true
10
+ duplicated_from: xvjiarui/ODISE
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
3
+ #
4
+ # This work is made available under the Nvidia Source Code License.
5
+ # To view a copy of this license, visit
6
+ # https://github.com/NVlabs/ODISE/blob/main/LICENSE
7
+ #
8
+ # Written by Jiarui Xu
9
+ # ------------------------------------------------------------------------------
10
+
11
+ import os
12
+ # token = os.environ["GITHUB_TOKEN"]
13
+ # os.system(f"pip install git+https://xvjiarui:{token}@github.com/xvjiarui/ODISE_NV.git")
14
+ os.system("pip install git+https://github.com/NVlabs/ODISE.git")
15
+ os.system("pip freeze")
16
+
17
+ import itertools
18
+ import json
19
+ from contextlib import ExitStack
20
+ import gradio as gr
21
+ import torch
22
+ from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES
23
+ from PIL import Image
24
+ from torch.cuda.amp import autocast
25
+
26
+ from detectron2.config import instantiate
27
+ from detectron2.data import MetadataCatalog
28
+ from detectron2.data import detection_utils as utils
29
+ from detectron2.data import transforms as T
30
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
31
+ from detectron2.evaluation import inference_context
32
+ from detectron2.utils.env import seed_all_rng
33
+ from detectron2.utils.logger import setup_logger
34
+ from detectron2.utils.visualizer import ColorMode, Visualizer, random_color
35
+
36
+ from odise import model_zoo
37
+ from odise.checkpoint import ODISECheckpointer
38
+ from odise.config import instantiate_odise
39
+ from odise.data import get_openseg_labels
40
+ from odise.modeling.wrapper import OpenPanopticInference
41
+ from odise.utils.file_io import ODISEHandler, PathManager
42
+ from odise.model_zoo.model_zoo import _ModelZooUrls
43
+
44
+ for k in ODISEHandler.URLS:
45
+ ODISEHandler.URLS[k] = ODISEHandler.URLS[k].replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
46
+ PathManager.register_handler(ODISEHandler())
47
+ _ModelZooUrls.PREFIX = _ModelZooUrls.PREFIX.replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
48
+
49
+ setup_logger()
50
+ logger = setup_logger(name="odise")
51
+
52
+ COCO_THING_CLASSES = [
53
+ label
54
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
55
+ if COCO_CATEGORIES[idx]["isthing"] == 1
56
+ ]
57
+ COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1]
58
+ COCO_STUFF_CLASSES = [
59
+ label
60
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
61
+ if COCO_CATEGORIES[idx]["isthing"] == 0
62
+ ]
63
+ COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0]
64
+
65
+ ADE_THING_CLASSES = [
66
+ label
67
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
68
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 1
69
+ ]
70
+ ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1]
71
+ ADE_STUFF_CLASSES = [
72
+ label
73
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
74
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 0
75
+ ]
76
+ ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0]
77
+
78
+ LVIS_CLASSES = get_openseg_labels("lvis_1203", True)
79
+ # use beautiful coco colors
80
+ LVIS_COLORS = list(
81
+ itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES))
82
+ )
83
+
84
+
85
+ class VisualizationDemo(object):
86
+ def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE):
87
+ """
88
+ Args:
89
+ model (nn.Module):
90
+ metadata (MetadataCatalog): image metadata.
91
+ instance_mode (ColorMode):
92
+ parallel (bool): whether to run the model in different processes from visualization.
93
+ Useful since the visualization logic can be slow.
94
+ """
95
+ self.model = model
96
+ self.metadata = metadata
97
+ self.aug = aug
98
+ self.cpu_device = torch.device("cpu")
99
+ self.instance_mode = instance_mode
100
+
101
+ def predict(self, original_image):
102
+ """
103
+ Args:
104
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
105
+
106
+ Returns:
107
+ predictions (dict):
108
+ the output of the model for one image only.
109
+ See :doc:`/tutorials/models` for details about the format.
110
+ """
111
+ height, width = original_image.shape[:2]
112
+ aug_input = T.AugInput(original_image, sem_seg=None)
113
+ self.aug(aug_input)
114
+ image = aug_input.image
115
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
116
+
117
+ inputs = {"image": image, "height": height, "width": width}
118
+ logger.info("forwarding")
119
+ with autocast():
120
+ predictions = self.model([inputs])[0]
121
+ logger.info("done")
122
+ return predictions
123
+
124
+ def run_on_image(self, image):
125
+ """
126
+ Args:
127
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
128
+ This is the format used by OpenCV.
129
+ Returns:
130
+ predictions (dict): the output of the model.
131
+ vis_output (VisImage): the visualized image output.
132
+ """
133
+ vis_output = None
134
+ predictions = self.predict(image)
135
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
136
+ if "panoptic_seg" in predictions:
137
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
138
+ vis_output = visualizer.draw_panoptic_seg(
139
+ panoptic_seg.to(self.cpu_device), segments_info
140
+ )
141
+ else:
142
+ if "sem_seg" in predictions:
143
+ vis_output = visualizer.draw_sem_seg(
144
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
145
+ )
146
+ if "instances" in predictions:
147
+ instances = predictions["instances"].to(self.cpu_device)
148
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
149
+
150
+ return predictions, vis_output
151
+
152
+
153
+ cfg = model_zoo.get_config("Panoptic/odise_label_coco_50e.py", trained=True)
154
+
155
+ cfg.model.overlap_threshold = 0
156
+ cfg.train.device = "cuda" if torch.cuda.is_available() else "cpu"
157
+ seed_all_rng(42)
158
+
159
+ dataset_cfg = cfg.dataloader.test
160
+ wrapper_cfg = cfg.dataloader.wrapper
161
+
162
+ aug = instantiate(dataset_cfg.mapper).augmentations
163
+
164
+ model = instantiate_odise(cfg.model)
165
+ model.to(torch.float16)
166
+ model.to(cfg.train.device)
167
+ ODISECheckpointer(model).load(cfg.train.init_checkpoint)
168
+
169
+
170
+ title = "ODISE"
171
+ description = """
172
+ <p style='text-align: center'> <a href='https://jerryxu.net/ODISE' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.04803' target='_blank'>Paper</a> | <a href='https://github.com/NVlabs/ODISE' target='_blank'>Code</a> | <a href='https://youtu.be/Su7p5KYmcII' target='_blank'>Video</a></p>
173
+
174
+ Gradio demo for ODISE: Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models. \n
175
+ You may click on of the examples or upload your own image. \n
176
+
177
+ ODISE could perform open vocabulary segmentation, you may input more classes (separate by comma).
178
+ The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class.
179
+ The first word will be displayed as the class name.
180
+ """ # noqa
181
+
182
+ article = """
183
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2303.04803' target='_blank'>Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models</a> | <a href='https://github.com/NVlab/ODISE' target='_blank'>Github Repo</a></p>
184
+ """ # noqa
185
+
186
+ examples = [
187
+ [
188
+ "demo/examples/coco.jpg",
189
+ "black pickup truck, pickup truck; blue sky, sky",
190
+ ["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
191
+ ],
192
+ [
193
+ "demo/examples/ade.jpg",
194
+ "luggage, suitcase, baggage;handbag",
195
+ ["ADE (150 categories)"],
196
+ ],
197
+ [
198
+ "demo/examples/ego4d.jpg",
199
+ "faucet, tap; kitchen paper, paper towels",
200
+ ["COCO (133 categories)"],
201
+ ],
202
+ ]
203
+
204
+
205
+ def build_demo_classes_and_metadata(vocab, label_list):
206
+ extra_classes = []
207
+
208
+ if vocab:
209
+ for words in vocab.split(";"):
210
+ extra_classes.append([word.strip() for word in words.split(",")])
211
+ extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))]
212
+
213
+ demo_thing_classes = extra_classes
214
+ demo_stuff_classes = []
215
+ demo_thing_colors = extra_colors
216
+ demo_stuff_colors = []
217
+
218
+ if any("COCO" in label for label in label_list):
219
+ demo_thing_classes += COCO_THING_CLASSES
220
+ demo_stuff_classes += COCO_STUFF_CLASSES
221
+ demo_thing_colors += COCO_THING_COLORS
222
+ demo_stuff_colors += COCO_STUFF_COLORS
223
+ if any("ADE" in label for label in label_list):
224
+ demo_thing_classes += ADE_THING_CLASSES
225
+ demo_stuff_classes += ADE_STUFF_CLASSES
226
+ demo_thing_colors += ADE_THING_COLORS
227
+ demo_stuff_colors += ADE_STUFF_COLORS
228
+ if any("LVIS" in label for label in label_list):
229
+ demo_thing_classes += LVIS_CLASSES
230
+ demo_thing_colors += LVIS_COLORS
231
+
232
+ MetadataCatalog.pop("odise_demo_metadata", None)
233
+ demo_metadata = MetadataCatalog.get("odise_demo_metadata")
234
+ demo_metadata.thing_classes = [c[0] for c in demo_thing_classes]
235
+ demo_metadata.stuff_classes = [
236
+ *demo_metadata.thing_classes,
237
+ *[c[0] for c in demo_stuff_classes],
238
+ ]
239
+ demo_metadata.thing_colors = demo_thing_colors
240
+ demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors
241
+ demo_metadata.stuff_dataset_id_to_contiguous_id = {
242
+ idx: idx for idx in range(len(demo_metadata.stuff_classes))
243
+ }
244
+ demo_metadata.thing_dataset_id_to_contiguous_id = {
245
+ idx: idx for idx in range(len(demo_metadata.thing_classes))
246
+ }
247
+
248
+ demo_classes = demo_thing_classes + demo_stuff_classes
249
+
250
+ return demo_classes, demo_metadata
251
+
252
+
253
+ def inference(image_path, vocab, label_list):
254
+
255
+ logger.info("building class names")
256
+ demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list)
257
+ with ExitStack() as stack:
258
+ inference_model = OpenPanopticInference(
259
+ model=model,
260
+ labels=demo_classes,
261
+ metadata=demo_metadata,
262
+ semantic_on=False,
263
+ instance_on=False,
264
+ panoptic_on=True,
265
+ )
266
+ stack.enter_context(inference_context(inference_model))
267
+ stack.enter_context(torch.no_grad())
268
+
269
+ demo = VisualizationDemo(inference_model, demo_metadata, aug)
270
+ img = utils.read_image(image_path, format="RGB")
271
+ _, visualized_output = demo.run_on_image(img)
272
+ return Image.fromarray(visualized_output.get_image())
273
+
274
+
275
+ with gr.Blocks(title=title) as demo:
276
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
277
+ gr.Markdown(description)
278
+ input_components = []
279
+ output_components = []
280
+
281
+ with gr.Row():
282
+ output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil")
283
+ output_components.append(output_image_gr)
284
+
285
+ with gr.Row().style(equal_height=True, mobile_collapse=True):
286
+ with gr.Column(scale=3, variant="panel") as input_component_column:
287
+ input_image_gr = gr.inputs.Image(type="filepath")
288
+ extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary")
289
+ category_list_gr = gr.inputs.CheckboxGroup(
290
+ choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
291
+ default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
292
+ label="Category to use",
293
+ )
294
+ input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr])
295
+
296
+ with gr.Column(scale=2):
297
+ examples_handler = gr.Examples(
298
+ examples=examples,
299
+ inputs=[c for c in input_components if not isinstance(c, gr.State)],
300
+ outputs=[c for c in output_components if not isinstance(c, gr.State)],
301
+ fn=inference,
302
+ cache_examples=torch.cuda.is_available(),
303
+ examples_per_page=5,
304
+ )
305
+ with gr.Row():
306
+ clear_btn = gr.Button("Clear")
307
+ submit_btn = gr.Button("Submit", variant="primary")
308
+
309
+ gr.Markdown(article)
310
+
311
+ submit_btn.click(
312
+ inference,
313
+ input_components,
314
+ output_components,
315
+ api_name="predict",
316
+ scroll_to_output=True,
317
+ )
318
+
319
+ clear_btn.click(
320
+ None,
321
+ [],
322
+ (input_components + output_components + [input_component_column]),
323
+ _js=f"""() => {json.dumps(
324
+ [component.cleared_value if hasattr(component, "cleared_value") else None
325
+ for component in input_components + output_components] + (
326
+ [gr.Column.update(visible=True)]
327
+ )
328
+ + ([gr.Column.update(visible=False)])
329
+ )}
330
+ """,
331
+ )
332
+
333
+ demo.launch()
demo/examples/ade.jpg ADDED
demo/examples/coco.jpg ADDED
demo/examples/ego4d.jpg ADDED
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libtinfo5
2
+ libsm6
3
+ libxext6
4
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch==1.13.1+cu116
3
+ torchvision==0.14.1+cu116
4
+ xformers==0.0.16
5
+ numpy==1.23.5
6
+ matplotlib==3.7.1
7
+ pillow==9.4.0