xinjie.wang commited on
Commit
0c94688
·
1 Parent(s): d31a703
Files changed (1) hide show
  1. app.py +497 -40
app.py CHANGED
@@ -1,44 +1,501 @@
1
- import gradio as gr
2
- import logging
3
-
4
- logging.basicConfig(level=logging.INFO)
5
- logger = logging.getLogger(__name__)
6
-
7
-
8
- from embodied_gen.utils.gpt_clients import GPT_CLIENT
9
-
10
- print(GPT_CLIENT.api_version, GPT_CLIENT.model_name, GPT_CLIENT.endpoint)
11
-
12
- def debug_gptclient(text_prompt, images, system_role):
13
- try:
14
- # Handle image input (Gradio passes images as PIL.Image or file paths)
15
- image_base64 = images if images else None
16
- response = GPT_CLIENT.query(
17
- text_prompt=text_prompt,
18
- image_base64=image_base64,
19
- system_role=system_role
20
- )
21
- return response if response else "No response received or an error occurred."
22
- except Exception as e:
23
- return f"Error: {str(e)}"
24
-
25
- # Create Gradio interface
26
- iface = gr.Interface(
27
- fn=debug_gptclient,
28
- inputs=[
29
- gr.Textbox(label="Text Prompt", placeholder="Enter your text prompt here"),
30
- gr.File(label="Images (Optional)", type="filepath", file_count="multiple"),
31
- gr.Textbox(
32
- label="System Role (Optional)",
33
- placeholder="Enter system role or leave empty for default",
34
- value="You are a highly knowledgeable assistant specializing in physics, engineering, and object properties."
35
- )
36
- ],
37
- outputs=gr.Textbox(label="Response"),
38
- title="GPTclient Debug Interface",
39
- description="A simple interface to debug GPTclient inputs and outputs."
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  if __name__ == "__main__":
44
- iface.launch()
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import os
19
+
20
+ os.environ["GRADIO_APP"] = "imageto3d"
21
+ from glob import glob
22
+
23
+ import gradio as gr
24
+ from common import (
25
+ MAX_SEED,
26
+ VERSION,
27
+ active_btn_by_content,
28
+ custom_theme,
29
+ end_session,
30
+ extract_3d_representations_v2,
31
+ extract_urdf,
32
+ get_seed,
33
+ image_css,
34
+ image_to_3d,
35
+ lighting_css,
36
+ preprocess_image_fn,
37
+ preprocess_sam_image_fn,
38
+ select_point,
39
+ start_session,
40
  )
41
 
42
+ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
43
+ gr.Markdown(
44
+ """
45
+ ## ***EmbodiedGen***: Image-to-3D Asset
46
+ **🔖 Version**: {VERSION}
47
+ <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
48
+ <a href="https://horizonrobotics.github.io/robot_lab/embodied_gen/index.html">
49
+ <img alt="🌐 Project Page" src="https://img.shields.io/badge/🌐-Project_Page-blue">
50
+ </a>
51
+ <a href="https://arxiv.org/abs/xxxx.xxxxx">
52
+ <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
53
+ </a>
54
+ <a href="https://github.com/HorizonRobotics/EmbodiedGen">
55
+ <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
56
+ </a>
57
+ <a href="https://www.youtube.com/watch?v=SnHhzHeb_aI">
58
+ <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
59
+ </a>
60
+ </p>
61
+
62
+ 🖼️ Generate physically plausible 3D asset from single input image.
63
+
64
+ """.format(
65
+ VERSION=VERSION
66
+ ),
67
+ elem_classes=["header"],
68
+ )
69
+
70
+ gr.HTML(image_css)
71
+ # gr.HTML(lighting_css)
72
+ with gr.Row():
73
+ with gr.Column(scale=2):
74
+ with gr.Tabs() as input_tabs:
75
+ with gr.Tab(
76
+ label="Image(auto seg)", id=0
77
+ ) as single_image_input_tab:
78
+ raw_image_cache = gr.Image(
79
+ format="png",
80
+ image_mode="RGB",
81
+ type="pil",
82
+ visible=False,
83
+ )
84
+ image_prompt = gr.Image(
85
+ label="Input Image",
86
+ format="png",
87
+ image_mode="RGBA",
88
+ type="pil",
89
+ height=400,
90
+ elem_classes=["image_fit"],
91
+ )
92
+ gr.Markdown(
93
+ """
94
+ If you are not satisfied with the auto segmentation
95
+ result, please switch to the `Image(SAM seg)` tab."""
96
+ )
97
+ with gr.Tab(
98
+ label="Image(SAM seg)", id=1
99
+ ) as samimage_input_tab:
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+ image_prompt_sam = gr.Image(
103
+ label="Input Image",
104
+ type="numpy",
105
+ height=400,
106
+ elem_classes=["image_fit"],
107
+ )
108
+ image_seg_sam = gr.Image(
109
+ label="SAM Seg Image",
110
+ image_mode="RGBA",
111
+ type="pil",
112
+ height=400,
113
+ visible=False,
114
+ )
115
+ with gr.Column(scale=1):
116
+ image_mask_sam = gr.AnnotatedImage(
117
+ elem_classes=["image_fit"]
118
+ )
119
+
120
+ fg_bg_radio = gr.Radio(
121
+ ["foreground_point", "background_point"],
122
+ label="Select foreground(green) or background(red) points, by default foreground", # noqa
123
+ value="foreground_point",
124
+ )
125
+ gr.Markdown(
126
+ """ Click the `Input Image` to select SAM points,
127
+ after get the satisified segmentation, click `Generate`
128
+ button to generate the 3D asset. \n
129
+ Note: If the segmented foreground is too small relative
130
+ to the entire image area, the generation will fail.
131
+ """
132
+ )
133
+
134
+ with gr.Accordion(label="Generation Settings", open=False):
135
+ with gr.Row():
136
+ seed = gr.Slider(
137
+ 0, MAX_SEED, label="Seed", value=0, step=1
138
+ )
139
+ texture_size = gr.Slider(
140
+ 1024,
141
+ 4096,
142
+ label="UV texture size",
143
+ value=2048,
144
+ step=256,
145
+ )
146
+ rmbg_tag = gr.Radio(
147
+ choices=["rembg", "rmbg14"],
148
+ value="rembg",
149
+ label="Background Removal Model",
150
+ )
151
+ with gr.Row():
152
+ randomize_seed = gr.Checkbox(
153
+ label="Randomize Seed", value=False
154
+ )
155
+ project_delight = gr.Checkbox(
156
+ label="Backproject delighting",
157
+ value=False,
158
+ )
159
+ gr.Markdown("Geo Structure Generation")
160
+ with gr.Row():
161
+ ss_guidance_strength = gr.Slider(
162
+ 0.0,
163
+ 10.0,
164
+ label="Guidance Strength",
165
+ value=7.5,
166
+ step=0.1,
167
+ )
168
+ ss_sampling_steps = gr.Slider(
169
+ 1, 50, label="Sampling Steps", value=12, step=1
170
+ )
171
+ gr.Markdown("Visual Appearance Generation")
172
+ with gr.Row():
173
+ slat_guidance_strength = gr.Slider(
174
+ 0.0,
175
+ 10.0,
176
+ label="Guidance Strength",
177
+ value=3.0,
178
+ step=0.1,
179
+ )
180
+ slat_sampling_steps = gr.Slider(
181
+ 1, 50, label="Sampling Steps", value=12, step=1
182
+ )
183
+
184
+ generate_btn = gr.Button(
185
+ "🚀 1. Generate(~0.5 mins)",
186
+ variant="primary",
187
+ interactive=False,
188
+ )
189
+ model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
190
+ with gr.Row():
191
+ extract_rep3d_btn = gr.Button(
192
+ "🔍 2. Extract 3D Representation(~2 mins)",
193
+ variant="primary",
194
+ interactive=False,
195
+ )
196
+ with gr.Accordion(
197
+ label="Enter Asset Attributes(optional)", open=False
198
+ ):
199
+ asset_cat_text = gr.Textbox(
200
+ label="Enter Asset Category (e.g., chair)"
201
+ )
202
+ height_range_text = gr.Textbox(
203
+ label="Enter **Height Range** in meter (e.g., 0.5-0.6)"
204
+ )
205
+ mass_range_text = gr.Textbox(
206
+ label="Enter **Mass Range** in kg (e.g., 1.1-1.2)"
207
+ )
208
+ asset_version_text = gr.Textbox(
209
+ label=f"Enter version (e.g., {VERSION})"
210
+ )
211
+ with gr.Row():
212
+ extract_urdf_btn = gr.Button(
213
+ "🧩 3. Extract URDF with physics(~1 mins)",
214
+ variant="primary",
215
+ interactive=False,
216
+ )
217
+ with gr.Row():
218
+ gr.Markdown(
219
+ "#### Estimated Asset 3D Attributes(No input required)"
220
+ )
221
+ with gr.Row():
222
+ est_type_text = gr.Textbox(
223
+ label="Asset category", interactive=False
224
+ )
225
+ est_height_text = gr.Textbox(
226
+ label="Real height(.m)", interactive=False
227
+ )
228
+ est_mass_text = gr.Textbox(
229
+ label="Mass(.kg)", interactive=False
230
+ )
231
+ est_mu_text = gr.Textbox(
232
+ label="Friction coefficient", interactive=False
233
+ )
234
+ with gr.Row():
235
+ download_urdf = gr.DownloadButton(
236
+ label="⬇️ 4. Download URDF",
237
+ variant="primary",
238
+ interactive=False,
239
+ )
240
+
241
+ gr.Markdown(
242
+ """ NOTE: If `Asset Attributes` are provided, the provided
243
+ properties will be used; otherwise, the GPT-preset properties
244
+ will be applied. \n
245
+ The `Download URDF` file is restored to the real scale and
246
+ has quality inspection, open with an editor to view details.
247
+ """
248
+ )
249
+
250
+ with gr.Row() as single_image_example:
251
+ examples = gr.Examples(
252
+ label="Image Gallery",
253
+ examples=[
254
+ [image_path]
255
+ for image_path in sorted(
256
+ glob("assets/example_image/*")
257
+ )
258
+ ],
259
+ inputs=[image_prompt, rmbg_tag],
260
+ fn=preprocess_image_fn,
261
+ outputs=[image_prompt, raw_image_cache],
262
+ run_on_click=True,
263
+ examples_per_page=10,
264
+ )
265
+
266
+ with gr.Row(visible=False) as single_sam_image_example:
267
+ examples = gr.Examples(
268
+ label="Image Gallery",
269
+ examples=[
270
+ [image_path]
271
+ for image_path in sorted(
272
+ glob("assets/example_image/*")
273
+ )
274
+ ],
275
+ inputs=[image_prompt_sam],
276
+ fn=preprocess_sam_image_fn,
277
+ outputs=[image_prompt_sam, raw_image_cache],
278
+ run_on_click=True,
279
+ examples_per_page=10,
280
+ )
281
+ with gr.Column(scale=1):
282
+ video_output = gr.Video(
283
+ label="Generated 3D Asset",
284
+ autoplay=True,
285
+ loop=True,
286
+ height=300,
287
+ )
288
+ model_output_gs = gr.Model3D(
289
+ label="Gaussian Representation", height=300, interactive=False
290
+ )
291
+ aligned_gs = gr.Textbox(visible=False)
292
+ gr.Markdown(
293
+ """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
294
+ )
295
+ with gr.Row():
296
+ model_output_mesh = gr.Model3D(
297
+ label="Mesh Representation",
298
+ height=300,
299
+ interactive=False,
300
+ clear_color=[0.8, 0.8, 0.8, 1],
301
+ elem_id="lighter_mesh",
302
+ )
303
+
304
+ is_samimage = gr.State(False)
305
+ output_buf = gr.State()
306
+ selected_points = gr.State(value=[])
307
+
308
+ demo.load(start_session)
309
+ demo.unload(end_session)
310
+
311
+ single_image_input_tab.select(
312
+ lambda: tuple(
313
+ [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
314
+ ),
315
+ outputs=[is_samimage, single_image_example, single_sam_image_example],
316
+ )
317
+ samimage_input_tab.select(
318
+ lambda: tuple(
319
+ [True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
320
+ ),
321
+ outputs=[is_samimage, single_sam_image_example, single_image_example],
322
+ )
323
+
324
+ image_prompt.upload(
325
+ preprocess_image_fn,
326
+ inputs=[image_prompt, rmbg_tag],
327
+ outputs=[image_prompt, raw_image_cache],
328
+ )
329
+ image_prompt.change(
330
+ lambda: tuple(
331
+ [
332
+ gr.Button(interactive=False),
333
+ gr.Button(interactive=False),
334
+ gr.Button(interactive=False),
335
+ None,
336
+ "",
337
+ None,
338
+ None,
339
+ "",
340
+ "",
341
+ "",
342
+ "",
343
+ "",
344
+ "",
345
+ "",
346
+ "",
347
+ ]
348
+ ),
349
+ outputs=[
350
+ extract_rep3d_btn,
351
+ extract_urdf_btn,
352
+ download_urdf,
353
+ model_output_gs,
354
+ aligned_gs,
355
+ model_output_mesh,
356
+ video_output,
357
+ asset_cat_text,
358
+ height_range_text,
359
+ mass_range_text,
360
+ asset_version_text,
361
+ est_type_text,
362
+ est_height_text,
363
+ est_mass_text,
364
+ est_mu_text,
365
+ ],
366
+ )
367
+ image_prompt.change(
368
+ active_btn_by_content,
369
+ inputs=image_prompt,
370
+ outputs=generate_btn,
371
+ )
372
+
373
+ image_prompt_sam.upload(
374
+ preprocess_sam_image_fn,
375
+ inputs=[image_prompt_sam],
376
+ outputs=[image_prompt_sam, raw_image_cache],
377
+ )
378
+ image_prompt_sam.change(
379
+ lambda: tuple(
380
+ [
381
+ gr.Button(interactive=False),
382
+ gr.Button(interactive=False),
383
+ gr.Button(interactive=False),
384
+ None,
385
+ None,
386
+ None,
387
+ "",
388
+ "",
389
+ "",
390
+ "",
391
+ "",
392
+ "",
393
+ "",
394
+ "",
395
+ None,
396
+ [],
397
+ ]
398
+ ),
399
+ outputs=[
400
+ extract_rep3d_btn,
401
+ extract_urdf_btn,
402
+ download_urdf,
403
+ model_output_gs,
404
+ model_output_mesh,
405
+ video_output,
406
+ asset_cat_text,
407
+ height_range_text,
408
+ mass_range_text,
409
+ asset_version_text,
410
+ est_type_text,
411
+ est_height_text,
412
+ est_mass_text,
413
+ est_mu_text,
414
+ image_mask_sam,
415
+ selected_points,
416
+ ],
417
+ )
418
+
419
+ image_prompt_sam.select(
420
+ select_point,
421
+ [
422
+ image_prompt_sam,
423
+ selected_points,
424
+ fg_bg_radio,
425
+ ],
426
+ [image_mask_sam, image_seg_sam],
427
+ )
428
+ image_seg_sam.change(
429
+ active_btn_by_content,
430
+ inputs=image_seg_sam,
431
+ outputs=generate_btn,
432
+ )
433
+
434
+ generate_btn.click(
435
+ get_seed,
436
+ inputs=[randomize_seed, seed],
437
+ outputs=[seed],
438
+ ).success(
439
+ image_to_3d,
440
+ inputs=[
441
+ image_prompt,
442
+ seed,
443
+ ss_guidance_strength,
444
+ ss_sampling_steps,
445
+ slat_guidance_strength,
446
+ slat_sampling_steps,
447
+ raw_image_cache,
448
+ image_seg_sam,
449
+ is_samimage,
450
+ ],
451
+ outputs=[output_buf, video_output],
452
+ ).success(
453
+ lambda: gr.Button(interactive=True),
454
+ outputs=[extract_rep3d_btn],
455
+ )
456
+
457
+ extract_rep3d_btn.click(
458
+ extract_3d_representations_v2,
459
+ inputs=[
460
+ output_buf,
461
+ project_delight,
462
+ texture_size,
463
+ ],
464
+ outputs=[
465
+ model_output_mesh,
466
+ model_output_gs,
467
+ model_output_obj,
468
+ aligned_gs,
469
+ ],
470
+ ).success(
471
+ lambda: gr.Button(interactive=True),
472
+ outputs=[extract_urdf_btn],
473
+ )
474
+
475
+ extract_urdf_btn.click(
476
+ extract_urdf,
477
+ inputs=[
478
+ aligned_gs,
479
+ model_output_obj,
480
+ asset_cat_text,
481
+ height_range_text,
482
+ mass_range_text,
483
+ asset_version_text,
484
+ ],
485
+ outputs=[
486
+ download_urdf,
487
+ est_type_text,
488
+ est_height_text,
489
+ est_mass_text,
490
+ est_mu_text,
491
+ ],
492
+ queue=True,
493
+ show_progress="full",
494
+ ).success(
495
+ lambda: gr.Button(interactive=True),
496
+ outputs=[download_urdf],
497
+ )
498
+
499
 
500
  if __name__ == "__main__":
501
+ demo.launch()