xinjie.wang commited on
Commit
8131b67
Β·
1 Parent(s): 33d9f9a
common.py CHANGED
@@ -55,9 +55,9 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
55
  from embodied_gen.utils.process_media import (
56
  filter_image_small_connected_components,
57
  merge_images_video,
58
- render_video,
59
  )
60
  from embodied_gen.utils.tags import VERSION
 
61
  from embodied_gen.validators.quality_checkers import (
62
  BaseChecker,
63
  ImageAestheticChecker,
@@ -94,9 +94,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
94
  os.environ["SPCONV_ALGO"] = "native"
95
 
96
  MAX_SEED = 100000
97
- DELIGHT = DelightingModel()
98
- IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
99
- # IMAGESR_MODEL = ImageStableSR()
100
 
101
 
102
  def patched_setup_functions(self):
@@ -136,6 +133,9 @@ def patched_setup_functions(self):
136
  Gaussian.setup_functions = patched_setup_functions
137
 
138
 
 
 
 
139
  if os.getenv("GRADIO_APP") == "imageto3d":
140
  RBG_REMOVER = RembgRemover()
141
  RBG14_REMOVER = BMGG14Remover()
 
55
  from embodied_gen.utils.process_media import (
56
  filter_image_small_connected_components,
57
  merge_images_video,
 
58
  )
59
  from embodied_gen.utils.tags import VERSION
60
+ from embodied_gen.utils.trender import render_video
61
  from embodied_gen.validators.quality_checkers import (
62
  BaseChecker,
63
  ImageAestheticChecker,
 
94
  os.environ["SPCONV_ALGO"] = "native"
95
 
96
  MAX_SEED = 100000
 
 
 
97
 
98
 
99
  def patched_setup_functions(self):
 
133
  Gaussian.setup_functions = patched_setup_functions
134
 
135
 
136
+ DELIGHT = DelightingModel()
137
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
138
+ # IMAGESR_MODEL = ImageStableSR()
139
  if os.getenv("GRADIO_APP") == "imageto3d":
140
  RBG_REMOVER = RembgRemover()
141
  RBG14_REMOVER = BMGG14Remover()
embodied_gen/data/backproject_v2.py CHANGED
@@ -251,6 +251,7 @@ class TextureBacker:
251
  during rendering. Defaults to 0.5.
252
  smooth_texture (bool, optional): If True, apply post-processing (e.g.,
253
  blurring) to the final texture. Defaults to True.
 
254
  """
255
 
256
  def __init__(
@@ -262,6 +263,7 @@ class TextureBacker:
262
  bake_angle_thresh: int = 75,
263
  mask_thresh: float = 0.5,
264
  smooth_texture: bool = True,
 
265
  ) -> None:
266
  self.camera_params = camera_params
267
  self.renderer = None
@@ -271,6 +273,7 @@ class TextureBacker:
271
  self.texture_wh = texture_wh
272
  self.mask_thresh = mask_thresh
273
  self.smooth_texture = smooth_texture
 
274
 
275
  self.bake_angle_thresh = bake_angle_thresh
276
  self.bake_unreliable_kernel_size = int(
@@ -446,11 +449,12 @@ class TextureBacker:
446
  def uv_inpaint(
447
  self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
448
  ) -> np.ndarray:
449
- vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
 
 
 
 
450
 
451
- texture, mask = _texture_inpaint_smooth(
452
- texture, mask, vertices, faces, uv_map
453
- )
454
  texture = texture.clip(0, 1)
455
  texture = cv2.inpaint(
456
  (texture * 255).astype(np.uint8),
 
251
  during rendering. Defaults to 0.5.
252
  smooth_texture (bool, optional): If True, apply post-processing (e.g.,
253
  blurring) to the final texture. Defaults to True.
254
+ inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
255
  """
256
 
257
  def __init__(
 
263
  bake_angle_thresh: int = 75,
264
  mask_thresh: float = 0.5,
265
  smooth_texture: bool = True,
266
+ inpaint_smooth: bool = False,
267
  ) -> None:
268
  self.camera_params = camera_params
269
  self.renderer = None
 
273
  self.texture_wh = texture_wh
274
  self.mask_thresh = mask_thresh
275
  self.smooth_texture = smooth_texture
276
+ self.inpaint_smooth = inpaint_smooth
277
 
278
  self.bake_angle_thresh = bake_angle_thresh
279
  self.bake_unreliable_kernel_size = int(
 
449
  def uv_inpaint(
450
  self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
451
  ) -> np.ndarray:
452
+ if self.inpaint_smooth:
453
+ vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
454
+ texture, mask = _texture_inpaint_smooth(
455
+ texture, mask, vertices, faces, uv_map
456
+ )
457
 
 
 
 
458
  texture = texture.clip(0, 1)
459
  texture = cv2.inpaint(
460
  (texture * 255).astype(np.uint8),
embodied_gen/data/datasets.py CHANGED
@@ -19,8 +19,9 @@ import json
19
  import logging
20
  import os
21
  import random
22
- from typing import Any, Callable, Dict, List, Tuple
23
 
 
24
  import torch
25
  import torch.utils.checkpoint
26
  from PIL import Image
@@ -36,6 +37,7 @@ logger = logging.getLogger(__name__)
36
 
37
  __all__ = [
38
  "Asset3dGenDataset",
 
39
  ]
40
 
41
 
@@ -222,6 +224,68 @@ class Asset3dGenDataset(Dataset):
222
  return data
223
 
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  if __name__ == "__main__":
226
  index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
227
  target_hw = (512, 512)
 
19
  import logging
20
  import os
21
  import random
22
+ from typing import Any, Callable, Dict, List, Literal, Tuple
23
 
24
+ import numpy as np
25
  import torch
26
  import torch.utils.checkpoint
27
  from PIL import Image
 
37
 
38
  __all__ = [
39
  "Asset3dGenDataset",
40
+ "PanoGSplatDataset",
41
  ]
42
 
43
 
 
224
  return data
225
 
226
 
227
+ class PanoGSplatDataset(Dataset):
228
+ """A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data.
229
+
230
+ This dataset is designed to be compatible with train and eval pipelines
231
+ that use COLMAP-style camera conventions.
232
+
233
+ Args:
234
+ data_dir (str): Root directory where the dataset file is located.
235
+ split (str): Dataset split to use, either "train" or "eval".
236
+ data_name (str, optional): Name of the dataset file (default: "gs_data.pt").
237
+ max_sample_num (int, optional): Maximum number of samples to load. If None,
238
+ all available samples in the split will be used.
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ data_dir: str,
244
+ split: str = Literal["train", "eval"],
245
+ data_name: str = "gs_data.pt",
246
+ max_sample_num: int = None,
247
+ ) -> None:
248
+ self.data_path = os.path.join(data_dir, data_name)
249
+ self.split = split
250
+ self.max_sample_num = max_sample_num
251
+ if not os.path.exists(self.data_path):
252
+ raise FileNotFoundError(
253
+ f"Dataset file {self.data_path} not found. Please provide the correct path."
254
+ )
255
+ self.data = torch.load(self.data_path, weights_only=False)
256
+ self.frames = self.data[split]
257
+ if max_sample_num is not None:
258
+ self.frames = self.frames[:max_sample_num]
259
+ self.points = self.data.get("points", None)
260
+ self.points_rgb = self.data.get("points_rgb", None)
261
+
262
+ def __len__(self) -> int:
263
+ return len(self.frames)
264
+
265
+ def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray:
266
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
267
+ tranformed_c2w = np.copy(c2w)
268
+ tranformed_c2w[:3, 1:3] *= -1
269
+
270
+ return tranformed_c2w
271
+
272
+ def __getitem__(self, index: int) -> dict[str, any]:
273
+ data = self.frames[index]
274
+ c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"])
275
+ item = dict(
276
+ camtoworld=c2w,
277
+ K=data["K"],
278
+ image_h=data["image_h"],
279
+ image_w=data["image_w"],
280
+ )
281
+ if "image" in data:
282
+ item["image"] = data["image"]
283
+ if "image_id" in data:
284
+ item["image_id"] = data["image_id"]
285
+
286
+ return item
287
+
288
+
289
  if __name__ == "__main__":
290
  index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
291
  target_hw = (512, 512)
embodied_gen/data/differentiable_render.py CHANGED
@@ -33,7 +33,6 @@ from tqdm import tqdm
33
  from embodied_gen.data.utils import (
34
  CameraSetting,
35
  DiffrastRender,
36
- RenderItems,
37
  as_list,
38
  calc_vertex_normals,
39
  import_kaolin_mesh,
@@ -42,6 +41,7 @@ from embodied_gen.data.utils import (
42
  render_pbr,
43
  save_images,
44
  )
 
45
 
46
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
47
  os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
@@ -470,7 +470,7 @@ def parse_args():
470
  "--pbr_light_factor",
471
  type=float,
472
  default=1.0,
473
- help="Light factor for mesh PBR rendering (default: 2.)",
474
  )
475
  parser.add_argument(
476
  "--with_mtl",
@@ -482,6 +482,11 @@ def parse_args():
482
  action="store_true",
483
  help="Whether to generate color .gif rendering file.",
484
  )
 
 
 
 
 
485
  parser.add_argument(
486
  "--gen_color_mp4",
487
  action="store_true",
@@ -568,7 +573,7 @@ def entrypoint(**kwargs) -> None:
568
  gen_viewnormal_mp4=args.gen_viewnormal_mp4,
569
  gen_glonormal_mp4=args.gen_glonormal_mp4,
570
  light_factor=args.pbr_light_factor,
571
- no_index_file=gen_video,
572
  )
573
  image_render.render_mesh(
574
  mesh_path=args.mesh_path,
 
33
  from embodied_gen.data.utils import (
34
  CameraSetting,
35
  DiffrastRender,
 
36
  as_list,
37
  calc_vertex_normals,
38
  import_kaolin_mesh,
 
41
  render_pbr,
42
  save_images,
43
  )
44
+ from embodied_gen.utils.enum import RenderItems
45
 
46
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
47
  os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
 
470
  "--pbr_light_factor",
471
  type=float,
472
  default=1.0,
473
+ help="Light factor for mesh PBR rendering (default: 1.)",
474
  )
475
  parser.add_argument(
476
  "--with_mtl",
 
482
  action="store_true",
483
  help="Whether to generate color .gif rendering file.",
484
  )
485
+ parser.add_argument(
486
+ "--no_index_file",
487
+ action="store_true",
488
+ help="Whether skip the index file saving.",
489
+ )
490
  parser.add_argument(
491
  "--gen_color_mp4",
492
  action="store_true",
 
573
  gen_viewnormal_mp4=args.gen_viewnormal_mp4,
574
  gen_glonormal_mp4=args.gen_glonormal_mp4,
575
  light_factor=args.pbr_light_factor,
576
+ no_index_file=gen_video or args.no_index_file,
577
  )
578
  image_render.render_mesh(
579
  mesh_path=args.mesh_path,
embodied_gen/data/mesh_operator.py CHANGED
@@ -395,6 +395,8 @@ class MeshFixer(object):
395
  self.vertices_np,
396
  np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
397
  )
 
 
398
  mesh = mesh.decimate(ratio, progress_bar=True)
399
 
400
  # Update vertices and faces
 
395
  self.vertices_np,
396
  np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
397
  )
398
+ mesh.clean(inplace=True)
399
+ mesh.clear_data()
400
  mesh = mesh.decimate(ratio, progress_bar=True)
401
 
402
  # Update vertices and faces
embodied_gen/data/utils.py CHANGED
@@ -38,7 +38,6 @@ except ImportError:
38
  ChatGLMModel = None
39
  import logging
40
  from dataclasses import dataclass, field
41
- from enum import Enum
42
 
43
  import trimesh
44
  from kaolin.render.camera import Camera
@@ -57,7 +56,6 @@ __all__ = [
57
  "load_mesh_to_unit_cube",
58
  "as_list",
59
  "CameraSetting",
60
- "RenderItems",
61
  "import_kaolin_mesh",
62
  "save_mesh_with_mtl",
63
  "get_images_from_grid",
@@ -160,8 +158,9 @@ class DiffrastRender(object):
160
 
161
  return normalized_maps
162
 
 
163
  def normalize_map_by_mask(
164
- self, map: torch.Tensor, mask: torch.Tensor
165
  ) -> torch.Tensor:
166
  # Normalize all maps in total by mask, normalized map in [0, 1].
167
  foreground = (mask == 1).squeeze(dim=-1)
@@ -738,18 +737,6 @@ class CameraSetting:
738
  self.Ks = Ks
739
 
740
 
741
- @dataclass
742
- class RenderItems(str, Enum):
743
- IMAGE = "image_color"
744
- ALPHA = "image_mask"
745
- VIEW_NORMAL = "image_view_normal"
746
- GLOBAL_NORMAL = "image_global_normal"
747
- POSITION_MAP = "image_position"
748
- DEPTH = "image_depth"
749
- ALBEDO = "image_albedo"
750
- DIFFUSE = "image_diffuse"
751
-
752
-
753
  def _compute_az_el_by_camera_params(
754
  camera_params: CameraSetting, flip_az: bool = False
755
  ):
 
38
  ChatGLMModel = None
39
  import logging
40
  from dataclasses import dataclass, field
 
41
 
42
  import trimesh
43
  from kaolin.render.camera import Camera
 
56
  "load_mesh_to_unit_cube",
57
  "as_list",
58
  "CameraSetting",
 
59
  "import_kaolin_mesh",
60
  "save_mesh_with_mtl",
61
  "get_images_from_grid",
 
158
 
159
  return normalized_maps
160
 
161
+ @staticmethod
162
  def normalize_map_by_mask(
163
+ map: torch.Tensor, mask: torch.Tensor
164
  ) -> torch.Tensor:
165
  # Normalize all maps in total by mask, normalized map in [0, 1].
166
  foreground = (mask == 1).squeeze(dim=-1)
 
737
  self.Ks = Ks
738
 
739
 
 
 
 
 
 
 
 
 
 
 
 
 
740
  def _compute_az_el_by_camera_params(
741
  camera_params: CameraSetting, flip_az: bool = False
742
  ):
embodied_gen/models/image_comm_model.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+ # Text-to-Image generation models from Hugging Face community.
17
+
18
+ import os
19
+ from abc import ABC, abstractmethod
20
+
21
+ import torch
22
+ from diffusers import (
23
+ ChromaPipeline,
24
+ Cosmos2TextToImagePipeline,
25
+ DPMSolverMultistepScheduler,
26
+ FluxPipeline,
27
+ KolorsPipeline,
28
+ StableDiffusion3Pipeline,
29
+ )
30
+ from diffusers.quantizers import PipelineQuantizationConfig
31
+ from huggingface_hub import snapshot_download
32
+ from PIL import Image
33
+ from transformers import AutoModelForCausalLM, SiglipProcessor
34
+
35
+ __all__ = [
36
+ "build_hf_image_pipeline",
37
+ ]
38
+
39
+
40
+ class BasePipelineLoader(ABC):
41
+ def __init__(self, device="cuda"):
42
+ self.device = device
43
+
44
+ @abstractmethod
45
+ def load(self):
46
+ pass
47
+
48
+
49
+ class BasePipelineRunner(ABC):
50
+ def __init__(self, pipe):
51
+ self.pipe = pipe
52
+
53
+ @abstractmethod
54
+ def run(self, prompt: str, **kwargs) -> Image.Image:
55
+ pass
56
+
57
+
58
+ # ===== SD3.5-medium =====
59
+ class SD35Loader(BasePipelineLoader):
60
+ def load(self):
61
+ pipe = StableDiffusion3Pipeline.from_pretrained(
62
+ "stabilityai/stable-diffusion-3.5-medium",
63
+ torch_dtype=torch.float16,
64
+ )
65
+ pipe = pipe.to(self.device)
66
+ pipe.enable_model_cpu_offload()
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+ pipe.enable_attention_slicing()
69
+ return pipe
70
+
71
+
72
+ class SD35Runner(BasePipelineRunner):
73
+ def run(self, prompt: str, **kwargs) -> Image.Image:
74
+ return self.pipe(prompt=prompt, **kwargs).images
75
+
76
+
77
+ # ===== Cosmos2 =====
78
+ class CosmosLoader(BasePipelineLoader):
79
+ def __init__(
80
+ self,
81
+ model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
82
+ local_dir="weights/cosmos2",
83
+ device="cuda",
84
+ ):
85
+ super().__init__(device)
86
+ self.model_id = model_id
87
+ self.local_dir = local_dir
88
+
89
+ def _patch(self):
90
+ def patch_model(cls):
91
+ orig = cls.from_pretrained
92
+
93
+ def new(*args, **kwargs):
94
+ kwargs.setdefault("attn_implementation", "flash_attention_2")
95
+ kwargs.setdefault("torch_dtype", torch.bfloat16)
96
+ return orig(*args, **kwargs)
97
+
98
+ cls.from_pretrained = new
99
+
100
+ def patch_processor(cls):
101
+ orig = cls.from_pretrained
102
+
103
+ def new(*args, **kwargs):
104
+ kwargs.setdefault("use_fast", True)
105
+ return orig(*args, **kwargs)
106
+
107
+ cls.from_pretrained = new
108
+
109
+ patch_model(AutoModelForCausalLM)
110
+ patch_processor(SiglipProcessor)
111
+
112
+ def load(self):
113
+ self._patch()
114
+ snapshot_download(
115
+ repo_id=self.model_id,
116
+ local_dir=self.local_dir,
117
+ local_dir_use_symlinks=False,
118
+ resume_download=True,
119
+ )
120
+
121
+ config = PipelineQuantizationConfig(
122
+ quant_backend="bitsandbytes_4bit",
123
+ quant_kwargs={
124
+ "load_in_4bit": True,
125
+ "bnb_4bit_quant_type": "nf4",
126
+ "bnb_4bit_compute_dtype": torch.bfloat16,
127
+ "bnb_4bit_use_double_quant": True,
128
+ },
129
+ components_to_quantize=["text_encoder", "transformer", "unet"],
130
+ )
131
+
132
+ pipe = Cosmos2TextToImagePipeline.from_pretrained(
133
+ self.model_id,
134
+ torch_dtype=torch.bfloat16,
135
+ quantization_config=config,
136
+ use_safetensors=True,
137
+ safety_checker=None,
138
+ requires_safety_checker=False,
139
+ ).to(self.device)
140
+ return pipe
141
+
142
+
143
+ class CosmosRunner(BasePipelineRunner):
144
+ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
145
+ return self.pipe(
146
+ prompt=prompt, negative_prompt=negative_prompt, **kwargs
147
+ ).images
148
+
149
+
150
+ # ===== Kolors =====
151
+ class KolorsLoader(BasePipelineLoader):
152
+ def load(self):
153
+ pipe = KolorsPipeline.from_pretrained(
154
+ "Kwai-Kolors/Kolors-diffusers",
155
+ torch_dtype=torch.float16,
156
+ variant="fp16",
157
+ ).to(self.device)
158
+ pipe.enable_model_cpu_offload()
159
+ pipe.enable_xformers_memory_efficient_attention()
160
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
161
+ pipe.scheduler.config, use_karras_sigmas=True
162
+ )
163
+ return pipe
164
+
165
+
166
+ class KolorsRunner(BasePipelineRunner):
167
+ def run(self, prompt: str, **kwargs) -> Image.Image:
168
+ return self.pipe(prompt=prompt, **kwargs).images
169
+
170
+
171
+ # ===== Flux =====
172
+ class FluxLoader(BasePipelineLoader):
173
+ def load(self):
174
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
175
+ pipe = FluxPipeline.from_pretrained(
176
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
177
+ )
178
+ pipe.enable_model_cpu_offload()
179
+ pipe.enable_xformers_memory_efficient_attention()
180
+ pipe.enable_attention_slicing()
181
+ return pipe.to(self.device)
182
+
183
+
184
+ class FluxRunner(BasePipelineRunner):
185
+ def run(self, prompt: str, **kwargs) -> Image.Image:
186
+ return self.pipe(prompt=prompt, **kwargs).images
187
+
188
+
189
+ # ===== Chroma =====
190
+ class ChromaLoader(BasePipelineLoader):
191
+ def load(self):
192
+ return ChromaPipeline.from_pretrained(
193
+ "lodestones/Chroma", torch_dtype=torch.bfloat16
194
+ ).to(self.device)
195
+
196
+
197
+ class ChromaRunner(BasePipelineRunner):
198
+ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
199
+ return self.pipe(
200
+ prompt=prompt, negative_prompt=negative_prompt, **kwargs
201
+ ).images
202
+
203
+
204
+ PIPELINE_REGISTRY = {
205
+ "sd35": (SD35Loader, SD35Runner),
206
+ "cosmos": (CosmosLoader, CosmosRunner),
207
+ "kolors": (KolorsLoader, KolorsRunner),
208
+ "flux": (FluxLoader, FluxRunner),
209
+ "chroma": (ChromaLoader, ChromaRunner),
210
+ }
211
+
212
+
213
+ def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
214
+ if name not in PIPELINE_REGISTRY:
215
+ raise ValueError(f"Unsupported model: {name}")
216
+ loader_cls, runner_cls = PIPELINE_REGISTRY[name]
217
+ pipe = loader_cls(device=device).load()
218
+
219
+ return runner_cls(pipe)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ model_name = "sd35"
224
+ runner = build_hf_image_pipeline(model_name)
225
+ # NOTE: Just for pipeline testing, generation quality at low resolution is poor.
226
+ images = runner.run(
227
+ prompt="A robot holding a sign that says 'Hello'",
228
+ height=512,
229
+ width=512,
230
+ num_inference_steps=10,
231
+ guidance_scale=6,
232
+ num_images_per_prompt=1,
233
+ )
234
+
235
+ for i, img in enumerate(images):
236
+ img.save(f"image_{model_name}_{i}.jpg")
embodied_gen/models/text_model.py CHANGED
@@ -52,6 +52,12 @@ __all__ = [
52
  "download_kolors_weights",
53
  ]
54
 
 
 
 
 
 
 
55
 
56
  def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
57
  logger.info(f"Download kolors weights from huggingface...")
@@ -179,8 +185,7 @@ def text2img_gen(
179
  ip_image_size: int = 512,
180
  seed: int = None,
181
  ) -> list[Image.Image]:
182
- prompt = "Single " + prompt + ", in the center of the image"
183
- prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
184
  logger.info(f"Processing prompt: {prompt}")
185
 
186
  generator = None
 
52
  "download_kolors_weights",
53
  ]
54
 
55
+ PROMPT_APPEND = (
56
+ "Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
57
+ "no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces"
58
+ )
59
+ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
60
+
61
 
62
  def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
63
  logger.info(f"Download kolors weights from huggingface...")
 
185
  ip_image_size: int = 512,
186
  seed: int = None,
187
  ) -> list[Image.Image]:
188
+ prompt = PROMPT_KAPPEND.format(object=prompt.strip())
 
189
  logger.info(f"Processing prompt: {prompt}")
190
 
191
  generator = None
embodied_gen/scripts/gen_scene3d.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import time
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from shutil import copy, rmtree
8
+
9
+ import torch
10
+ import tyro
11
+ from huggingface_hub import snapshot_download
12
+ from packaging import version
13
+
14
+ # Suppress warnings
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ logging.getLogger("transformers").setLevel(logging.ERROR)
17
+ logging.getLogger("diffusers").setLevel(logging.ERROR)
18
+
19
+ # TorchVision monkey patch for >0.16
20
+ if version.parse(torch.__version__) >= version.parse("0.16"):
21
+ import sys
22
+ import types
23
+
24
+ import torchvision.transforms.functional as TF
25
+
26
+ functional_tensor = types.ModuleType(
27
+ "torchvision.transforms.functional_tensor"
28
+ )
29
+ functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
30
+ sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
31
+
32
+ from gsplat.distributed import cli
33
+ from txt2panoimg import Text2360PanoramaImagePipeline
34
+ from embodied_gen.trainer.gsplat_trainer import (
35
+ DefaultStrategy,
36
+ GsplatTrainConfig,
37
+ )
38
+ from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
39
+ from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
40
+ from embodied_gen.utils.config import Pano2MeshSRConfig
41
+ from embodied_gen.utils.gaussian import restore_scene_scale_and_position
42
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
43
+ from embodied_gen.utils.log import logger
44
+ from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
45
+ from embodied_gen.validators.quality_checkers import (
46
+ PanoHeightEstimator,
47
+ PanoImageOccChecker,
48
+ )
49
+
50
+ __all__ = [
51
+ "generate_pano_image",
52
+ "entrypoint",
53
+ ]
54
+
55
+
56
+ @dataclass
57
+ class Scene3DGenConfig:
58
+ prompts: list[str] # Text desc of indoor room or style reference image.
59
+ output_dir: str
60
+ seed: int | None = None
61
+ real_height: float | None = None # The real height of the room in meters.
62
+ pano_image_only: bool = False
63
+ disable_pano_check: bool = False
64
+ keep_middle_result: bool = False
65
+ n_retry: int = 7
66
+ gs3d: GsplatTrainConfig = field(
67
+ default_factory=lambda: GsplatTrainConfig(
68
+ strategy=DefaultStrategy(verbose=True),
69
+ max_steps=4000,
70
+ init_opa=0.9,
71
+ opacity_reg=2e-3,
72
+ sh_degree=0,
73
+ means_lr=1e-4,
74
+ scales_lr=1e-3,
75
+ )
76
+ )
77
+
78
+
79
+ def generate_pano_image(
80
+ prompt: str,
81
+ output_path: str,
82
+ pipeline,
83
+ seed: int,
84
+ n_retry: int,
85
+ checker=None,
86
+ num_inference_steps: int = 40,
87
+ ) -> None:
88
+ for i in range(n_retry):
89
+ logger.info(
90
+ f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
91
+ )
92
+ if is_image_file(prompt):
93
+ raise NotImplementedError("Image mode not implemented yet.")
94
+ else:
95
+ txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
96
+ inputs = {
97
+ "prompt": txt_prompt,
98
+ "num_inference_steps": num_inference_steps,
99
+ "upscale": False,
100
+ "seed": seed,
101
+ }
102
+ pano_image = pipeline(inputs)
103
+
104
+ pano_image.save(output_path)
105
+ if checker is None:
106
+ break
107
+
108
+ flag, response = checker(pano_image)
109
+ logger.warning(f"{response}, image saved in {output_path}")
110
+ if flag is True or flag is None:
111
+ break
112
+
113
+ seed = random.randint(0, 100000)
114
+
115
+ return
116
+
117
+
118
+ def entrypoint(*args, **kwargs):
119
+ cfg = tyro.cli(Scene3DGenConfig)
120
+
121
+ # Init global models.
122
+ model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
123
+ IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
124
+ model_path, torch_dtype=torch.float16, device="cuda"
125
+ )
126
+ PANOMESH_CFG = Pano2MeshSRConfig()
127
+ PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
128
+ PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
129
+ PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
130
+
131
+ prompts = parse_text_prompts(cfg.prompts)
132
+ for idx, prompt in enumerate(prompts):
133
+ start_time = time.time()
134
+ output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
135
+ os.makedirs(output_dir, exist_ok=True)
136
+ pano_path = os.path.join(output_dir, "pano_image.png")
137
+ with open(f"{output_dir}/prompt.txt", "w") as f:
138
+ f.write(prompt)
139
+
140
+ generate_pano_image(
141
+ prompt,
142
+ pano_path,
143
+ IMG2PANO_PIPE,
144
+ cfg.seed if cfg.seed is not None else random.randint(0, 100000),
145
+ cfg.n_retry,
146
+ checker=None if cfg.disable_pano_check else PANO_CHECKER,
147
+ )
148
+
149
+ if cfg.pano_image_only:
150
+ continue
151
+
152
+ logger.info("GEN and REPAIR Mesh from Panorama...")
153
+ PANO2MESH_PIPE(pano_path, output_dir)
154
+
155
+ logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
156
+ cfg.gs3d.data_dir = output_dir
157
+ cfg.gs3d.result_dir = f"{output_dir}/gaussian"
158
+ cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
159
+ torch.set_default_device("cpu") # recover default setting.
160
+ cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
161
+
162
+ # Clean up the middle results.
163
+ gs_path = (
164
+ f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
165
+ )
166
+ copy(gs_path, f"{output_dir}/gs_model.ply")
167
+ video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
168
+ copy(video_path, f"{output_dir}/video.mp4")
169
+ gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
170
+ copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
171
+ if not cfg.keep_middle_result:
172
+ rmtree(cfg.gs3d.result_dir, ignore_errors=True)
173
+ os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
174
+
175
+ real_height = (
176
+ PANOHEIGHT_ESTOR(pano_path)
177
+ if cfg.real_height is None
178
+ else cfg.real_height
179
+ )
180
+ gs_path = os.path.join(output_dir, "gs_model.ply")
181
+ mesh_path = os.path.join(output_dir, "mesh_model.ply")
182
+ restore_scene_scale_and_position(real_height, mesh_path, gs_path)
183
+
184
+ elapsed_time = (time.time() - start_time) / 60
185
+ logger.info(
186
+ f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ entrypoint()
embodied_gen/scripts/imageto3d.py CHANGED
@@ -16,29 +16,28 @@
16
 
17
 
18
  import argparse
19
- import logging
20
  import os
 
21
  import sys
22
  from glob import glob
23
  from shutil import copy, copytree, rmtree
24
 
25
  import numpy as np
 
26
  import trimesh
27
  from PIL import Image
28
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
29
  from embodied_gen.data.utils import delete_dir, trellis_preprocess
30
  from embodied_gen.models.delight_model import DelightingModel
31
  from embodied_gen.models.gs_model import GaussianOperator
32
- from embodied_gen.models.segment_model import (
33
- BMGG14Remover,
34
- RembgRemover,
35
- SAMPredictor,
36
- )
37
  from embodied_gen.models.sr_model import ImageRealESRGAN
38
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
39
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
40
- from embodied_gen.utils.process_media import merge_images_video, render_video
 
41
  from embodied_gen.utils.tags import VERSION
 
42
  from embodied_gen.validators.quality_checkers import (
43
  BaseChecker,
44
  ImageAestheticChecker,
@@ -52,36 +51,25 @@ current_dir = os.path.dirname(current_file_path)
52
  sys.path.append(os.path.join(current_dir, "../.."))
53
  from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
54
 
55
- logging.basicConfig(
56
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
57
- )
58
- logger = logging.getLogger(__name__)
59
-
60
-
61
  os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
62
  "~/.cache/torch_extensions"
63
  )
64
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
65
  os.environ["SPCONV_ALGO"] = "native"
 
66
 
67
-
68
  DELIGHT = DelightingModel()
69
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
70
-
71
  RBG_REMOVER = RembgRemover()
72
- RBG14_REMOVER = BMGG14Remover()
73
- SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
74
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
75
  "microsoft/TRELLIS-image-large"
76
  )
77
- PIPELINE.cuda()
78
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
79
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
80
  AESTHETIC_CHECKER = ImageAestheticChecker()
81
  CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
82
- TMP_DIR = os.path.join(
83
- os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
84
- )
85
 
86
 
87
  def parse_args():
@@ -95,7 +83,6 @@ def parse_args():
95
  parser.add_argument(
96
  "--output_root",
97
  type=str,
98
- required=True,
99
  help="Root directory for saving outputs.",
100
  )
101
  parser.add_argument(
@@ -110,12 +97,26 @@ def parse_args():
110
  default=None,
111
  help="The mass in kg to restore the mesh real weight.",
112
  )
113
- parser.add_argument("--asset_type", type=str, default=None)
114
  parser.add_argument("--skip_exists", action="store_true")
115
- parser.add_argument("--strict_seg", action="store_true")
116
  parser.add_argument("--version", type=str, default=VERSION)
117
- parser.add_argument("--remove_intermediate", type=bool, default=True)
118
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  assert (
121
  args.image_path or args.image_root
@@ -125,13 +126,7 @@ def parse_args():
125
  args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
126
  args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
127
 
128
- return args
129
-
130
-
131
- if __name__ == "__main__":
132
- args = parse_args()
133
-
134
- for image_path in args.image_path:
135
  try:
136
  filename = os.path.basename(image_path).split(".")[0]
137
  output_root = args.output_root
@@ -141,7 +136,7 @@ if __name__ == "__main__":
141
 
142
  mesh_out = f"{output_root}/{filename}.obj"
143
  if args.skip_exists and os.path.exists(mesh_out):
144
- logger.info(
145
  f"Skip {image_path}, already processed in {mesh_out}"
146
  )
147
  continue
@@ -149,67 +144,84 @@ if __name__ == "__main__":
149
  image = Image.open(image_path)
150
  image.save(f"{output_root}/{filename}_raw.png")
151
 
152
- # Segmentation: Get segmented image using SAM or Rembg.
153
  seg_path = f"{output_root}/{filename}_cond.png"
154
- if image.mode != "RGBA":
155
- seg_image = RBG_REMOVER(image, save_path=seg_path)
156
- seg_image = trellis_preprocess(seg_image)
157
- else:
158
- seg_image = image
159
- seg_image.save(seg_path)
160
-
161
- # Run the pipeline
162
- try:
163
- outputs = PIPELINE.run(
164
- seg_image,
165
- preprocess_image=False,
166
- # Optional parameters
167
- # seed=1,
168
- # sparse_structure_sampler_params={
169
- # "steps": 12,
170
- # "cfg_strength": 7.5,
171
- # },
172
- # slat_sampler_params={
173
- # "steps": 12,
174
- # "cfg_strength": 3,
175
- # },
176
  )
177
- except Exception as e:
178
- logger.error(
179
- f"[Pipeline Failed] process {image_path}: {e}, skip."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
181
- continue
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Render and save color and mesh videos
184
- gs_model = outputs["gaussian"][0]
185
- mesh_model = outputs["mesh"][0]
186
  color_images = render_video(gs_model)["color"]
187
  normal_images = render_video(mesh_model)["normal"]
188
  video_path = os.path.join(output_root, "gs_mesh.mp4")
189
  merge_images_video(color_images, normal_images, video_path)
190
 
191
- # Save the raw Gaussian model
192
- gs_path = mesh_out.replace(".obj", "_gs.ply")
193
- gs_model.save_ply(gs_path)
194
-
195
- # Rotate mesh and GS by 90 degrees around Z-axis.
196
- rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
197
- gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
198
- mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
199
-
200
- # Addtional rotation for GS to align mesh.
201
- gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
202
- pose = GaussianOperator.trans_to_quatpose(gs_rot)
203
- aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
204
- GaussianOperator.resave_ply(
205
- in_ply=gs_path,
206
- out_ply=aligned_gs_path,
207
- instance_pose=pose,
208
- device="cpu",
209
- )
210
- color_path = os.path.join(output_root, "color.png")
211
- render_gs_api(aligned_gs_path, color_path)
212
-
213
  mesh = trimesh.Trimesh(
214
  vertices=mesh_model.vertices.cpu().numpy(),
215
  faces=mesh_model.faces.cpu().numpy(),
@@ -249,8 +261,8 @@ if __name__ == "__main__":
249
  min_mass, max_mass = map(float, args.mass_range.split("-"))
250
  asset_attrs["min_mass"] = min_mass
251
  asset_attrs["max_mass"] = max_mass
252
- if args.asset_type:
253
- asset_attrs["category"] = args.asset_type
254
  if args.version:
255
  asset_attrs["version"] = args.version
256
 
@@ -289,8 +301,8 @@ if __name__ == "__main__":
289
  ]
290
  images_list.append(images)
291
 
292
- results = BaseChecker.validate(CHECKERS, images_list)
293
- urdf_convertor.add_quality_tag(urdf_path, results)
294
 
295
  # Organize the final result files
296
  result_dir = f"{output_root}/result"
@@ -303,7 +315,7 @@ if __name__ == "__main__":
303
  f"{result_dir}/{urdf_convertor.output_mesh_dir}",
304
  )
305
  copy(video_path, f"{result_dir}/video.mp4")
306
- if args.remove_intermediate:
307
  delete_dir(output_root, keep_subs=["result"])
308
 
309
  except Exception as e:
@@ -311,3 +323,7 @@ if __name__ == "__main__":
311
  continue
312
 
313
  logger.info(f"Processing complete. Outputs saved to {args.output_root}")
 
 
 
 
 
16
 
17
 
18
  import argparse
 
19
  import os
20
+ import random
21
  import sys
22
  from glob import glob
23
  from shutil import copy, copytree, rmtree
24
 
25
  import numpy as np
26
+ import torch
27
  import trimesh
28
  from PIL import Image
29
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
30
  from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
  from embodied_gen.models.delight_model import DelightingModel
32
  from embodied_gen.models.gs_model import GaussianOperator
33
+ from embodied_gen.models.segment_model import RembgRemover
 
 
 
 
34
  from embodied_gen.models.sr_model import ImageRealESRGAN
35
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
36
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
37
+ from embodied_gen.utils.log import logger
38
+ from embodied_gen.utils.process_media import merge_images_video
39
  from embodied_gen.utils.tags import VERSION
40
+ from embodied_gen.utils.trender import render_video
41
  from embodied_gen.validators.quality_checkers import (
42
  BaseChecker,
43
  ImageAestheticChecker,
 
51
  sys.path.append(os.path.join(current_dir, "../.."))
52
  from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
53
 
 
 
 
 
 
 
54
  os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
55
  "~/.cache/torch_extensions"
56
  )
57
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
58
  os.environ["SPCONV_ALGO"] = "native"
59
+ random.seed(0)
60
 
61
+ logger.info("Loading Models...")
62
  DELIGHT = DelightingModel()
63
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
 
64
  RBG_REMOVER = RembgRemover()
 
 
65
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
66
  "microsoft/TRELLIS-image-large"
67
  )
68
+ # PIPELINE.cuda()
69
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
70
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
71
  AESTHETIC_CHECKER = ImageAestheticChecker()
72
  CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
 
 
 
73
 
74
 
75
  def parse_args():
 
83
  parser.add_argument(
84
  "--output_root",
85
  type=str,
 
86
  help="Root directory for saving outputs.",
87
  )
88
  parser.add_argument(
 
97
  default=None,
98
  help="The mass in kg to restore the mesh real weight.",
99
  )
100
+ parser.add_argument("--asset_type", type=str, nargs="+", default=None)
101
  parser.add_argument("--skip_exists", action="store_true")
 
102
  parser.add_argument("--version", type=str, default=VERSION)
103
+ parser.add_argument("--keep_intermediate", action="store_true")
104
+ parser.add_argument("--seed", type=int, default=0)
105
+ parser.add_argument(
106
+ "--n_retry",
107
+ type=int,
108
+ default=2,
109
+ )
110
+ args, unknown = parser.parse_known_args()
111
+
112
+ return args
113
+
114
+
115
+ def entrypoint(**kwargs):
116
+ args = parse_args()
117
+ for k, v in kwargs.items():
118
+ if hasattr(args, k) and v is not None:
119
+ setattr(args, k, v)
120
 
121
  assert (
122
  args.image_path or args.image_root
 
126
  args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
127
  args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
128
 
129
+ for idx, image_path in enumerate(args.image_path):
 
 
 
 
 
 
130
  try:
131
  filename = os.path.basename(image_path).split(".")[0]
132
  output_root = args.output_root
 
136
 
137
  mesh_out = f"{output_root}/{filename}.obj"
138
  if args.skip_exists and os.path.exists(mesh_out):
139
+ logger.warning(
140
  f"Skip {image_path}, already processed in {mesh_out}"
141
  )
142
  continue
 
144
  image = Image.open(image_path)
145
  image.save(f"{output_root}/{filename}_raw.png")
146
 
147
+ # Segmentation: Get segmented image using Rembg.
148
  seg_path = f"{output_root}/{filename}_cond.png"
149
+ seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
150
+ seg_image = trellis_preprocess(seg_image)
151
+ seg_image.save(seg_path)
152
+
153
+ seed = args.seed
154
+ for try_idx in range(args.n_retry):
155
+ logger.info(
156
+ f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
+ # Run the pipeline
159
+ try:
160
+ PIPELINE.cuda()
161
+ outputs = PIPELINE.run(
162
+ seg_image,
163
+ preprocess_image=False,
164
+ seed=(
165
+ random.randint(0, 100000) if seed is None else seed
166
+ ),
167
+ # Optional parameters
168
+ # sparse_structure_sampler_params={
169
+ # "steps": 12,
170
+ # "cfg_strength": 7.5,
171
+ # },
172
+ # slat_sampler_params={
173
+ # "steps": 12,
174
+ # "cfg_strength": 3,
175
+ # },
176
+ )
177
+ PIPELINE.cpu()
178
+ torch.cuda.empty_cache()
179
+ except Exception as e:
180
+ logger.error(
181
+ f"[Pipeline Failed] process {image_path}: {e}, skip."
182
+ )
183
+ continue
184
+
185
+ gs_model = outputs["gaussian"][0]
186
+ mesh_model = outputs["mesh"][0]
187
+
188
+ # Save the raw Gaussian model
189
+ gs_path = mesh_out.replace(".obj", "_gs.ply")
190
+ gs_model.save_ply(gs_path)
191
+
192
+ # Rotate mesh and GS by 90 degrees around Z-axis.
193
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
194
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
195
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
196
+
197
+ # Addtional rotation for GS to align mesh.
198
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
199
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
200
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
201
+ GaussianOperator.resave_ply(
202
+ in_ply=gs_path,
203
+ out_ply=aligned_gs_path,
204
+ instance_pose=pose,
205
+ device="cpu",
206
  )
207
+ color_path = os.path.join(output_root, "color.png")
208
+ render_gs_api(aligned_gs_path, color_path)
209
+
210
+ geo_flag, geo_result = GEO_CHECKER([color_path])
211
+ logger.warning(
212
+ f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
213
+ )
214
+ if geo_flag is True or geo_flag is None:
215
+ break
216
+
217
+ seed = random.randint(0, 100000) if seed is not None else None
218
 
219
+ # Render the video for generated 3D asset.
 
 
220
  color_images = render_video(gs_model)["color"]
221
  normal_images = render_video(mesh_model)["normal"]
222
  video_path = os.path.join(output_root, "gs_mesh.mp4")
223
  merge_images_video(color_images, normal_images, video_path)
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  mesh = trimesh.Trimesh(
226
  vertices=mesh_model.vertices.cpu().numpy(),
227
  faces=mesh_model.faces.cpu().numpy(),
 
261
  min_mass, max_mass = map(float, args.mass_range.split("-"))
262
  asset_attrs["min_mass"] = min_mass
263
  asset_attrs["max_mass"] = max_mass
264
+ if isinstance(args.asset_type, list) and args.asset_type[idx]:
265
+ asset_attrs["category"] = args.asset_type[idx]
266
  if args.version:
267
  asset_attrs["version"] = args.version
268
 
 
301
  ]
302
  images_list.append(images)
303
 
304
+ qa_results = BaseChecker.validate(CHECKERS, images_list)
305
+ urdf_convertor.add_quality_tag(urdf_path, qa_results)
306
 
307
  # Organize the final result files
308
  result_dir = f"{output_root}/result"
 
315
  f"{result_dir}/{urdf_convertor.output_mesh_dir}",
316
  )
317
  copy(video_path, f"{result_dir}/video.mp4")
318
+ if not args.keep_intermediate:
319
  delete_dir(output_root, keep_subs=["result"])
320
 
321
  except Exception as e:
 
323
  continue
324
 
325
  logger.info(f"Processing complete. Outputs saved to {args.output_root}")
326
+
327
+
328
+ if __name__ == "__main__":
329
+ entrypoint()
embodied_gen/scripts/text2image.py CHANGED
@@ -31,6 +31,7 @@ from embodied_gen.models.text_model import (
31
  build_text2img_pipeline,
32
  text2img_gen,
33
  )
 
34
 
35
  logging.basicConfig(level=logging.INFO)
36
  logger = logging.getLogger(__name__)
@@ -85,7 +86,7 @@ def parse_args():
85
  parser.add_argument(
86
  "--seed",
87
  type=int,
88
- default=0,
89
  )
90
  args = parser.parse_args()
91
 
@@ -101,14 +102,7 @@ def entrypoint(
101
  if hasattr(args, k) and v is not None:
102
  setattr(args, k, v)
103
 
104
- prompts = args.prompts
105
- if len(prompts) == 1 and prompts[0].endswith(".txt"):
106
- with open(prompts[0], "r") as f:
107
- prompts = f.readlines()
108
- prompts = [
109
- prompt.strip() for prompt in prompts if prompt.strip() != ""
110
- ]
111
-
112
  os.makedirs(args.output_root, exist_ok=True)
113
 
114
  ip_img_paths = args.ref_image
 
31
  build_text2img_pipeline,
32
  text2img_gen,
33
  )
34
+ from embodied_gen.utils.process_media import parse_text_prompts
35
 
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
 
86
  parser.add_argument(
87
  "--seed",
88
  type=int,
89
+ default=None,
90
  )
91
  args = parser.parse_args()
92
 
 
102
  if hasattr(args, k) and v is not None:
103
  setattr(args, k, v)
104
 
105
+ prompts = parse_text_prompts(args.prompts)
 
 
 
 
 
 
 
106
  os.makedirs(args.output_root, exist_ok=True)
107
 
108
  ip_img_paths = args.ref_image
embodied_gen/scripts/textto3d.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import random
20
+ from collections import defaultdict
21
+
22
+ import numpy as np
23
+ import torch
24
+ from PIL import Image
25
+ from embodied_gen.models.image_comm_model import build_hf_image_pipeline
26
+ from embodied_gen.models.segment_model import RembgRemover
27
+ from embodied_gen.models.text_model import PROMPT_APPEND
28
+ from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
29
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
30
+ from embodied_gen.utils.log import logger
31
+ from embodied_gen.utils.process_media import (
32
+ check_object_edge_truncated,
33
+ render_asset3d,
34
+ )
35
+ from embodied_gen.validators.quality_checkers import (
36
+ ImageSegChecker,
37
+ SemanticConsistChecker,
38
+ TextGenAlignChecker,
39
+ )
40
+
41
+ # Avoid huggingface/tokenizers: The current process just got forked.
42
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
+ random.seed(0)
44
+
45
+ logger.info("Loading Models...")
46
+ SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
47
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
48
+ TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
49
+ PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35"))
50
+ BG_REMOVER = RembgRemover()
51
+
52
+
53
+ __all__ = [
54
+ "text_to_image",
55
+ "text_to_3d",
56
+ ]
57
+
58
+
59
+ def text_to_image(
60
+ prompt: str,
61
+ save_path: str,
62
+ n_retry: int,
63
+ img_denoise_step: int,
64
+ text_guidance_scale: float,
65
+ n_img_sample: int,
66
+ image_hw: tuple[int, int] = (1024, 1024),
67
+ seed: int = None,
68
+ ) -> bool:
69
+ select_image = None
70
+ success_flag = False
71
+ assert save_path.endswith(".png"), "Image save path must end with `.png`."
72
+ for try_idx in range(n_retry):
73
+ if select_image is not None:
74
+ select_image[0].save(save_path.replace(".png", "_raw.png"))
75
+ select_image[1].save(save_path)
76
+ break
77
+
78
+ f_prompt = PROMPT_APPEND.format(object=prompt)
79
+ logger.info(
80
+ f"Image GEN for {os.path.basename(save_path)}\n"
81
+ f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
82
+ )
83
+ torch.cuda.empty_cache()
84
+ images = PIPE_IMG.run(
85
+ f_prompt,
86
+ num_inference_steps=img_denoise_step,
87
+ guidance_scale=text_guidance_scale,
88
+ num_images_per_prompt=n_img_sample,
89
+ height=image_hw[0],
90
+ width=image_hw[1],
91
+ generator=(
92
+ torch.Generator().manual_seed(seed)
93
+ if seed is not None
94
+ else None
95
+ ),
96
+ )
97
+
98
+ for idx in range(len(images)):
99
+ raw_image: Image.Image = images[idx]
100
+ image = BG_REMOVER(raw_image)
101
+ image.save(save_path)
102
+ semantic_flag, semantic_result = SEMANTIC_CHECKER(
103
+ prompt, [image.convert("RGB")]
104
+ )
105
+ seg_flag, seg_result = SEG_CHECKER(
106
+ [raw_image, image.convert("RGB")]
107
+ )
108
+ image_mask = np.array(image)[..., -1]
109
+ edge_flag = check_object_edge_truncated(image_mask)
110
+ logger.warning(
111
+ f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}"
112
+ )
113
+ if (
114
+ (edge_flag and semantic_flag and seg_flag)
115
+ or (edge_flag and semantic_flag is None)
116
+ or (edge_flag and seg_flag is None)
117
+ ):
118
+ select_image = [raw_image, image]
119
+ success_flag = True
120
+ break
121
+
122
+ seed = random.randint(0, 100000) if seed is not None else None
123
+
124
+ return success_flag
125
+
126
+
127
+ def text_to_3d(**kwargs) -> dict:
128
+ args = parse_args()
129
+ for k, v in kwargs.items():
130
+ if hasattr(args, k) and v is not None:
131
+ setattr(args, k, v)
132
+
133
+ if args.asset_names is None or len(args.asset_names) == 0:
134
+ args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
135
+ img_save_dir = os.path.join(args.output_root, "images")
136
+ asset_save_dir = os.path.join(args.output_root, "asset3d")
137
+ os.makedirs(img_save_dir, exist_ok=True)
138
+ os.makedirs(asset_save_dir, exist_ok=True)
139
+ results = defaultdict(dict)
140
+ for prompt, node in zip(args.prompts, args.asset_names):
141
+ success_flag = False
142
+ n_pipe_retry = args.n_pipe_retry
143
+ seed_img = args.seed_img
144
+ seed_3d = args.seed_3d
145
+ while success_flag is False and n_pipe_retry > 0:
146
+ logger.info(
147
+ f"GEN pipeline for node {node}\n"
148
+ f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
149
+ )
150
+ # Text-to-image GEN
151
+ save_node = node.replace(" ", "_")
152
+ gen_image_path = f"{img_save_dir}/{save_node}.png"
153
+ textgen_flag = text_to_image(
154
+ prompt,
155
+ gen_image_path,
156
+ args.n_image_retry,
157
+ args.img_denoise_step,
158
+ args.text_guidance_scale,
159
+ args.n_img_sample,
160
+ seed=seed_img,
161
+ )
162
+
163
+ # Asset 3D GEN
164
+ node_save_dir = f"{asset_save_dir}/{save_node}"
165
+ asset_type = node if "sample3d_" not in node else None
166
+ imageto3d_api(
167
+ image_path=[gen_image_path],
168
+ output_root=node_save_dir,
169
+ asset_type=[asset_type],
170
+ seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
171
+ n_retry=args.n_asset_retry,
172
+ keep_intermediate=args.keep_intermediate,
173
+ )
174
+ mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
175
+ image_path = render_asset3d(
176
+ mesh_path,
177
+ output_root=f"{node_save_dir}/result",
178
+ num_images=6,
179
+ elevation=(30, -30),
180
+ output_subdir="renders",
181
+ no_index_file=True,
182
+ )
183
+
184
+ check_text = asset_type if asset_type is not None else prompt
185
+ qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
186
+ logger.warning(
187
+ f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
188
+ )
189
+ results["assets"][node] = f"{node_save_dir}/result"
190
+ results["quality"][node] = qa_result
191
+
192
+ if qa_flag is None or qa_flag is True:
193
+ success_flag = True
194
+ break
195
+
196
+ n_pipe_retry -= 1
197
+ seed_img = (
198
+ random.randint(0, 100000) if seed_img is not None else None
199
+ )
200
+ seed_3d = (
201
+ random.randint(0, 100000) if seed_3d is not None else None
202
+ )
203
+
204
+ torch.cuda.empty_cache()
205
+
206
+ return results
207
+
208
+
209
+ def parse_args():
210
+ parser = argparse.ArgumentParser(description="3D Layout Generation Config")
211
+ parser.add_argument("--prompts", nargs="+", help="text descriptions")
212
+ parser.add_argument(
213
+ "--output_root",
214
+ type=str,
215
+ help="Directory to save outputs",
216
+ )
217
+ parser.add_argument(
218
+ "--asset_names",
219
+ type=str,
220
+ nargs="+",
221
+ default=None,
222
+ help="Asset names to generate",
223
+ )
224
+ parser.add_argument(
225
+ "--n_img_sample",
226
+ type=int,
227
+ default=3,
228
+ help="Number of image samples to generate",
229
+ )
230
+ parser.add_argument(
231
+ "--text_guidance_scale",
232
+ type=float,
233
+ default=7,
234
+ help="Text-to-image guidance scale",
235
+ )
236
+ parser.add_argument(
237
+ "--img_denoise_step",
238
+ type=int,
239
+ default=25,
240
+ help="Denoising steps for image generation",
241
+ )
242
+ parser.add_argument(
243
+ "--n_image_retry",
244
+ type=int,
245
+ default=2,
246
+ help="Max retry count for image generation",
247
+ )
248
+ parser.add_argument(
249
+ "--n_asset_retry",
250
+ type=int,
251
+ default=2,
252
+ help="Max retry count for 3D generation",
253
+ )
254
+ parser.add_argument(
255
+ "--n_pipe_retry",
256
+ type=int,
257
+ default=1,
258
+ help="Max retry count for 3D asset generation",
259
+ )
260
+ parser.add_argument(
261
+ "--seed_img",
262
+ type=int,
263
+ default=None,
264
+ help="Random seed for image generation",
265
+ )
266
+ parser.add_argument(
267
+ "--seed_3d",
268
+ type=int,
269
+ default=0,
270
+ help="Random seed for 3D generation",
271
+ )
272
+ parser.add_argument("--keep_intermediate", action="store_true")
273
+
274
+ args, unknown = parser.parse_known_args()
275
+
276
+ return args
277
+
278
+
279
+ if __name__ == "__main__":
280
+ text_to_3d()
embodied_gen/scripts/textto3d.sh CHANGED
@@ -2,7 +2,9 @@
2
 
3
  # Initialize variables
4
  prompts=()
 
5
  output_root=""
 
6
 
7
  # Parse arguments
8
  while [[ $# -gt 0 ]]; do
@@ -14,10 +16,21 @@ while [[ $# -gt 0 ]]; do
14
  shift
15
  done
16
  ;;
 
 
 
 
 
 
 
17
  --output_root)
18
  output_root="$2"
19
  shift 2
20
  ;;
 
 
 
 
21
  *)
22
  echo "Unknown argument: $1"
23
  exit 1
@@ -28,7 +41,21 @@ done
28
  # Validate required arguments
29
  if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
30
  echo "Missing required arguments."
31
- echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" --output_root <path>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  exit 1
33
  fi
34
 
@@ -37,20 +64,30 @@ echo "Prompts:"
37
  for p in "${prompts[@]}"; do
38
  echo " - $p"
39
  done
 
 
 
 
40
  echo "Output root: ${output_root}"
 
41
 
42
- # Concatenate prompts for Python command
43
  prompt_args=""
44
- for p in "${prompts[@]}"; do
45
- prompt_args+="\"$p\" "
 
 
46
  done
47
 
 
48
  # Step 1: Text-to-Image
49
  eval python3 embodied_gen/scripts/text2image.py \
50
  --prompts ${prompt_args} \
51
- --output_root "${output_root}/images"
 
52
 
53
  # Step 2: Image-to-3D
54
  python3 embodied_gen/scripts/imageto3d.py \
55
  --image_root "${output_root}/images" \
56
- --output_root "${output_root}/asset3d"
 
 
2
 
3
  # Initialize variables
4
  prompts=()
5
+ asset_types=()
6
  output_root=""
7
+ seed=0
8
 
9
  # Parse arguments
10
  while [[ $# -gt 0 ]]; do
 
16
  shift
17
  done
18
  ;;
19
+ --asset_types)
20
+ shift
21
+ while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
22
+ asset_types+=("$1")
23
+ shift
24
+ done
25
+ ;;
26
  --output_root)
27
  output_root="$2"
28
  shift 2
29
  ;;
30
+ --seed)
31
+ seed="$2"
32
+ shift 2
33
+ ;;
34
  *)
35
  echo "Unknown argument: $1"
36
  exit 1
 
41
  # Validate required arguments
42
  if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
43
  echo "Missing required arguments."
44
+ echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \
45
+ --asset_types \"type1\" \"type2\" --seed <seed_value> --output_root <path>"
46
+ exit 1
47
+ fi
48
+
49
+ # If no asset_types provided, default to ""
50
+ if [[ ${#asset_types[@]} -eq 0 ]]; then
51
+ for (( i=0; i<${#prompts[@]}; i++ )); do
52
+ asset_types+=("")
53
+ done
54
+ fi
55
+
56
+ # Ensure the number of asset_types matches the number of prompts
57
+ if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then
58
+ echo "The number of asset types must match the number of prompts."
59
  exit 1
60
  fi
61
 
 
64
  for p in "${prompts[@]}"; do
65
  echo " - $p"
66
  done
67
+ # echo "Asset types:"
68
+ # for at in "${asset_types[@]}"; do
69
+ # echo " - $at"
70
+ # done
71
  echo "Output root: ${output_root}"
72
+ echo "Seed: ${seed}"
73
 
74
+ # Concatenate prompts and asset types for Python command
75
  prompt_args=""
76
+ asset_type_args=""
77
+ for i in "${!prompts[@]}"; do
78
+ prompt_args+="\"${prompts[$i]}\" "
79
+ asset_type_args+="\"${asset_types[$i]}\" "
80
  done
81
 
82
+
83
  # Step 1: Text-to-Image
84
  eval python3 embodied_gen/scripts/text2image.py \
85
  --prompts ${prompt_args} \
86
+ --output_root "${output_root}/images" \
87
+ --seed ${seed}
88
 
89
  # Step 2: Image-to-3D
90
  python3 embodied_gen/scripts/imageto3d.py \
91
  --image_root "${output_root}/images" \
92
+ --output_root "${output_root}/asset3d" \
93
+ --asset_type ${asset_type_args}
embodied_gen/scripts/texture_gen.sh CHANGED
@@ -10,10 +10,6 @@ while [[ $# -gt 0 ]]; do
10
  prompt="$2"
11
  shift 2
12
  ;;
13
- --uuid)
14
- uuid="$2"
15
- shift 2
16
- ;;
17
  --output_root)
18
  output_root="$2"
19
  shift 2
@@ -26,12 +22,13 @@ while [[ $# -gt 0 ]]; do
26
  done
27
 
28
 
29
- if [[ -z "$mesh_path" || -z "$prompt" || -z "$uuid" || -z "$output_root" ]]; then
30
  echo "params missing"
31
- echo "usage: bash run.sh --mesh_path <path> --prompt <text> --uuid <id> --output_root <path>"
32
  exit 1
33
  fi
34
 
 
35
  # Step 1: drender-cli for condition rendering
36
  drender-cli --mesh_path ${mesh_path} \
37
  --output_root ${output_root}/condition \
 
10
  prompt="$2"
11
  shift 2
12
  ;;
 
 
 
 
13
  --output_root)
14
  output_root="$2"
15
  shift 2
 
22
  done
23
 
24
 
25
+ if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
26
  echo "params missing"
27
+ echo "usage: bash run.sh --mesh_path <path> --prompt <text> --output_root <path>"
28
  exit 1
29
  fi
30
 
31
+ uuid=$(basename "$output_root")
32
  # Step 1: drender-cli for condition rendering
33
  drender-cli --mesh_path ${mesh_path} \
34
  --output_root ${output_root}/condition \
embodied_gen/trainer/gsplat_trainer.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+ # Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
17
+ # Both under the Apache License, Version 2.0.
18
+
19
+
20
+ import json
21
+ import os
22
+ import time
23
+ from collections import defaultdict
24
+ from typing import Dict, Optional, Tuple
25
+
26
+ import cv2
27
+ import imageio
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import tqdm
32
+ import tyro
33
+ import yaml
34
+ from fused_ssim import fused_ssim
35
+ from gsplat.distributed import cli
36
+ from gsplat.rendering import rasterization
37
+ from gsplat.strategy import DefaultStrategy, MCMCStrategy
38
+ from torch import Tensor
39
+ from torch.utils.tensorboard import SummaryWriter
40
+ from torchmetrics.image import (
41
+ PeakSignalNoiseRatio,
42
+ StructuralSimilarityIndexMeasure,
43
+ )
44
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
45
+ from typing_extensions import Literal, assert_never
46
+ from embodied_gen.data.datasets import PanoGSplatDataset
47
+ from embodied_gen.utils.config import GsplatTrainConfig
48
+ from embodied_gen.utils.gaussian import (
49
+ create_splats_with_optimizers,
50
+ export_splats,
51
+ resize_pinhole_intrinsics,
52
+ set_random_seed,
53
+ )
54
+
55
+
56
+ class Runner:
57
+ """Engine for training and testing from gsplat example.
58
+
59
+ Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ local_rank: int,
65
+ world_rank,
66
+ world_size: int,
67
+ cfg: GsplatTrainConfig,
68
+ ) -> None:
69
+ set_random_seed(42 + local_rank)
70
+
71
+ self.cfg = cfg
72
+ self.world_rank = world_rank
73
+ self.local_rank = local_rank
74
+ self.world_size = world_size
75
+ self.device = f"cuda:{local_rank}"
76
+
77
+ # Where to dump results.
78
+ os.makedirs(cfg.result_dir, exist_ok=True)
79
+
80
+ # Setup output directories.
81
+ self.ckpt_dir = f"{cfg.result_dir}/ckpts"
82
+ os.makedirs(self.ckpt_dir, exist_ok=True)
83
+ self.stats_dir = f"{cfg.result_dir}/stats"
84
+ os.makedirs(self.stats_dir, exist_ok=True)
85
+ self.render_dir = f"{cfg.result_dir}/renders"
86
+ os.makedirs(self.render_dir, exist_ok=True)
87
+ self.ply_dir = f"{cfg.result_dir}/ply"
88
+ os.makedirs(self.ply_dir, exist_ok=True)
89
+
90
+ # Tensorboard
91
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
92
+ self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
93
+ self.valset = PanoGSplatDataset(
94
+ cfg.data_dir, split="train", max_sample_num=6
95
+ )
96
+ self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
97
+ self.scene_scale = cfg.scene_scale
98
+
99
+ # Model
100
+ self.splats, self.optimizers = create_splats_with_optimizers(
101
+ self.trainset.points,
102
+ self.trainset.points_rgb,
103
+ init_num_pts=cfg.init_num_pts,
104
+ init_extent=cfg.init_extent,
105
+ init_opacity=cfg.init_opa,
106
+ init_scale=cfg.init_scale,
107
+ means_lr=cfg.means_lr,
108
+ scales_lr=cfg.scales_lr,
109
+ opacities_lr=cfg.opacities_lr,
110
+ quats_lr=cfg.quats_lr,
111
+ sh0_lr=cfg.sh0_lr,
112
+ shN_lr=cfg.shN_lr,
113
+ scene_scale=self.scene_scale,
114
+ sh_degree=cfg.sh_degree,
115
+ sparse_grad=cfg.sparse_grad,
116
+ visible_adam=cfg.visible_adam,
117
+ batch_size=cfg.batch_size,
118
+ feature_dim=None,
119
+ device=self.device,
120
+ world_rank=world_rank,
121
+ world_size=world_size,
122
+ )
123
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
124
+
125
+ # Densification Strategy
126
+ self.cfg.strategy.check_sanity(self.splats, self.optimizers)
127
+
128
+ if isinstance(self.cfg.strategy, DefaultStrategy):
129
+ self.strategy_state = self.cfg.strategy.initialize_state(
130
+ scene_scale=self.scene_scale
131
+ )
132
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
133
+ self.strategy_state = self.cfg.strategy.initialize_state()
134
+ else:
135
+ assert_never(self.cfg.strategy)
136
+
137
+ # Losses & Metrics.
138
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
139
+ self.device
140
+ )
141
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
142
+
143
+ if cfg.lpips_net == "alex":
144
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
145
+ net_type="alex", normalize=True
146
+ ).to(self.device)
147
+ elif cfg.lpips_net == "vgg":
148
+ # The 3DGS official repo uses lpips vgg, which is equivalent with the following:
149
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
150
+ net_type="vgg", normalize=False
151
+ ).to(self.device)
152
+ else:
153
+ raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
154
+
155
+ def rasterize_splats(
156
+ self,
157
+ camtoworlds: Tensor,
158
+ Ks: Tensor,
159
+ width: int,
160
+ height: int,
161
+ masks: Optional[Tensor] = None,
162
+ rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
163
+ camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
164
+ **kwargs,
165
+ ) -> Tuple[Tensor, Tensor, Dict]:
166
+ means = self.splats["means"] # [N, 3]
167
+ # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
168
+ # rasterization does normalization internally
169
+ quats = self.splats["quats"] # [N, 4]
170
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
171
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
172
+ image_ids = kwargs.pop("image_ids", None)
173
+
174
+ colors = torch.cat(
175
+ [self.splats["sh0"], self.splats["shN"]], 1
176
+ ) # [N, K, 3]
177
+
178
+ if rasterize_mode is None:
179
+ rasterize_mode = (
180
+ "antialiased" if self.cfg.antialiased else "classic"
181
+ )
182
+ if camera_model is None:
183
+ camera_model = self.cfg.camera_model
184
+
185
+ render_colors, render_alphas, info = rasterization(
186
+ means=means,
187
+ quats=quats,
188
+ scales=scales,
189
+ opacities=opacities,
190
+ colors=colors,
191
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
192
+ Ks=Ks, # [C, 3, 3]
193
+ width=width,
194
+ height=height,
195
+ packed=self.cfg.packed,
196
+ absgrad=(
197
+ self.cfg.strategy.absgrad
198
+ if isinstance(self.cfg.strategy, DefaultStrategy)
199
+ else False
200
+ ),
201
+ sparse_grad=self.cfg.sparse_grad,
202
+ rasterize_mode=rasterize_mode,
203
+ distributed=self.world_size > 1,
204
+ camera_model=self.cfg.camera_model,
205
+ with_ut=self.cfg.with_ut,
206
+ with_eval3d=self.cfg.with_eval3d,
207
+ **kwargs,
208
+ )
209
+ if masks is not None:
210
+ render_colors[~masks] = 0
211
+ return render_colors, render_alphas, info
212
+
213
+ def train(self):
214
+ cfg = self.cfg
215
+ device = self.device
216
+ world_rank = self.world_rank
217
+
218
+ # Dump cfg.
219
+ if world_rank == 0:
220
+ with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
221
+ yaml.dump(vars(cfg), f)
222
+
223
+ max_steps = cfg.max_steps
224
+ init_step = 0
225
+
226
+ schedulers = [
227
+ # means has a learning rate schedule, that end at 0.01 of the initial value
228
+ torch.optim.lr_scheduler.ExponentialLR(
229
+ self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
230
+ ),
231
+ ]
232
+ trainloader = torch.utils.data.DataLoader(
233
+ self.trainset,
234
+ batch_size=cfg.batch_size,
235
+ shuffle=True,
236
+ num_workers=4,
237
+ persistent_workers=True,
238
+ pin_memory=True,
239
+ )
240
+ trainloader_iter = iter(trainloader)
241
+
242
+ # Training loop.
243
+ global_tic = time.time()
244
+ pbar = tqdm.tqdm(range(init_step, max_steps))
245
+ for step in pbar:
246
+ try:
247
+ data = next(trainloader_iter)
248
+ except StopIteration:
249
+ trainloader_iter = iter(trainloader)
250
+ data = next(trainloader_iter)
251
+
252
+ camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
253
+ Ks = data["K"].to(device) # [1, 3, 3]
254
+ pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
255
+ image_ids = data["image_id"].to(device)
256
+ masks = (
257
+ data["mask"].to(device) if "mask" in data else None
258
+ ) # [1, H, W]
259
+ if cfg.depth_loss:
260
+ points = data["points"].to(device) # [1, M, 2]
261
+ depths_gt = data["depths"].to(device) # [1, M]
262
+
263
+ height, width = pixels.shape[1:3]
264
+
265
+ # sh schedule
266
+ sh_degree_to_use = min(
267
+ step // cfg.sh_degree_interval, cfg.sh_degree
268
+ )
269
+
270
+ # forward
271
+ renders, alphas, info = self.rasterize_splats(
272
+ camtoworlds=camtoworlds,
273
+ Ks=Ks,
274
+ width=width,
275
+ height=height,
276
+ sh_degree=sh_degree_to_use,
277
+ near_plane=cfg.near_plane,
278
+ far_plane=cfg.far_plane,
279
+ image_ids=image_ids,
280
+ render_mode="RGB+ED" if cfg.depth_loss else "RGB",
281
+ masks=masks,
282
+ )
283
+ if renders.shape[-1] == 4:
284
+ colors, depths = renders[..., 0:3], renders[..., 3:4]
285
+ else:
286
+ colors, depths = renders, None
287
+
288
+ if cfg.random_bkgd:
289
+ bkgd = torch.rand(1, 3, device=device)
290
+ colors = colors + bkgd * (1.0 - alphas)
291
+
292
+ self.cfg.strategy.step_pre_backward(
293
+ params=self.splats,
294
+ optimizers=self.optimizers,
295
+ state=self.strategy_state,
296
+ step=step,
297
+ info=info,
298
+ )
299
+
300
+ # loss
301
+ l1loss = F.l1_loss(colors, pixels)
302
+ ssimloss = 1.0 - fused_ssim(
303
+ colors.permute(0, 3, 1, 2),
304
+ pixels.permute(0, 3, 1, 2),
305
+ padding="valid",
306
+ )
307
+ loss = (
308
+ l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
309
+ )
310
+ if cfg.depth_loss:
311
+ # query depths from depth map
312
+ points = torch.stack(
313
+ [
314
+ points[:, :, 0] / (width - 1) * 2 - 1,
315
+ points[:, :, 1] / (height - 1) * 2 - 1,
316
+ ],
317
+ dim=-1,
318
+ ) # normalize to [-1, 1]
319
+ grid = points.unsqueeze(2) # [1, M, 1, 2]
320
+ depths = F.grid_sample(
321
+ depths.permute(0, 3, 1, 2), grid, align_corners=True
322
+ ) # [1, 1, M, 1]
323
+ depths = depths.squeeze(3).squeeze(1) # [1, M]
324
+ # calculate loss in disparity space
325
+ disp = torch.where(
326
+ depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
327
+ )
328
+ disp_gt = 1.0 / depths_gt # [1, M]
329
+ depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
330
+ loss += depthloss * cfg.depth_lambda
331
+
332
+ # regularizations
333
+ if cfg.opacity_reg > 0.0:
334
+ loss += (
335
+ cfg.opacity_reg
336
+ * torch.sigmoid(self.splats["opacities"]).mean()
337
+ )
338
+ if cfg.scale_reg > 0.0:
339
+ loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
340
+
341
+ loss.backward()
342
+
343
+ desc = (
344
+ f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
345
+ )
346
+ if cfg.depth_loss:
347
+ desc += f"depth loss={depthloss.item():.6f}| "
348
+ pbar.set_description(desc)
349
+
350
+ # write images (gt and render)
351
+ # if world_rank == 0 and step % 800 == 0:
352
+ # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
353
+ # canvas = canvas.reshape(-1, *canvas.shape[2:])
354
+ # imageio.imwrite(
355
+ # f"{self.render_dir}/train_rank{self.world_rank}.png",
356
+ # (canvas * 255).astype(np.uint8),
357
+ # )
358
+
359
+ if (
360
+ world_rank == 0
361
+ and cfg.tb_every > 0
362
+ and step % cfg.tb_every == 0
363
+ ):
364
+ mem = torch.cuda.max_memory_allocated() / 1024**3
365
+ self.writer.add_scalar("train/loss", loss.item(), step)
366
+ self.writer.add_scalar("train/l1loss", l1loss.item(), step)
367
+ self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
368
+ self.writer.add_scalar(
369
+ "train/num_GS", len(self.splats["means"]), step
370
+ )
371
+ self.writer.add_scalar("train/mem", mem, step)
372
+ if cfg.depth_loss:
373
+ self.writer.add_scalar(
374
+ "train/depthloss", depthloss.item(), step
375
+ )
376
+ if cfg.tb_save_image:
377
+ canvas = (
378
+ torch.cat([pixels, colors], dim=2)
379
+ .detach()
380
+ .cpu()
381
+ .numpy()
382
+ )
383
+ canvas = canvas.reshape(-1, *canvas.shape[2:])
384
+ self.writer.add_image("train/render", canvas, step)
385
+ self.writer.flush()
386
+
387
+ # save checkpoint before updating the model
388
+ if (
389
+ step in [i - 1 for i in cfg.save_steps]
390
+ or step == max_steps - 1
391
+ ):
392
+ mem = torch.cuda.max_memory_allocated() / 1024**3
393
+ stats = {
394
+ "mem": mem,
395
+ "ellipse_time": time.time() - global_tic,
396
+ "num_GS": len(self.splats["means"]),
397
+ }
398
+ print("Step: ", step, stats)
399
+ with open(
400
+ f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
401
+ "w",
402
+ ) as f:
403
+ json.dump(stats, f)
404
+ data = {"step": step, "splats": self.splats.state_dict()}
405
+ torch.save(
406
+ data,
407
+ f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
408
+ )
409
+ if (
410
+ step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
411
+ ) and cfg.save_ply:
412
+ sh0 = self.splats["sh0"]
413
+ shN = self.splats["shN"]
414
+ means = self.splats["means"]
415
+ scales = self.splats["scales"]
416
+ quats = self.splats["quats"]
417
+ opacities = self.splats["opacities"]
418
+ export_splats(
419
+ means=means,
420
+ scales=scales,
421
+ quats=quats,
422
+ opacities=opacities,
423
+ sh0=sh0,
424
+ shN=shN,
425
+ format="ply",
426
+ save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
427
+ )
428
+
429
+ # Turn Gradients into Sparse Tensor before running optimizer
430
+ if cfg.sparse_grad:
431
+ assert (
432
+ cfg.packed
433
+ ), "Sparse gradients only work with packed mode."
434
+ gaussian_ids = info["gaussian_ids"]
435
+ for k in self.splats.keys():
436
+ grad = self.splats[k].grad
437
+ if grad is None or grad.is_sparse:
438
+ continue
439
+ self.splats[k].grad = torch.sparse_coo_tensor(
440
+ indices=gaussian_ids[None], # [1, nnz]
441
+ values=grad[gaussian_ids], # [nnz, ...]
442
+ size=self.splats[k].size(), # [N, ...]
443
+ is_coalesced=len(Ks) == 1,
444
+ )
445
+
446
+ if cfg.visible_adam:
447
+ gaussian_cnt = self.splats.means.shape[0]
448
+ if cfg.packed:
449
+ visibility_mask = torch.zeros_like(
450
+ self.splats["opacities"], dtype=bool
451
+ )
452
+ visibility_mask.scatter_(0, info["gaussian_ids"], 1)
453
+ else:
454
+ visibility_mask = (info["radii"] > 0).all(-1).any(0)
455
+
456
+ # optimize
457
+ for optimizer in self.optimizers.values():
458
+ if cfg.visible_adam:
459
+ optimizer.step(visibility_mask)
460
+ else:
461
+ optimizer.step()
462
+ optimizer.zero_grad(set_to_none=True)
463
+ for scheduler in schedulers:
464
+ scheduler.step()
465
+
466
+ # Run post-backward steps after backward and optimizer
467
+ if isinstance(self.cfg.strategy, DefaultStrategy):
468
+ self.cfg.strategy.step_post_backward(
469
+ params=self.splats,
470
+ optimizers=self.optimizers,
471
+ state=self.strategy_state,
472
+ step=step,
473
+ info=info,
474
+ packed=cfg.packed,
475
+ )
476
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
477
+ self.cfg.strategy.step_post_backward(
478
+ params=self.splats,
479
+ optimizers=self.optimizers,
480
+ state=self.strategy_state,
481
+ step=step,
482
+ info=info,
483
+ lr=schedulers[0].get_last_lr()[0],
484
+ )
485
+ else:
486
+ assert_never(self.cfg.strategy)
487
+
488
+ # eval the full set
489
+ if step in [i - 1 for i in cfg.eval_steps]:
490
+ self.eval(step)
491
+ self.render_video(step)
492
+
493
+ @torch.no_grad()
494
+ def eval(
495
+ self,
496
+ step: int,
497
+ stage: str = "val",
498
+ canvas_h: int = 512,
499
+ canvas_w: int = 1024,
500
+ ):
501
+ """Entry for evaluation."""
502
+ print("Running evaluation...")
503
+ cfg = self.cfg
504
+ device = self.device
505
+ world_rank = self.world_rank
506
+
507
+ valloader = torch.utils.data.DataLoader(
508
+ self.valset, batch_size=1, shuffle=False, num_workers=1
509
+ )
510
+ ellipse_time = 0
511
+ metrics = defaultdict(list)
512
+ for i, data in enumerate(valloader):
513
+ camtoworlds = data["camtoworld"].to(device)
514
+ Ks = data["K"].to(device)
515
+ pixels = data["image"].to(device) / 255.0
516
+ height, width = pixels.shape[1:3]
517
+ masks = data["mask"].to(device) if "mask" in data else None
518
+
519
+ pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
520
+ pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
521
+
522
+ torch.cuda.synchronize()
523
+ tic = time.time()
524
+ colors, _, _ = self.rasterize_splats(
525
+ camtoworlds=camtoworlds,
526
+ Ks=Ks,
527
+ width=width,
528
+ height=height,
529
+ sh_degree=cfg.sh_degree,
530
+ near_plane=cfg.near_plane,
531
+ far_plane=cfg.far_plane,
532
+ masks=masks,
533
+ ) # [1, H, W, 3]
534
+ torch.cuda.synchronize()
535
+ ellipse_time += max(time.time() - tic, 1e-10)
536
+
537
+ colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
538
+ colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
539
+ colors = torch.clamp(colors, 0.0, 1.0)
540
+ canvas_list = [pixels, colors]
541
+
542
+ if world_rank == 0:
543
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0)
544
+ canvas = canvas.permute(1, 2, 0) # CHW -> HWC
545
+ canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
546
+ cv2.imwrite(
547
+ f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
548
+ canvas[..., ::-1],
549
+ )
550
+ metrics["psnr"].append(self.psnr(colors, pixels))
551
+ metrics["ssim"].append(self.ssim(colors, pixels))
552
+ metrics["lpips"].append(self.lpips(colors, pixels))
553
+
554
+ if world_rank == 0:
555
+ ellipse_time /= len(valloader)
556
+
557
+ stats = {
558
+ k: torch.stack(v).mean().item() for k, v in metrics.items()
559
+ }
560
+ stats.update(
561
+ {
562
+ "ellipse_time": ellipse_time,
563
+ "num_GS": len(self.splats["means"]),
564
+ }
565
+ )
566
+ print(
567
+ f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
568
+ f"Time: {stats['ellipse_time']:.3f}s/image "
569
+ f"Number of GS: {stats['num_GS']}"
570
+ )
571
+ # save stats as json
572
+ with open(
573
+ f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
574
+ ) as f:
575
+ json.dump(stats, f)
576
+ # save stats to tensorboard
577
+ for k, v in stats.items():
578
+ self.writer.add_scalar(f"{stage}/{k}", v, step)
579
+ self.writer.flush()
580
+
581
+ @torch.no_grad()
582
+ def render_video(
583
+ self, step: int, canvas_h: int = 512, canvas_w: int = 1024
584
+ ):
585
+ testloader = torch.utils.data.DataLoader(
586
+ self.testset, batch_size=1, shuffle=False, num_workers=1
587
+ )
588
+
589
+ images_cache = []
590
+ depth_global_min, depth_global_max = float("inf"), -float("inf")
591
+ for data in testloader:
592
+ camtoworlds = data["camtoworld"].to(self.device)
593
+ Ks = resize_pinhole_intrinsics(
594
+ data["K"].squeeze(),
595
+ raw_hw=(data["image_h"].item(), data["image_w"].item()),
596
+ new_hw=(canvas_h, canvas_w // 2),
597
+ ).to(self.device)
598
+ renders, _, _ = self.rasterize_splats(
599
+ camtoworlds=camtoworlds,
600
+ Ks=Ks[None, ...],
601
+ width=canvas_w // 2,
602
+ height=canvas_h,
603
+ sh_degree=self.cfg.sh_degree,
604
+ near_plane=self.cfg.near_plane,
605
+ far_plane=self.cfg.far_plane,
606
+ render_mode="RGB+ED",
607
+ ) # [1, H, W, 4]
608
+ colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
609
+ colors = (colors * 255).to(torch.uint8).cpu().numpy()
610
+ depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
611
+ images_cache.append([colors, depths])
612
+ depth_global_min = min(depth_global_min, depths.min().item())
613
+ depth_global_max = max(depth_global_max, depths.max().item())
614
+
615
+ video_path = f"{self.render_dir}/video_step{step}.mp4"
616
+ writer = imageio.get_writer(video_path, fps=30)
617
+ for rgb, depth in images_cache:
618
+ depth_normalized = torch.clip(
619
+ (depth - depth_global_min)
620
+ / (depth_global_max - depth_global_min),
621
+ 0,
622
+ 1,
623
+ )
624
+ depth_normalized = (
625
+ (depth_normalized * 255).to(torch.uint8).cpu().numpy()
626
+ )
627
+ depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
628
+ image = np.concatenate([rgb, depth_map], axis=1)
629
+ writer.append_data(image)
630
+
631
+ writer.close()
632
+
633
+
634
+ def entrypoint(
635
+ local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
636
+ ):
637
+ runner = Runner(local_rank, world_rank, world_size, cfg)
638
+
639
+ if cfg.ckpt is not None:
640
+ # run eval only
641
+ ckpts = [
642
+ torch.load(file, map_location=runner.device, weights_only=True)
643
+ for file in cfg.ckpt
644
+ ]
645
+ for k in runner.splats.keys():
646
+ runner.splats[k].data = torch.cat(
647
+ [ckpt["splats"][k] for ckpt in ckpts]
648
+ )
649
+ step = ckpts[0]["step"]
650
+ runner.eval(step=step)
651
+ runner.render_video(step=step)
652
+ else:
653
+ runner.train()
654
+ runner.render_video(step=cfg.max_steps - 1)
655
+
656
+
657
+ if __name__ == "__main__":
658
+ configs = {
659
+ "default": (
660
+ "Gaussian splatting training using densification heuristics from the original paper.",
661
+ GsplatTrainConfig(
662
+ strategy=DefaultStrategy(verbose=True),
663
+ ),
664
+ ),
665
+ "mcmc": (
666
+ "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
667
+ GsplatTrainConfig(
668
+ init_scale=0.1,
669
+ opacity_reg=0.01,
670
+ scale_reg=0.01,
671
+ strategy=MCMCStrategy(verbose=True),
672
+ ),
673
+ ),
674
+ }
675
+ cfg = tyro.extras.overridable_config_cli(configs)
676
+ cfg.adjust_steps(cfg.steps_scaler)
677
+
678
+ cli(entrypoint, cfg, verbose=True)
embodied_gen/trainer/pono2mesh_trainer.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ from embodied_gen.utils.monkey_patches import monkey_patch_pano2room
19
+
20
+ monkey_patch_pano2room()
21
+
22
+ import os
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import torch
27
+ import trimesh
28
+ from equilib import cube2equi, equi2pers
29
+ from kornia.morphology import dilation
30
+ from PIL import Image
31
+ from embodied_gen.models.sr_model import ImageRealESRGAN
32
+ from embodied_gen.utils.config import Pano2MeshSRConfig
33
+ from embodied_gen.utils.gaussian import compute_pinhole_intrinsics
34
+ from embodied_gen.utils.log import logger
35
+ from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
36
+ from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
37
+ PanoFusionDistancePredictor,
38
+ )
39
+ from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter
40
+ from thirdparty.pano2room.modules.mesh_fusion.render import (
41
+ features_to_world_space_mesh,
42
+ render_mesh,
43
+ )
44
+ from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool
45
+ from thirdparty.pano2room.utils.camera_utils import gen_pano_rays
46
+ from thirdparty.pano2room.utils.functions import (
47
+ depth_to_distance,
48
+ get_cubemap_views_world_to_cam,
49
+ resize_image_with_aspect_ratio,
50
+ rot_z_world_to_cam,
51
+ tensor_to_pil,
52
+ )
53
+
54
+
55
+ class Pano2MeshSRPipeline:
56
+ """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
57
+
58
+ This class integrates several key components including:
59
+ - Depth estimation from RGB panorama
60
+ - Inpainting of missing regions under offsets
61
+ - RGB-D to mesh conversion
62
+ - Multi-view mesh repair
63
+ - 3D Gaussian Splatting (3DGS) dataset generation
64
+
65
+ Args:
66
+ config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
67
+
68
+ Example:
69
+ ```python
70
+ pipeline = Pano2MeshSRPipeline(config)
71
+ pipeline(pano_image='example.png', output_dir='./output')
72
+ ```
73
+ """
74
+
75
+ def __init__(self, config: Pano2MeshSRConfig) -> None:
76
+ self.cfg = config
77
+ self.device = config.device
78
+
79
+ # Init models.
80
+ self.inpainter = PanoPersFusionInpainter(save_path=None)
81
+ self.geo_predictor = PanoJointPredictor(save_path=None)
82
+ self.pano_fusion_distance_predictor = PanoFusionDistancePredictor()
83
+ self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor)
84
+
85
+ # Init poses.
86
+ cubemap_w2cs = get_cubemap_views_world_to_cam()
87
+ self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs]
88
+ self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir)
89
+
90
+ kernel = cv2.getStructuringElement(
91
+ cv2.MORPH_ELLIPSE, self.cfg.kernel_size
92
+ )
93
+ self.kernel = torch.from_numpy(kernel).float().to(self.device)
94
+
95
+ def init_mesh_params(self) -> None:
96
+ torch.set_default_device(self.device)
97
+ self.inpaint_mask = torch.ones(
98
+ (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
99
+ )
100
+ self.vertices = torch.empty((3, 0), requires_grad=False)
101
+ self.colors = torch.empty((3, 0), requires_grad=False)
102
+ self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False)
103
+
104
+ @staticmethod
105
+ def read_camera_pose_file(filepath: str) -> np.ndarray:
106
+ with open(filepath, "r") as f:
107
+ values = [float(num) for line in f for num in line.split()]
108
+
109
+ return np.array(values).reshape(4, 4)
110
+
111
+ def load_camera_poses(
112
+ self, trajectory_dir: str
113
+ ) -> tuple[np.ndarray, list[torch.Tensor]]:
114
+ pose_filenames = sorted(
115
+ [
116
+ fname
117
+ for fname in os.listdir(trajectory_dir)
118
+ if fname.startswith("camera_pose")
119
+ ]
120
+ )
121
+
122
+ pano_pose_world = None
123
+ relative_poses = []
124
+ for idx, filename in enumerate(pose_filenames):
125
+ pose_path = os.path.join(trajectory_dir, filename)
126
+ pose_matrix = self.read_camera_pose_file(pose_path)
127
+
128
+ if pano_pose_world is None:
129
+ pano_pose_world = pose_matrix.copy()
130
+ pano_pose_world[0, 3] += self.cfg.pano_center_offset[0]
131
+ pano_pose_world[2, 3] += self.cfg.pano_center_offset[1]
132
+
133
+ # Use different reference for the first 6 cubemap views
134
+ reference_pose = pose_matrix if idx < 6 else pano_pose_world
135
+ relative_matrix = pose_matrix @ np.linalg.inv(reference_pose)
136
+ relative_matrix[0:2, :] *= -1 # flip_xy
137
+ relative_matrix = (
138
+ relative_matrix @ rot_z_world_to_cam(180).cpu().numpy()
139
+ )
140
+ relative_matrix[:3, 3] *= self.cfg.pose_scale
141
+ relative_matrix = torch.tensor(
142
+ relative_matrix, dtype=torch.float32
143
+ )
144
+ relative_poses.append(relative_matrix)
145
+
146
+ return relative_poses
147
+
148
+ def load_inpaint_poses(
149
+ self, poses: torch.Tensor
150
+ ) -> dict[int, torch.Tensor]:
151
+ inpaint_poses = dict()
152
+ sampled_views = poses[:: self.cfg.inpaint_frame_stride]
153
+ init_pose = torch.eye(4)
154
+ for idx, w2c_tensor in enumerate(sampled_views):
155
+ w2c = w2c_tensor.cpu().numpy().astype(np.float32)
156
+ c2w = np.linalg.inv(w2c)
157
+ pose_tensor = init_pose.clone()
158
+ pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3])
159
+ pose_tensor[:3, 3] *= -1
160
+ inpaint_poses[idx] = pose_tensor.to(self.device)
161
+
162
+ return inpaint_poses
163
+
164
+ def project(self, world_to_cam: torch.Tensor):
165
+ (
166
+ project_image,
167
+ project_depth,
168
+ inpaint_mask,
169
+ _,
170
+ z_buf,
171
+ mesh,
172
+ ) = render_mesh(
173
+ vertices=self.vertices,
174
+ faces=self.faces,
175
+ vertex_features=self.colors,
176
+ H=self.cfg.cubemap_h,
177
+ W=self.cfg.cubemap_w,
178
+ fov_in_degrees=self.cfg.fov,
179
+ RT=world_to_cam,
180
+ blur_radius=self.cfg.blur_radius,
181
+ faces_per_pixel=self.cfg.faces_per_pixel,
182
+ )
183
+ project_image = project_image * ~inpaint_mask
184
+
185
+ return project_image[:3, ...], inpaint_mask, project_depth
186
+
187
+ def render_pano(self, pose: torch.Tensor):
188
+ cubemap_list = []
189
+ for cubemap_pose in self.cubemap_w2cs:
190
+ project_pose = cubemap_pose @ pose
191
+ rgb, inpaint_mask, depth = self.project(project_pose)
192
+ distance_map = depth_to_distance(depth[None, ...])
193
+ mask = inpaint_mask[None, ...]
194
+ cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0))
195
+
196
+ # Set default tensor type for CPU operation in cube2equi
197
+ with torch.device("cpu"):
198
+ pano_rgbd = cube2equi(
199
+ cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w
200
+ )
201
+
202
+ pano_rgb = pano_rgbd[:3, :, :]
203
+ pano_depth = pano_rgbd[3:4, :, :].squeeze(0)
204
+ pano_mask = pano_rgbd[4:, :, :].squeeze(0)
205
+
206
+ return pano_rgb, pano_depth, pano_mask
207
+
208
+ def rgbd_to_mesh(
209
+ self,
210
+ rgb: torch.Tensor,
211
+ depth: torch.Tensor,
212
+ inpaint_mask: torch.Tensor,
213
+ world_to_cam: torch.Tensor = None,
214
+ using_distance_map: bool = True,
215
+ ) -> None:
216
+ if world_to_cam is None:
217
+ world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
218
+
219
+ if inpaint_mask.sum() == 0:
220
+ return
221
+
222
+ vertices, faces, colors = features_to_world_space_mesh(
223
+ colors=rgb.squeeze(0),
224
+ depth=depth,
225
+ fov_in_degrees=self.cfg.fov,
226
+ world_to_cam=world_to_cam,
227
+ mask=inpaint_mask,
228
+ faces=self.faces,
229
+ vertices=self.vertices,
230
+ using_distance_map=using_distance_map,
231
+ edge_threshold=0.05,
232
+ )
233
+
234
+ faces += self.vertices.shape[1]
235
+ self.vertices = torch.cat([self.vertices, vertices], dim=1)
236
+ self.colors = torch.cat([self.colors, colors], dim=1)
237
+ self.faces = torch.cat([self.faces, faces], dim=1)
238
+
239
+ def get_edge_image_by_depth(
240
+ self, depth: torch.Tensor, dilate_iter: int = 1
241
+ ) -> np.ndarray:
242
+ if isinstance(depth, torch.Tensor):
243
+ depth = depth.cpu().detach().numpy()
244
+
245
+ gray = (depth / depth.max() * 255).astype(np.uint8)
246
+ edges = cv2.Canny(gray, 60, 150)
247
+ if dilate_iter > 0:
248
+ kernel = np.ones((3, 3), np.uint8)
249
+ edges = cv2.dilate(edges, kernel, iterations=dilate_iter)
250
+
251
+ return edges
252
+
253
+ def mesh_repair_by_greedy_view_selection(
254
+ self, pose_dict: dict[str, torch.Tensor], output_dir: str
255
+ ) -> list:
256
+ inpainted_panos_w_pose = []
257
+ while len(pose_dict) > 0:
258
+ logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
259
+ sampled_views = []
260
+ for key, pose in pose_dict.items():
261
+ pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
262
+ completeness = torch.sum(1 - pano_mask) / (pano_mask.numel())
263
+ sampled_views.append((key, completeness.item(), pose))
264
+
265
+ if len(sampled_views) == 0:
266
+ break
267
+
268
+ # Find inpainting with least view completeness.
269
+ sampled_views = sorted(sampled_views, key=lambda x: x[1])
270
+ key, _, pose = sampled_views[len(sampled_views) * 2 // 3]
271
+ pose_dict.pop(key)
272
+
273
+ pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
274
+
275
+ colors = pano_rgb.permute(1, 2, 0).clone()
276
+ distances = pano_distance.unsqueeze(-1).clone()
277
+ pano_inpaint_mask = pano_mask.clone()
278
+ init_pose = pose.clone()
279
+ normals = None
280
+ if pano_inpaint_mask.min().item() < 0.5:
281
+ colors, distances, normals = self.inpaint_panorama(
282
+ idx=key,
283
+ colors=colors,
284
+ distances=distances,
285
+ pano_mask=pano_inpaint_mask,
286
+ )
287
+
288
+ init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = (
289
+ -pose[0, 3],
290
+ pose[2, 3],
291
+ 0,
292
+ )
293
+ rays = gen_pano_rays(
294
+ init_pose, self.cfg.pano_h, self.cfg.pano_w
295
+ )
296
+ conflict_mask = self.sup_pool.geo_check(
297
+ rays, distances.unsqueeze(-1)
298
+ ) # 0 is conflict, 1 not conflict
299
+ pano_inpaint_mask *= conflict_mask
300
+
301
+ self.rgbd_to_mesh(
302
+ colors.permute(2, 0, 1),
303
+ distances,
304
+ pano_inpaint_mask,
305
+ world_to_cam=pose,
306
+ )
307
+
308
+ self.sup_pool.register_sup_info(
309
+ pose=init_pose,
310
+ mask=pano_inpaint_mask.clone(),
311
+ rgb=colors,
312
+ distance=distances.unsqueeze(-1),
313
+ normal=normals,
314
+ )
315
+
316
+ colors = colors.permute(2, 0, 1).unsqueeze(0)
317
+ inpainted_panos_w_pose.append([colors, pose])
318
+
319
+ if self.cfg.visualize:
320
+ from embodied_gen.data.utils import DiffrastRender
321
+
322
+ tensor_to_pil(pano_rgb.unsqueeze(0)).save(
323
+ f"{output_dir}/rendered_pano_{key}.jpg"
324
+ )
325
+ tensor_to_pil(colors).save(
326
+ f"{output_dir}/inpainted_pano_{key}.jpg"
327
+ )
328
+ norm_depth = DiffrastRender.normalize_map_by_mask(
329
+ distances, torch.ones_like(distances)
330
+ )
331
+ heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8)
332
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
333
+ Image.fromarray(heatmap).save(
334
+ f"{output_dir}/inpainted_depth_{key}.png"
335
+ )
336
+
337
+ return inpainted_panos_w_pose
338
+
339
+ def inpaint_panorama(
340
+ self,
341
+ idx: int,
342
+ colors: torch.Tensor,
343
+ distances: torch.Tensor,
344
+ pano_mask: torch.Tensor,
345
+ ) -> tuple[torch.Tensor]:
346
+ mask = (pano_mask[None, ..., None] > 0.5).float()
347
+ mask = mask.permute(0, 3, 1, 2)
348
+ mask = dilation(mask, kernel=self.kernel)
349
+ mask = mask[0, 0, ..., None] # hwc
350
+ inpainted_img = self.inpainter.inpaint(idx, colors, mask)
351
+ inpainted_img = colors * (1 - mask) + inpainted_img * mask
352
+ inpainted_distances, inpainted_normals = self.geo_predictor(
353
+ idx,
354
+ inpainted_img,
355
+ distances[..., None],
356
+ mask=mask,
357
+ reg_loss_weight=0.0,
358
+ normal_loss_weight=5e-2,
359
+ normal_tv_loss_weight=5e-2,
360
+ )
361
+
362
+ return inpainted_img, inpainted_distances.squeeze(), inpainted_normals
363
+
364
+ def preprocess_pano(
365
+ self, image: Image.Image | str
366
+ ) -> tuple[torch.Tensor, torch.Tensor]:
367
+ if isinstance(image, str):
368
+ image = Image.open(image)
369
+
370
+ image = image.convert("RGB")
371
+
372
+ if image.size[0] < image.size[1]:
373
+ image = image.transpose(Image.TRANSPOSE)
374
+
375
+ image = resize_image_with_aspect_ratio(image, self.cfg.pano_w)
376
+ image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255
377
+ image_rgb = image_rgb.to(self.device)
378
+ image_depth = self.pano_fusion_distance_predictor.predict(
379
+ image_rgb.permute(1, 2, 0)
380
+ )
381
+ image_depth = (
382
+ image_depth / image_depth.max() * self.cfg.depth_scale_factor
383
+ )
384
+
385
+ return image_rgb, image_depth
386
+
387
+ def pano_to_perpective(
388
+ self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
389
+ ) -> torch.Tensor:
390
+ rots = dict(
391
+ roll=0,
392
+ pitch=pitch,
393
+ yaw=yaw,
394
+ )
395
+ perspective = equi2pers(
396
+ equi=pano_image.squeeze(0),
397
+ rots=rots,
398
+ height=self.cfg.cubemap_h,
399
+ width=self.cfg.cubemap_w,
400
+ fov_x=fov,
401
+ mode="bilinear",
402
+ ).unsqueeze(0)
403
+
404
+ return perspective
405
+
406
+ def pano_to_cubemap(self, pano_rgb: torch.Tensor):
407
+ # Define six canonical cube directions in (pitch, yaw)
408
+ directions = [
409
+ (0, 0),
410
+ (0, 1.5 * np.pi),
411
+ (0, 1.0 * np.pi),
412
+ (0, 0.5 * np.pi),
413
+ (-0.5 * np.pi, 0),
414
+ (0.5 * np.pi, 0),
415
+ ]
416
+
417
+ cubemaps_rgb = []
418
+ for pitch, yaw in directions:
419
+ rgb_view = self.pano_to_perpective(
420
+ pano_rgb, pitch, yaw, fov=self.cfg.fov
421
+ )
422
+ cubemaps_rgb.append(rgb_view.cpu())
423
+
424
+ return cubemaps_rgb
425
+
426
+ def save_mesh(self, output_path: str) -> None:
427
+ vertices_np = self.vertices.T.cpu().numpy()
428
+ colors_np = self.colors.T.cpu().numpy()
429
+ faces_np = self.faces.T.cpu().numpy()
430
+ mesh = trimesh.Trimesh(
431
+ vertices=vertices_np, faces=faces_np, vertex_colors=colors_np
432
+ )
433
+
434
+ mesh.export(output_path)
435
+
436
+ def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
437
+ pose = mesh_pose.clone()
438
+ pose[0, :] *= -1
439
+ pose[1, :] *= -1
440
+
441
+ Rw2c = pose[:3, :3].cpu().numpy()
442
+ Tw2c = pose[:3, 3:].cpu().numpy()
443
+ yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
444
+
445
+ Rc2w = (yz_reverse @ Rw2c).T
446
+ Tc2w = -(Rc2w @ yz_reverse @ Tw2c)
447
+ c2w = np.concatenate((Rc2w, Tc2w), axis=1)
448
+ c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
449
+
450
+ return c2w
451
+
452
+ def __call__(self, pano_image: Image.Image | str, output_dir: str):
453
+ self.init_mesh_params()
454
+ pano_rgb, pano_depth = self.preprocess_pano(pano_image)
455
+ self.sup_pool = SupInfoPool()
456
+ self.sup_pool.register_sup_info(
457
+ pose=torch.eye(4).to(self.device),
458
+ mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]),
459
+ rgb=pano_rgb.permute(1, 2, 0),
460
+ distance=pano_depth[..., None],
461
+ )
462
+ self.sup_pool.gen_occ_grid(res=256)
463
+
464
+ logger.info("Init mesh from pano RGBD image...")
465
+ depth_edge = self.get_edge_image_by_depth(pano_depth)
466
+ inpaint_edge_mask = (
467
+ ~torch.from_numpy(depth_edge).to(self.device).bool()
468
+ )
469
+ self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask)
470
+
471
+ repair_poses = self.load_inpaint_poses(self.camera_poses)
472
+ inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection(
473
+ repair_poses, output_dir
474
+ )
475
+ torch.cuda.empty_cache()
476
+ torch.set_default_device("cpu")
477
+
478
+ if self.cfg.mesh_file is not None:
479
+ mesh_path = os.path.join(output_dir, self.cfg.mesh_file)
480
+ self.save_mesh(mesh_path)
481
+
482
+ if self.cfg.gs_data_file is None:
483
+ return
484
+
485
+ logger.info(f"Dump data for 3DGS training...")
486
+ points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8)
487
+ data = {
488
+ "points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3)
489
+ "points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3)
490
+ "train": [],
491
+ "eval": [],
492
+ }
493
+ image_h = self.cfg.cubemap_h * self.cfg.upscale_factor
494
+ image_w = self.cfg.cubemap_w * self.cfg.upscale_factor
495
+ Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov)
496
+ for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses):
497
+ cubemaps = self.pano_to_cubemap(pano_img)
498
+ for i in range(len(cubemaps)):
499
+ cubemap = tensor_to_pil(cubemaps[i])
500
+ cubemap = self.super_model(cubemap)
501
+ mesh_pose = self.cubemap_w2cs[i] @ pano_pose
502
+ c2w = self.mesh_pose_to_gs_pose(mesh_pose)
503
+ data["train"].append(
504
+ {
505
+ "camtoworld": c2w.astype(np.float32),
506
+ "K": Ks.astype(np.float32),
507
+ "image": np.array(cubemap),
508
+ "image_h": image_h,
509
+ "image_w": image_w,
510
+ "image_id": len(cubemaps) * idx + i,
511
+ }
512
+ )
513
+
514
+ # Camera poses for evaluation.
515
+ for idx in range(len(self.camera_poses)):
516
+ c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx])
517
+ data["eval"].append(
518
+ {
519
+ "camtoworld": c2w.astype(np.float32),
520
+ "K": Ks.astype(np.float32),
521
+ "image_h": image_h,
522
+ "image_w": image_w,
523
+ "image_id": idx,
524
+ }
525
+ )
526
+
527
+ data_path = os.path.join(output_dir, self.cfg.gs_data_file)
528
+ torch.save(data, data_path)
529
+
530
+ return
531
+
532
+
533
+ if __name__ == "__main__":
534
+ output_dir = "outputs/bg_v2/test3"
535
+ input_pano = "apps/assets/example_scene/result_pano.png"
536
+ config = Pano2MeshSRConfig()
537
+ pipeline = Pano2MeshSRPipeline(config)
538
+ pipeline(input_pano, output_dir)
embodied_gen/utils/config.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Optional, Union
19
+
20
+ from gsplat.strategy import DefaultStrategy, MCMCStrategy
21
+ from typing_extensions import Literal, assert_never
22
+
23
+ __all__ = [
24
+ "Pano2MeshSRConfig",
25
+ "GsplatTrainConfig",
26
+ ]
27
+
28
+
29
+ @dataclass
30
+ class Pano2MeshSRConfig:
31
+ mesh_file: str = "mesh_model.ply"
32
+ gs_data_file: str = "gs_data.pt"
33
+ device: str = "cuda"
34
+ blur_radius: int = 0
35
+ faces_per_pixel: int = 8
36
+ fov: int = 90
37
+ pano_w: int = 2048
38
+ pano_h: int = 1024
39
+ cubemap_w: int = 512
40
+ cubemap_h: int = 512
41
+ pose_scale: float = 0.6
42
+ pano_center_offset: tuple = (-0.2, 0.3)
43
+ inpaint_frame_stride: int = 20
44
+ trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
45
+ visualize: bool = False
46
+ depth_scale_factor: float = 3.4092
47
+ kernel_size: tuple = (9, 9)
48
+ upscale_factor: int = 4
49
+
50
+
51
+ @dataclass
52
+ class GsplatTrainConfig:
53
+ # Path to the .pt files. If provide, it will skip training and run evaluation only.
54
+ ckpt: Optional[List[str]] = None
55
+ # Render trajectory path
56
+ render_traj_path: str = "interp"
57
+
58
+ # Path to the Mip-NeRF 360 dataset
59
+ data_dir: str = "outputs/bg"
60
+ # Downsample factor for the dataset
61
+ data_factor: int = 4
62
+ # Directory to save results
63
+ result_dir: str = "outputs/bg"
64
+ # Every N images there is a test image
65
+ test_every: int = 8
66
+ # Random crop size for training (experimental)
67
+ patch_size: Optional[int] = None
68
+ # A global scaler that applies to the scene size related parameters
69
+ global_scale: float = 1.0
70
+ # Normalize the world space
71
+ normalize_world_space: bool = True
72
+ # Camera model
73
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
74
+
75
+ # Port for the viewer server
76
+ port: int = 8080
77
+
78
+ # Batch size for training. Learning rates are scaled automatically
79
+ batch_size: int = 1
80
+ # A global factor to scale the number of training steps
81
+ steps_scaler: float = 1.0
82
+
83
+ # Number of training steps
84
+ max_steps: int = 30_000
85
+ # Steps to evaluate the model
86
+ eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
87
+ # Steps to save the model
88
+ save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
89
+ # Whether to save ply file (storage size can be large)
90
+ save_ply: bool = True
91
+ # Steps to save the model as ply
92
+ ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
93
+ # Whether to disable video generation during training and evaluation
94
+ disable_video: bool = False
95
+
96
+ # Initial number of GSs. Ignored if using sfm
97
+ init_num_pts: int = 100_000
98
+ # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
99
+ init_extent: float = 3.0
100
+ # Degree of spherical harmonics
101
+ sh_degree: int = 1
102
+ # Turn on another SH degree every this steps
103
+ sh_degree_interval: int = 1000
104
+ # Initial opacity of GS
105
+ init_opa: float = 0.1
106
+ # Initial scale of GS
107
+ init_scale: float = 1.0
108
+ # Weight for SSIM loss
109
+ ssim_lambda: float = 0.2
110
+
111
+ # Near plane clipping distance
112
+ near_plane: float = 0.01
113
+ # Far plane clipping distance
114
+ far_plane: float = 1e10
115
+
116
+ # Strategy for GS densification
117
+ strategy: Union[DefaultStrategy, MCMCStrategy] = field(
118
+ default_factory=DefaultStrategy
119
+ )
120
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
121
+ packed: bool = False
122
+ # Use sparse gradients for optimization. (experimental)
123
+ sparse_grad: bool = False
124
+ # Use visible adam from Taming 3DGS. (experimental)
125
+ visible_adam: bool = False
126
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
127
+ antialiased: bool = False
128
+
129
+ # Use random background for training to discourage transparency
130
+ random_bkgd: bool = False
131
+
132
+ # LR for 3D point positions
133
+ means_lr: float = 1.6e-4
134
+ # LR for Gaussian scale factors
135
+ scales_lr: float = 5e-3
136
+ # LR for alpha blending weights
137
+ opacities_lr: float = 5e-2
138
+ # LR for orientation (quaternions)
139
+ quats_lr: float = 1e-3
140
+ # LR for SH band 0 (brightness)
141
+ sh0_lr: float = 2.5e-3
142
+ # LR for higher-order SH (detail)
143
+ shN_lr: float = 2.5e-3 / 20
144
+
145
+ # Opacity regularization
146
+ opacity_reg: float = 0.0
147
+ # Scale regularization
148
+ scale_reg: float = 0.0
149
+
150
+ # Enable depth loss. (experimental)
151
+ depth_loss: bool = False
152
+ # Weight for depth loss
153
+ depth_lambda: float = 1e-2
154
+
155
+ # Dump information to tensorboard every this steps
156
+ tb_every: int = 200
157
+ # Save training images to tensorboard
158
+ tb_save_image: bool = False
159
+
160
+ lpips_net: Literal["vgg", "alex"] = "alex"
161
+
162
+ # 3DGUT (uncented transform + eval 3D)
163
+ with_ut: bool = False
164
+ with_eval3d: bool = False
165
+
166
+ scene_scale: float = 1.0
167
+
168
+ def adjust_steps(self, factor: float):
169
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
170
+ self.save_steps = [int(i * factor) for i in self.save_steps]
171
+ self.ply_steps = [int(i * factor) for i in self.ply_steps]
172
+ self.max_steps = int(self.max_steps * factor)
173
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
174
+
175
+ strategy = self.strategy
176
+ if isinstance(strategy, DefaultStrategy):
177
+ strategy.refine_start_iter = int(
178
+ strategy.refine_start_iter * factor
179
+ )
180
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
181
+ strategy.reset_every = int(strategy.reset_every * factor)
182
+ strategy.refine_every = int(strategy.refine_every * factor)
183
+ elif isinstance(strategy, MCMCStrategy):
184
+ strategy.refine_start_iter = int(
185
+ strategy.refine_start_iter * factor
186
+ )
187
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
188
+ strategy.refine_every = int(strategy.refine_every * factor)
189
+ else:
190
+ assert_never(strategy)
embodied_gen/utils/enum.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum
19
+
20
+ from dataclasses_json import DataClassJsonMixin
21
+
22
+ __all__ = [
23
+ "RenderItems",
24
+ "Scene3DItemEnum",
25
+ "SpatialRelationEnum",
26
+ "RobotItemEnum",
27
+ ]
28
+
29
+
30
+ @dataclass
31
+ class RenderItems(str, Enum):
32
+ IMAGE = "image_color"
33
+ ALPHA = "image_mask"
34
+ VIEW_NORMAL = "image_view_normal"
35
+ GLOBAL_NORMAL = "image_global_normal"
36
+ POSITION_MAP = "image_position"
37
+ DEPTH = "image_depth"
38
+ ALBEDO = "image_albedo"
39
+ DIFFUSE = "image_diffuse"
40
+
41
+
42
+ @dataclass
43
+ class Scene3DItemEnum(str, Enum):
44
+ BACKGROUND = "background"
45
+ CONTEXT = "context"
46
+ ROBOT = "robot"
47
+ MANIPULATED_OBJS = "manipulated_objs"
48
+ DISTRACTOR_OBJS = "distractor_objs"
49
+ OTHERS = "others"
50
+
51
+ @classmethod
52
+ def object_list(cls, layout_relation: dict) -> list:
53
+ return (
54
+ [
55
+ layout_relation[cls.BACKGROUND.value],
56
+ layout_relation[cls.CONTEXT.value],
57
+ ]
58
+ + layout_relation[cls.MANIPULATED_OBJS.value]
59
+ + layout_relation[cls.DISTRACTOR_OBJS.value]
60
+ )
61
+
62
+ @classmethod
63
+ def object_mapping(cls, layout_relation):
64
+ relation_mapping = {
65
+ # layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
66
+ layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
67
+ layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value,
68
+ }
69
+ relation_mapping.update(
70
+ {
71
+ item: cls.MANIPULATED_OBJS.value
72
+ for item in layout_relation[cls.MANIPULATED_OBJS.value]
73
+ }
74
+ )
75
+ relation_mapping.update(
76
+ {
77
+ item: cls.DISTRACTOR_OBJS.value
78
+ for item in layout_relation[cls.DISTRACTOR_OBJS.value]
79
+ }
80
+ )
81
+
82
+ return relation_mapping
83
+
84
+
85
+ @dataclass
86
+ class SpatialRelationEnum(str, Enum):
87
+ ON = "ON" # objects on the table
88
+ IN = "IN" # objects in the room
89
+ INSIDE = "INSIDE" # objects inside the shelf/rack
90
+ FLOOR = "FLOOR" # object floor room/bin
91
+
92
+
93
+ @dataclass
94
+ class RobotItemEnum(str, Enum):
95
+ FRANKA = "franka"
96
+ UR5 = "ur5"
97
+ PIPER = "piper"
98
+
99
+
100
+ @dataclass
101
+ class LayoutInfo(DataClassJsonMixin):
102
+ tree: dict[str, list]
103
+ relation: dict[str, str | list[str]]
104
+ objs_desc: dict[str, str] = field(default_factory=dict)
105
+ assets: dict[str, str] = field(default_factory=dict)
106
+ quality: dict[str, str] = field(default_factory=dict)
107
+ position: dict[str, list[float]] = field(default_factory=dict)
embodied_gen/utils/gaussian.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+ # Part of the code comes from https://github.com/nerfstudio-project/gsplat
17
+ # Both under the Apache License, Version 2.0.
18
+
19
+
20
+ import math
21
+ import random
22
+ from io import BytesIO
23
+ from typing import Dict, Literal, Optional, Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ import trimesh
28
+ from gsplat.optimizers import SelectiveAdam
29
+ from scipy.spatial.transform import Rotation
30
+ from sklearn.neighbors import NearestNeighbors
31
+ from torch import Tensor
32
+ from embodied_gen.models.gs_model import GaussianOperator
33
+
34
+ __all__ = [
35
+ "set_random_seed",
36
+ "export_splats",
37
+ "create_splats_with_optimizers",
38
+ "compute_pinhole_intrinsics",
39
+ "resize_pinhole_intrinsics",
40
+ "restore_scene_scale_and_position",
41
+ ]
42
+
43
+
44
+ def knn(x: Tensor, K: int = 4) -> Tensor:
45
+ x_np = x.cpu().numpy()
46
+ model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
47
+ distances, _ = model.kneighbors(x_np)
48
+ return torch.from_numpy(distances).to(x)
49
+
50
+
51
+ def rgb_to_sh(rgb: Tensor) -> Tensor:
52
+ C0 = 0.28209479177387814
53
+ return (rgb - 0.5) / C0
54
+
55
+
56
+ def set_random_seed(seed: int):
57
+ random.seed(seed)
58
+ np.random.seed(seed)
59
+ torch.manual_seed(seed)
60
+
61
+
62
+ def splat2ply_bytes(
63
+ means: torch.Tensor,
64
+ scales: torch.Tensor,
65
+ quats: torch.Tensor,
66
+ opacities: torch.Tensor,
67
+ sh0: torch.Tensor,
68
+ shN: torch.Tensor,
69
+ ) -> bytes:
70
+ num_splats = means.shape[0]
71
+ buffer = BytesIO()
72
+
73
+ # Write PLY header
74
+ buffer.write(b"ply\n")
75
+ buffer.write(b"format binary_little_endian 1.0\n")
76
+ buffer.write(f"element vertex {num_splats}\n".encode())
77
+ buffer.write(b"property float x\n")
78
+ buffer.write(b"property float y\n")
79
+ buffer.write(b"property float z\n")
80
+ for i, data in enumerate([sh0, shN]):
81
+ prefix = "f_dc" if i == 0 else "f_rest"
82
+ for j in range(data.shape[1]):
83
+ buffer.write(f"property float {prefix}_{j}\n".encode())
84
+ buffer.write(b"property float opacity\n")
85
+ for i in range(scales.shape[1]):
86
+ buffer.write(f"property float scale_{i}\n".encode())
87
+ for i in range(quats.shape[1]):
88
+ buffer.write(f"property float rot_{i}\n".encode())
89
+ buffer.write(b"end_header\n")
90
+
91
+ # Concatenate all tensors in the correct order
92
+ splat_data = torch.cat(
93
+ [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
94
+ )
95
+ # Ensure correct dtype
96
+ splat_data = splat_data.to(torch.float32)
97
+
98
+ # Write binary data
99
+ float_dtype = np.dtype(np.float32).newbyteorder("<")
100
+ buffer.write(
101
+ splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
102
+ )
103
+
104
+ return buffer.getvalue()
105
+
106
+
107
+ def export_splats(
108
+ means: torch.Tensor,
109
+ scales: torch.Tensor,
110
+ quats: torch.Tensor,
111
+ opacities: torch.Tensor,
112
+ sh0: torch.Tensor,
113
+ shN: torch.Tensor,
114
+ format: Literal["ply"] = "ply",
115
+ save_to: Optional[str] = None,
116
+ ) -> bytes:
117
+ """Export a Gaussian Splats model to bytes in PLY file format."""
118
+ total_splats = means.shape[0]
119
+ assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
120
+ assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
121
+ assert quats.shape == (
122
+ total_splats,
123
+ 4,
124
+ ), "Quaternions must be of shape (N, 4)"
125
+ assert opacities.shape == (
126
+ total_splats,
127
+ ), "Opacities must be of shape (N,)"
128
+ assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
129
+ assert (
130
+ shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
131
+ ), f"shN must be of shape (N, K, 3), got {shN.shape}"
132
+
133
+ # Reshape spherical harmonics
134
+ sh0 = sh0.squeeze(1) # Shape (N, 3)
135
+ shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
136
+
137
+ # Check for NaN or Inf values
138
+ invalid_mask = (
139
+ torch.isnan(means).any(dim=1)
140
+ | torch.isinf(means).any(dim=1)
141
+ | torch.isnan(scales).any(dim=1)
142
+ | torch.isinf(scales).any(dim=1)
143
+ | torch.isnan(quats).any(dim=1)
144
+ | torch.isinf(quats).any(dim=1)
145
+ | torch.isnan(opacities).any(dim=0)
146
+ | torch.isinf(opacities).any(dim=0)
147
+ | torch.isnan(sh0).any(dim=1)
148
+ | torch.isinf(sh0).any(dim=1)
149
+ | torch.isnan(shN).any(dim=1)
150
+ | torch.isinf(shN).any(dim=1)
151
+ )
152
+
153
+ # Filter out invalid entries
154
+ valid_mask = ~invalid_mask
155
+ means = means[valid_mask]
156
+ scales = scales[valid_mask]
157
+ quats = quats[valid_mask]
158
+ opacities = opacities[valid_mask]
159
+ sh0 = sh0[valid_mask]
160
+ shN = shN[valid_mask]
161
+
162
+ if format == "ply":
163
+ data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
164
+ else:
165
+ raise ValueError(f"Unsupported format: {format}")
166
+
167
+ if save_to:
168
+ with open(save_to, "wb") as binary_file:
169
+ binary_file.write(data)
170
+
171
+ return data
172
+
173
+
174
+ def create_splats_with_optimizers(
175
+ points: np.ndarray = None,
176
+ points_rgb: np.ndarray = None,
177
+ init_num_pts: int = 100_000,
178
+ init_extent: float = 3.0,
179
+ init_opacity: float = 0.1,
180
+ init_scale: float = 1.0,
181
+ means_lr: float = 1.6e-4,
182
+ scales_lr: float = 5e-3,
183
+ opacities_lr: float = 5e-2,
184
+ quats_lr: float = 1e-3,
185
+ sh0_lr: float = 2.5e-3,
186
+ shN_lr: float = 2.5e-3 / 20,
187
+ scene_scale: float = 1.0,
188
+ sh_degree: int = 3,
189
+ sparse_grad: bool = False,
190
+ visible_adam: bool = False,
191
+ batch_size: int = 1,
192
+ feature_dim: Optional[int] = None,
193
+ device: str = "cuda",
194
+ world_rank: int = 0,
195
+ world_size: int = 1,
196
+ ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
197
+ if points is not None and points_rgb is not None:
198
+ points = torch.from_numpy(points).float()
199
+ rgbs = torch.from_numpy(points_rgb / 255.0).float()
200
+ else:
201
+ points = (
202
+ init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
203
+ )
204
+ rgbs = torch.rand((init_num_pts, 3))
205
+
206
+ # Initialize the GS size to be the average dist of the 3 nearest neighbors
207
+ dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
208
+ dist_avg = torch.sqrt(dist2_avg)
209
+ scales = (
210
+ torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
211
+ ) # [N, 3]
212
+
213
+ # Distribute the GSs to different ranks (also works for single rank)
214
+ points = points[world_rank::world_size]
215
+ rgbs = rgbs[world_rank::world_size]
216
+ scales = scales[world_rank::world_size]
217
+
218
+ N = points.shape[0]
219
+ quats = torch.rand((N, 4)) # [N, 4]
220
+ opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
221
+
222
+ params = [
223
+ # name, value, lr
224
+ ("means", torch.nn.Parameter(points), means_lr * scene_scale),
225
+ ("scales", torch.nn.Parameter(scales), scales_lr),
226
+ ("quats", torch.nn.Parameter(quats), quats_lr),
227
+ ("opacities", torch.nn.Parameter(opacities), opacities_lr),
228
+ ]
229
+
230
+ if feature_dim is None:
231
+ # color is SH coefficients.
232
+ colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
233
+ colors[:, 0, :] = rgb_to_sh(rgbs)
234
+ params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
235
+ params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
236
+ else:
237
+ # features will be used for appearance and view-dependent shading
238
+ features = torch.rand(N, feature_dim) # [N, feature_dim]
239
+ params.append(("features", torch.nn.Parameter(features), sh0_lr))
240
+ colors = torch.logit(rgbs) # [N, 3]
241
+ params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
242
+
243
+ splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
244
+ # Scale learning rate based on batch size, reference:
245
+ # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
246
+ # Note that this would not make the training exactly equivalent, see
247
+ # https://arxiv.org/pdf/2402.18824v1
248
+ BS = batch_size * world_size
249
+ optimizer_class = None
250
+ if sparse_grad:
251
+ optimizer_class = torch.optim.SparseAdam
252
+ elif visible_adam:
253
+ optimizer_class = SelectiveAdam
254
+ else:
255
+ optimizer_class = torch.optim.Adam
256
+ optimizers = {
257
+ name: optimizer_class(
258
+ [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
259
+ eps=1e-15 / math.sqrt(BS),
260
+ # TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
261
+ betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
262
+ )
263
+ for name, _, lr in params
264
+ }
265
+ return splats, optimizers
266
+
267
+
268
+ def compute_pinhole_intrinsics(
269
+ image_w: int, image_h: int, fov_deg: float
270
+ ) -> np.ndarray:
271
+ fov_rad = np.deg2rad(fov_deg)
272
+ fx = image_w / (2 * np.tan(fov_rad / 2))
273
+ fy = fx # assuming square pixels
274
+ cx = image_w / 2
275
+ cy = image_h / 2
276
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
277
+
278
+ return K
279
+
280
+
281
+ def resize_pinhole_intrinsics(
282
+ raw_K: np.ndarray | torch.Tensor,
283
+ raw_hw: tuple[int, int],
284
+ new_hw: tuple[int, int],
285
+ ) -> np.ndarray:
286
+ raw_h, raw_w = raw_hw
287
+ new_h, new_w = new_hw
288
+
289
+ scale_x = new_w / raw_w
290
+ scale_y = new_h / raw_h
291
+
292
+ new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
293
+ new_K[0, 0] *= scale_x # fx
294
+ new_K[0, 2] *= scale_x # cx
295
+ new_K[1, 1] *= scale_y # fy
296
+ new_K[1, 2] *= scale_y # cy
297
+
298
+ return new_K
299
+
300
+
301
+ def restore_scene_scale_and_position(
302
+ real_height: float, mesh_path: str, gs_path: str
303
+ ) -> None:
304
+ """Scales a mesh and corresponding GS model to match a given real-world height.
305
+
306
+ Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
307
+ applies scaling and vertical alignment, and updates both the mesh and GS model.
308
+
309
+ Args:
310
+ real_height (float): Target real-world height among Z axis.
311
+ mesh_path (str): Path to the input mesh file.
312
+ gs_path (str): Path to the Gaussian Splatting model file.
313
+ """
314
+ mesh = trimesh.load(mesh_path)
315
+ z_min = np.percentile(mesh.vertices[:, 1], 1)
316
+ z_max = np.percentile(mesh.vertices[:, 1], 99)
317
+ height = z_max - z_min
318
+ scale = real_height / height
319
+
320
+ rot = Rotation.from_quat([0, 1, 0, 0])
321
+ mesh.vertices = rot.apply(mesh.vertices)
322
+ mesh.vertices[:, 1] -= z_min
323
+ mesh.vertices *= scale
324
+ mesh.export(mesh_path)
325
+
326
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
327
+ gs_model = gs_model.get_gaussians(
328
+ instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
329
+ )
330
+ gs_model.rescale(scale)
331
+ gs_model.save_to_ply(gs_path)
embodied_gen/utils/gpt_clients.py CHANGED
@@ -30,12 +30,20 @@ from tenacity import (
30
  stop_after_delay,
31
  wait_random_exponential,
32
  )
33
- from embodied_gen.utils.process_media import combine_images_to_base64
34
 
35
- logging.basicConfig(level=logging.INFO)
 
36
  logger = logging.getLogger(__name__)
37
 
38
 
 
 
 
 
 
 
 
39
  class GPTclient:
40
  """A client to interact with the GPT model via OpenAI or Azure API."""
41
 
@@ -45,6 +53,7 @@ class GPTclient:
45
  api_key: str,
46
  model_name: str = "yfb-gpt-4o",
47
  api_version: str = None,
 
48
  verbose: bool = False,
49
  ):
50
  if api_version is not None:
@@ -63,6 +72,9 @@ class GPTclient:
63
  self.model_name = model_name
64
  self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
65
  self.verbose = verbose
 
 
 
66
  logger.info(f"Using GPT model: {self.model_name}.")
67
 
68
  @retry(
@@ -77,6 +89,7 @@ class GPTclient:
77
  text_prompt: str,
78
  image_base64: Optional[list[str | Image.Image]] = None,
79
  system_role: Optional[str] = None,
 
80
  ) -> Optional[str]:
81
  """Queries the GPT model with a text and optional image prompts.
82
 
@@ -86,6 +99,7 @@ class GPTclient:
86
  or local image paths or PIL.Image to accompany the text prompt.
87
  system_role (Optional[str]): Optional system-level instructions
88
  that specify the behavior of the assistant.
 
89
 
90
  Returns:
91
  Optional[str]: The response content generated by the model based on
@@ -103,11 +117,11 @@ class GPTclient:
103
 
104
  # Process images if provided
105
  if image_base64 is not None:
106
- image_base64 = (
107
- image_base64
108
- if isinstance(image_base64, list)
109
- else [image_base64]
110
- )
111
  for img in image_base64:
112
  if isinstance(img, Image.Image):
113
  buffer = BytesIO()
@@ -142,8 +156,11 @@ class GPTclient:
142
  "frequency_penalty": 0,
143
  "presence_penalty": 0,
144
  "stop": None,
 
145
  }
146
- payload.update({"model": self.model_name})
 
 
147
 
148
  response = None
149
  try:
@@ -159,8 +176,28 @@ class GPTclient:
159
 
160
  return response
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- with open("embodied_gen/utils/gpt_config.yaml", "r") as f:
 
164
  config = yaml.safe_load(f)
165
 
166
  agent_type = config["agent_type"]
@@ -177,32 +214,5 @@ GPT_CLIENT = GPTclient(
177
  api_key=api_key,
178
  api_version=api_version,
179
  model_name=model_name,
 
180
  )
181
-
182
- if __name__ == "__main__":
183
- if "openrouter" in GPT_CLIENT.endpoint:
184
- response = GPT_CLIENT.query(
185
- text_prompt="What is the content in each image?",
186
- image_base64=combine_images_to_base64(
187
- [
188
- "apps/assets/example_image/sample_02.jpg",
189
- "apps/assets/example_image/sample_03.jpg",
190
- ]
191
- ), # input raw image_path if only one image
192
- )
193
- print(response)
194
- else:
195
- response = GPT_CLIENT.query(
196
- text_prompt="What is the content in the images?",
197
- image_base64=[
198
- Image.open("apps/assets/example_image/sample_02.jpg"),
199
- Image.open("apps/assets/example_image/sample_03.jpg"),
200
- ],
201
- )
202
- print(response)
203
-
204
- # test2: text prompt
205
- response = GPT_CLIENT.query(
206
- text_prompt="What is the capital of China?"
207
- )
208
- print(response)
 
30
  stop_after_delay,
31
  wait_random_exponential,
32
  )
33
+ from embodied_gen.utils.process_media import combine_images_to_grid
34
 
35
+ logging.getLogger("httpx").setLevel(logging.WARNING)
36
+ logging.basicConfig(level=logging.WARNING)
37
  logger = logging.getLogger(__name__)
38
 
39
 
40
+ __all__ = [
41
+ "GPTclient",
42
+ ]
43
+
44
+ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
45
+
46
+
47
  class GPTclient:
48
  """A client to interact with the GPT model via OpenAI or Azure API."""
49
 
 
53
  api_key: str,
54
  model_name: str = "yfb-gpt-4o",
55
  api_version: str = None,
56
+ check_connection: bool = True,
57
  verbose: bool = False,
58
  ):
59
  if api_version is not None:
 
72
  self.model_name = model_name
73
  self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
74
  self.verbose = verbose
75
+ if check_connection:
76
+ self.check_connection()
77
+
78
  logger.info(f"Using GPT model: {self.model_name}.")
79
 
80
  @retry(
 
89
  text_prompt: str,
90
  image_base64: Optional[list[str | Image.Image]] = None,
91
  system_role: Optional[str] = None,
92
+ params: Optional[dict] = None,
93
  ) -> Optional[str]:
94
  """Queries the GPT model with a text and optional image prompts.
95
 
 
99
  or local image paths or PIL.Image to accompany the text prompt.
100
  system_role (Optional[str]): Optional system-level instructions
101
  that specify the behavior of the assistant.
102
+ params (Optional[dict]): Additional parameters for GPT setting.
103
 
104
  Returns:
105
  Optional[str]: The response content generated by the model based on
 
117
 
118
  # Process images if provided
119
  if image_base64 is not None:
120
+ if not isinstance(image_base64, list):
121
+ image_base64 = [image_base64]
122
+ # Hardcode tmp because of the openrouter can't input multi images.
123
+ if "openrouter" in self.endpoint:
124
+ image_base64 = combine_images_to_grid(image_base64)
125
  for img in image_base64:
126
  if isinstance(img, Image.Image):
127
  buffer = BytesIO()
 
156
  "frequency_penalty": 0,
157
  "presence_penalty": 0,
158
  "stop": None,
159
+ "model": self.model_name,
160
  }
161
+
162
+ if params:
163
+ payload.update(params)
164
 
165
  response = None
166
  try:
 
176
 
177
  return response
178
 
179
+ def check_connection(self) -> None:
180
+ """Check whether the GPT API connection is working."""
181
+ try:
182
+ response = self.completion_with_backoff(
183
+ messages=[
184
+ {"role": "system", "content": "You are a test system."},
185
+ {"role": "user", "content": "Hello"},
186
+ ],
187
+ model=self.model_name,
188
+ temperature=0,
189
+ max_tokens=100,
190
+ )
191
+ content = response.choices[0].message.content
192
+ logger.info(f"Connection check success.")
193
+ except Exception as e:
194
+ raise ConnectionError(
195
+ f"Failed to connect to GPT API at {self.endpoint}, "
196
+ f"please check setting in `{CONFIG_FILE}` and `README`."
197
+ )
198
 
199
+
200
+ with open(CONFIG_FILE, "r") as f:
201
  config = yaml.safe_load(f)
202
 
203
  agent_type = config["agent_type"]
 
214
  api_key=api_key,
215
  api_version=api_version,
216
  model_name=model_name,
217
+ check_connection=False,
218
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
embodied_gen/utils/log.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import logging
18
+
19
+ from colorlog import ColoredFormatter
20
+
21
+ __all__ = [
22
+ "logger",
23
+ ]
24
+
25
+ LOG_FORMAT = (
26
+ "%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s"
27
+ )
28
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
29
+
30
+ formatter = ColoredFormatter(
31
+ LOG_FORMAT,
32
+ datefmt=DATE_FORMAT,
33
+ log_colors={
34
+ "DEBUG": "cyan",
35
+ "INFO": "green",
36
+ "WARNING": "yellow",
37
+ "ERROR": "red",
38
+ "CRITICAL": "bold_red",
39
+ },
40
+ )
41
+
42
+ handler = logging.StreamHandler()
43
+ handler.setFormatter(formatter)
44
+
45
+ logger = logging.getLogger(__name__)
46
+ logger.setLevel(logging.INFO)
47
+ logger.addHandler(handler)
48
+ logger.propagate = False
embodied_gen/utils/monkey_patches.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import os
18
+ import sys
19
+ import zipfile
20
+
21
+ import torch
22
+ from huggingface_hub import hf_hub_download
23
+ from omegaconf import OmegaConf
24
+ from PIL import Image
25
+ from torchvision import transforms
26
+
27
+
28
+ def monkey_patch_pano2room():
29
+ current_file_path = os.path.abspath(__file__)
30
+ current_dir = os.path.dirname(current_file_path)
31
+ sys.path.append(os.path.join(current_dir, "../.."))
32
+ sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
33
+ from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
34
+ OmnidataNormalPredictor,
35
+ )
36
+ from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
37
+ OmnidataPredictor,
38
+ )
39
+
40
+ def patched_omni_depth_init(self):
41
+ self.img_size = 384
42
+ self.model = torch.hub.load(
43
+ 'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
44
+ )
45
+ self.model.eval()
46
+ self.trans_totensor = transforms.Compose(
47
+ [
48
+ transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
49
+ transforms.CenterCrop(self.img_size),
50
+ transforms.Normalize(mean=0.5, std=0.5),
51
+ ]
52
+ )
53
+
54
+ OmnidataPredictor.__init__ = patched_omni_depth_init
55
+
56
+ def patched_omni_normal_init(self):
57
+ self.img_size = 384
58
+ self.model = torch.hub.load(
59
+ 'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
60
+ )
61
+ self.model.eval()
62
+ self.trans_totensor = transforms.Compose(
63
+ [
64
+ transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
65
+ transforms.CenterCrop(self.img_size),
66
+ transforms.Normalize(mean=0.5, std=0.5),
67
+ ]
68
+ )
69
+
70
+ OmnidataNormalPredictor.__init__ = patched_omni_normal_init
71
+
72
+ def patched_panojoint_init(self, save_path=None):
73
+ self.depth_predictor = OmnidataPredictor()
74
+ self.normal_predictor = OmnidataNormalPredictor()
75
+ self.save_path = save_path
76
+
77
+ from modules.geo_predictors import PanoJointPredictor
78
+
79
+ PanoJointPredictor.__init__ = patched_panojoint_init
80
+
81
+ # NOTE: We use gsplat instead.
82
+ # import depth_diff_gaussian_rasterization_min as ddgr
83
+ # from dataclasses import dataclass
84
+ # @dataclass
85
+ # class PatchedGaussianRasterizationSettings:
86
+ # image_height: int
87
+ # image_width: int
88
+ # tanfovx: float
89
+ # tanfovy: float
90
+ # bg: torch.Tensor
91
+ # scale_modifier: float
92
+ # viewmatrix: torch.Tensor
93
+ # projmatrix: torch.Tensor
94
+ # sh_degree: int
95
+ # campos: torch.Tensor
96
+ # prefiltered: bool
97
+ # debug: bool = False
98
+ # ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
99
+
100
+ # disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
101
+ os.environ["NODE_RANK"] = "0"
102
+
103
+ from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
104
+ load_checkpoint,
105
+ )
106
+ from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
107
+ LamaInpainter,
108
+ )
109
+
110
+ def patched_lama_inpaint_init(self):
111
+ zip_path = hf_hub_download(
112
+ repo_id="smartywu/big-lama",
113
+ filename="big-lama.zip",
114
+ repo_type="model",
115
+ )
116
+ extract_dir = os.path.splitext(zip_path)[0]
117
+
118
+ if not os.path.exists(extract_dir):
119
+ os.makedirs(extract_dir, exist_ok=True)
120
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
121
+ zip_ref.extractall(extract_dir)
122
+
123
+ config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
124
+ checkpoint_path = os.path.join(
125
+ extract_dir, 'big-lama/models/best.ckpt'
126
+ )
127
+ train_config = OmegaConf.load(config_path)
128
+ train_config.training_model.predict_only = True
129
+ train_config.visualizer.kind = 'noop'
130
+
131
+ self.model = load_checkpoint(
132
+ train_config, checkpoint_path, strict=False, map_location='cpu'
133
+ )
134
+ self.model.freeze()
135
+
136
+ LamaInpainter.__init__ = patched_lama_inpaint_init
137
+
138
+ from diffusers import StableDiffusionInpaintPipeline
139
+ from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
140
+ SDFTInpainter,
141
+ )
142
+
143
+ def patched_sd_inpaint_init(self, subset_name=None):
144
+ super(SDFTInpainter, self).__init__()
145
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
146
+ "stabilityai/stable-diffusion-2-inpainting",
147
+ torch_dtype=torch.float16,
148
+ ).to("cuda")
149
+ pipe.enable_model_cpu_offload()
150
+ self.inpaint_pipe = pipe
151
+
152
+ SDFTInpainter.__init__ = patched_sd_inpaint_init
embodied_gen/utils/process_media.py CHANGED
@@ -15,34 +15,25 @@
15
  # permissions and limitations under the License.
16
 
17
 
18
- import base64
19
  import logging
20
  import math
 
21
  import os
22
- import sys
23
  from glob import glob
24
- from io import BytesIO
25
  from typing import Union
26
 
27
  import cv2
28
  import imageio
 
 
29
  import numpy as np
30
- import PIL.Image as Image
31
  import spaces
32
- import torch
33
  from moviepy.editor import VideoFileClip, clips_array
34
- from tqdm import tqdm
35
  from embodied_gen.data.differentiable_render import entrypoint as render_api
36
-
37
- current_file_path = os.path.abspath(__file__)
38
- current_dir = os.path.dirname(current_file_path)
39
- sys.path.append(os.path.join(current_dir, "../.."))
40
- from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
41
- from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
42
- from thirdparty.TRELLIS.trellis.utils.render_utils import (
43
- render_frames,
44
- yaw_pitch_r_fov_to_extrinsics_intrinsics,
45
- )
46
 
47
  logging.basicConfig(level=logging.INFO)
48
  logger = logging.getLogger(__name__)
@@ -53,9 +44,11 @@ __all__ = [
53
  "merge_images_video",
54
  "filter_small_connected_components",
55
  "filter_image_small_connected_components",
56
- "combine_images_to_base64",
57
- "render_mesh",
58
- "render_video",
 
 
59
  ]
60
 
61
 
@@ -66,12 +59,14 @@ def render_asset3d(
66
  distance: float = 5.0,
67
  num_images: int = 1,
68
  elevation: list[float] = (0.0,),
69
- pbr_light_factor: float = 1.5,
70
  return_key: str = "image_color/*",
71
  output_subdir: str = "renders",
72
  gen_color_mp4: bool = False,
73
  gen_viewnormal_mp4: bool = False,
74
  gen_glonormal_mp4: bool = False,
 
 
75
  ) -> list[str]:
76
  input_args = dict(
77
  mesh_path=mesh_path,
@@ -81,14 +76,13 @@ def render_asset3d(
81
  num_images=num_images,
82
  elevation=elevation,
83
  pbr_light_factor=pbr_light_factor,
84
- with_mtl=True,
 
 
 
 
85
  )
86
- if gen_color_mp4:
87
- input_args["gen_color_mp4"] = True
88
- if gen_viewnormal_mp4:
89
- input_args["gen_viewnormal_mp4"] = True
90
- if gen_glonormal_mp4:
91
- input_args["gen_glonormal_mp4"] = True
92
  try:
93
  _ = render_api(**input_args)
94
  except Exception as e:
@@ -168,12 +162,15 @@ def filter_image_small_connected_components(
168
  return image
169
 
170
 
171
- def combine_images_to_base64(
172
  images: list[str | Image.Image],
173
  cat_row_col: tuple[int, int] = None,
174
  target_wh: tuple[int, int] = (512, 512),
175
- ) -> str:
176
  n_images = len(images)
 
 
 
177
  if cat_row_col is None:
178
  n_col = math.ceil(math.sqrt(n_images))
179
  n_row = math.ceil(n_images / n_col)
@@ -182,88 +179,229 @@ def combine_images_to_base64(
182
 
183
  images = [
184
  Image.open(p).convert("RGB") if isinstance(p, str) else p
185
- for p in images[: n_row * n_col]
186
  ]
187
  images = [img.resize(target_wh) for img in images]
188
 
189
  grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
190
- grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
191
 
192
  for idx, img in enumerate(images):
193
  row, col = divmod(idx, n_col)
194
  grid.paste(img, (col * target_wh[0], row * target_wh[1]))
195
 
196
- buffer = BytesIO()
197
- grid.save(buffer, format="PNG")
198
-
199
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
 
 
 
 
 
 
 
 
 
201
 
202
- @spaces.GPU
203
- def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
204
- renderer = MeshRenderer()
205
- renderer.rendering_options.resolution = options.get("resolution", 512)
206
- renderer.rendering_options.near = options.get("near", 1)
207
- renderer.rendering_options.far = options.get("far", 100)
208
- renderer.rendering_options.ssaa = options.get("ssaa", 4)
209
- rets = {}
210
- for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
211
- res = renderer.render(sample, extr, intr)
212
- if "normal" not in rets:
213
- rets["normal"] = []
214
- normal = torch.lerp(
215
- torch.zeros_like(res["normal"]), res["normal"], res["mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  )
217
- normal = np.clip(
218
- normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
219
- ).astype(np.uint8)
220
- rets["normal"].append(normal)
221
 
222
- return rets
 
 
223
 
224
 
225
- @spaces.GPU
226
- def render_video(
227
- sample,
228
- resolution=512,
229
- bg_color=(0, 0, 0),
230
- num_frames=300,
231
- r=2,
232
- fov=40,
233
- **kwargs,
234
- ):
235
- yaws = torch.linspace(0, 2 * 3.1415, num_frames)
236
- yaws = yaws.tolist()
237
- pitch = [0.5] * num_frames
238
- extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
239
- yaws, pitch, r, fov
240
- )
241
- render_fn = (
242
- render_mesh if isinstance(sample, MeshExtractResult) else render_frames
243
- )
244
- result = render_fn(
245
- sample,
246
- extrinsics,
247
- intrinsics,
248
- {"resolution": resolution, "bg_color": bg_color},
249
- **kwargs,
250
- )
251
 
252
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  if __name__ == "__main__":
256
- # Example usage:
257
  merge_video_video(
258
  "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
259
  "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
260
  "merge.mp4",
261
  )
262
-
263
- image_base64 = combine_images_to_base64(
264
- [
265
- "apps/assets/example_image/sample_00.jpg",
266
- "apps/assets/example_image/sample_01.jpg",
267
- "apps/assets/example_image/sample_02.jpg",
268
- ]
269
- )
 
15
  # permissions and limitations under the License.
16
 
17
 
 
18
  import logging
19
  import math
20
+ import mimetypes
21
  import os
22
+ import textwrap
23
  from glob import glob
 
24
  from typing import Union
25
 
26
  import cv2
27
  import imageio
28
+ import matplotlib.pyplot as plt
29
+ import networkx as nx
30
  import numpy as np
 
31
  import spaces
32
+ from matplotlib.patches import Patch
33
  from moviepy.editor import VideoFileClip, clips_array
34
+ from PIL import Image
35
  from embodied_gen.data.differentiable_render import entrypoint as render_api
36
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
 
 
 
 
 
 
 
 
 
37
 
38
  logging.basicConfig(level=logging.INFO)
39
  logger = logging.getLogger(__name__)
 
44
  "merge_images_video",
45
  "filter_small_connected_components",
46
  "filter_image_small_connected_components",
47
+ "combine_images_to_grid",
48
+ "SceneTreeVisualizer",
49
+ "is_image_file",
50
+ "parse_text_prompts",
51
+ "check_object_edge_truncated",
52
  ]
53
 
54
 
 
59
  distance: float = 5.0,
60
  num_images: int = 1,
61
  elevation: list[float] = (0.0,),
62
+ pbr_light_factor: float = 1.2,
63
  return_key: str = "image_color/*",
64
  output_subdir: str = "renders",
65
  gen_color_mp4: bool = False,
66
  gen_viewnormal_mp4: bool = False,
67
  gen_glonormal_mp4: bool = False,
68
+ no_index_file: bool = False,
69
+ with_mtl: bool = True,
70
  ) -> list[str]:
71
  input_args = dict(
72
  mesh_path=mesh_path,
 
76
  num_images=num_images,
77
  elevation=elevation,
78
  pbr_light_factor=pbr_light_factor,
79
+ with_mtl=with_mtl,
80
+ gen_color_mp4=gen_color_mp4,
81
+ gen_viewnormal_mp4=gen_viewnormal_mp4,
82
+ gen_glonormal_mp4=gen_glonormal_mp4,
83
+ no_index_file=no_index_file,
84
  )
85
+
 
 
 
 
 
86
  try:
87
  _ = render_api(**input_args)
88
  except Exception as e:
 
162
  return image
163
 
164
 
165
+ def combine_images_to_grid(
166
  images: list[str | Image.Image],
167
  cat_row_col: tuple[int, int] = None,
168
  target_wh: tuple[int, int] = (512, 512),
169
+ ) -> list[str | Image.Image]:
170
  n_images = len(images)
171
+ if n_images == 1:
172
+ return images
173
+
174
  if cat_row_col is None:
175
  n_col = math.ceil(math.sqrt(n_images))
176
  n_row = math.ceil(n_images / n_col)
 
179
 
180
  images = [
181
  Image.open(p).convert("RGB") if isinstance(p, str) else p
182
+ for p in images
183
  ]
184
  images = [img.resize(target_wh) for img in images]
185
 
186
  grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
187
+ grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0))
188
 
189
  for idx, img in enumerate(images):
190
  row, col = divmod(idx, n_col)
191
  grid.paste(img, (col * target_wh[0], row * target_wh[1]))
192
 
193
+ return [grid]
194
+
195
+
196
+ class SceneTreeVisualizer:
197
+ def __init__(self, layout_info: LayoutInfo) -> None:
198
+ self.tree = layout_info.tree
199
+ self.relation = layout_info.relation
200
+ self.objs_desc = layout_info.objs_desc
201
+ self.G = nx.DiGraph()
202
+ self.root = self._find_root()
203
+ self._build_graph()
204
+
205
+ self.role_colors = {
206
+ Scene3DItemEnum.BACKGROUND.value: "plum",
207
+ Scene3DItemEnum.CONTEXT.value: "lightblue",
208
+ Scene3DItemEnum.ROBOT.value: "lightcoral",
209
+ Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen",
210
+ Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray",
211
+ Scene3DItemEnum.OTHERS.value: "orange",
212
+ }
213
+
214
+ def _find_root(self) -> str:
215
+ children = {c for cs in self.tree.values() for c, _ in cs}
216
+ parents = set(self.tree.keys())
217
+ roots = parents - children
218
+ if not roots:
219
+ raise ValueError("No root node found.")
220
+ return next(iter(roots))
221
+
222
+ def _build_graph(self):
223
+ for parent, children in self.tree.items():
224
+ for child, relation in children:
225
+ self.G.add_edge(parent, child, relation=relation)
226
+
227
+ def _get_node_role(self, node: str) -> str:
228
+ if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value):
229
+ return Scene3DItemEnum.BACKGROUND.value
230
+ if node == self.relation.get(Scene3DItemEnum.CONTEXT.value):
231
+ return Scene3DItemEnum.CONTEXT.value
232
+ if node == self.relation.get(Scene3DItemEnum.ROBOT.value):
233
+ return Scene3DItemEnum.ROBOT.value
234
+ if node in self.relation.get(
235
+ Scene3DItemEnum.MANIPULATED_OBJS.value, []
236
+ ):
237
+ return Scene3DItemEnum.MANIPULATED_OBJS.value
238
+ if node in self.relation.get(
239
+ Scene3DItemEnum.DISTRACTOR_OBJS.value, []
240
+ ):
241
+ return Scene3DItemEnum.DISTRACTOR_OBJS.value
242
+ return Scene3DItemEnum.OTHERS.value
243
+
244
+ def _get_positions(
245
+ self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None
246
+ ):
247
+ if pos is None:
248
+ pos = {root: (xcenter, vert_loc)}
249
+ else:
250
+ pos[root] = (xcenter, vert_loc)
251
+
252
+ children = list(self.G.successors(root))
253
+ if children:
254
+ dx = width / len(children)
255
+ next_x = xcenter - width / 2 - dx / 2
256
+ for child in children:
257
+ next_x += dx
258
+ pos = self._get_positions(
259
+ child,
260
+ width=dx,
261
+ vert_gap=vert_gap,
262
+ vert_loc=vert_loc - vert_gap,
263
+ xcenter=next_x,
264
+ pos=pos,
265
+ )
266
+ return pos
267
+
268
+ def render(
269
+ self,
270
+ save_path: str,
271
+ figsize=(8, 6),
272
+ dpi=300,
273
+ title: str = "Scene 3D Hierarchy Tree",
274
+ ):
275
+ node_colors = [
276
+ self.role_colors[self._get_node_role(n)] for n in self.G.nodes
277
+ ]
278
+ pos = self._get_positions(self.root)
279
+
280
+ plt.figure(figsize=figsize)
281
+ nx.draw(
282
+ self.G,
283
+ pos,
284
+ with_labels=True,
285
+ arrows=False,
286
+ node_size=2000,
287
+ node_color=node_colors,
288
+ font_size=10,
289
+ font_weight="bold",
290
+ )
291
 
292
+ # Draw edge labels
293
+ edge_labels = nx.get_edge_attributes(self.G, "relation")
294
+ nx.draw_networkx_edge_labels(
295
+ self.G,
296
+ pos,
297
+ edge_labels=edge_labels,
298
+ font_size=9,
299
+ font_color="black",
300
+ )
301
 
302
+ # Draw small description text under each node (if available)
303
+ for node, (x, y) in pos.items():
304
+ desc = self.objs_desc.get(node)
305
+ if desc:
306
+ wrapped = "\n".join(textwrap.wrap(desc, width=30))
307
+ plt.text(
308
+ x,
309
+ y - 0.006,
310
+ wrapped,
311
+ fontsize=6,
312
+ ha="center",
313
+ va="top",
314
+ wrap=True,
315
+ color="black",
316
+ bbox=dict(
317
+ facecolor="dimgray",
318
+ edgecolor="darkgray",
319
+ alpha=0.1,
320
+ boxstyle="round,pad=0.2",
321
+ ),
322
+ )
323
+
324
+ plt.title(title, fontsize=12)
325
+ task_desc = self.relation.get("task_desc", "")
326
+ if task_desc:
327
+ plt.suptitle(
328
+ f"Task Description: {task_desc}", fontsize=10, y=0.999
329
+ )
330
+
331
+ plt.axis("off")
332
+
333
+ legend_handles = [
334
+ Patch(facecolor=color, edgecolor='black', label=role)
335
+ for role, color in self.role_colors.items()
336
+ ]
337
+ plt.legend(
338
+ handles=legend_handles,
339
+ loc="lower center",
340
+ ncol=3,
341
+ bbox_to_anchor=(0.5, -0.1),
342
+ fontsize=9,
343
  )
 
 
 
 
344
 
345
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
346
+ plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
347
+ plt.close()
348
 
349
 
350
+ def load_scene_dict(file_path: str) -> dict:
351
+ scene_dict = {}
352
+ with open(file_path, "r", encoding='utf-8') as f:
353
+ for line in f:
354
+ line = line.strip()
355
+ if not line or ":" not in line:
356
+ continue
357
+ scene_id, desc = line.split(":", 1)
358
+ scene_dict[scene_id.strip()] = desc.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
+ return scene_dict
361
+
362
+
363
+ def is_image_file(filename: str) -> bool:
364
+ mime_type, _ = mimetypes.guess_type(filename)
365
+
366
+ return mime_type is not None and mime_type.startswith('image')
367
+
368
+
369
+ def parse_text_prompts(prompts: list[str]) -> list[str]:
370
+ if len(prompts) == 1 and prompts[0].endswith(".txt"):
371
+ with open(prompts[0], "r") as f:
372
+ prompts = [
373
+ line.strip()
374
+ for line in f
375
+ if line.strip() and not line.strip().startswith("#")
376
+ ]
377
+ return prompts
378
+
379
+
380
+ def check_object_edge_truncated(
381
+ mask: np.ndarray, edge_threshold: int = 5
382
+ ) -> bool:
383
+ """Checks if a binary object mask is truncated at the image edges.
384
+
385
+ Args:
386
+ mask: A 2D binary NumPy array where nonzero values indicate the object region.
387
+ edge_threshold: Number of pixels from each image edge to consider for truncation.
388
+ Defaults to 5.
389
+
390
+ Returns:
391
+ True if the object is fully enclosed (not truncated).
392
+ False if the object touches or crosses any image boundary.
393
+ """
394
+ top = mask[:edge_threshold, :].any()
395
+ bottom = mask[-edge_threshold:, :].any()
396
+ left = mask[:, :edge_threshold].any()
397
+ right = mask[:, -edge_threshold:].any()
398
+
399
+ return not (top or bottom or left or right)
400
 
401
 
402
  if __name__ == "__main__":
 
403
  merge_video_video(
404
  "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
405
  "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
406
  "merge.mp4",
407
  )
 
 
 
 
 
 
 
 
embodied_gen/utils/tags.py CHANGED
@@ -1 +1 @@
1
- VERSION = "v0.1.0"
 
1
+ VERSION = "v0.1.2"
embodied_gen/utils/trender.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import os
18
+ import sys
19
+
20
+ import numpy as np
21
+ import spaces
22
+ import torch
23
+ from tqdm import tqdm
24
+
25
+ current_file_path = os.path.abspath(__file__)
26
+ current_dir = os.path.dirname(current_file_path)
27
+ sys.path.append(os.path.join(current_dir, "../.."))
28
+ from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
29
+ from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
30
+ from thirdparty.TRELLIS.trellis.utils.render_utils import (
31
+ render_frames,
32
+ yaw_pitch_r_fov_to_extrinsics_intrinsics,
33
+ )
34
+
35
+ __all__ = [
36
+ "render_video",
37
+ ]
38
+
39
+
40
+ @spaces.GPU
41
+ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
42
+ renderer = MeshRenderer()
43
+ renderer.rendering_options.resolution = options.get("resolution", 512)
44
+ renderer.rendering_options.near = options.get("near", 1)
45
+ renderer.rendering_options.far = options.get("far", 100)
46
+ renderer.rendering_options.ssaa = options.get("ssaa", 4)
47
+ rets = {}
48
+ for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
49
+ res = renderer.render(sample, extr, intr)
50
+ if "normal" not in rets:
51
+ rets["normal"] = []
52
+ normal = torch.lerp(
53
+ torch.zeros_like(res["normal"]), res["normal"], res["mask"]
54
+ )
55
+ normal = np.clip(
56
+ normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
57
+ ).astype(np.uint8)
58
+ rets["normal"].append(normal)
59
+
60
+ return rets
61
+
62
+
63
+ @spaces.GPU
64
+ def render_video(
65
+ sample,
66
+ resolution=512,
67
+ bg_color=(0, 0, 0),
68
+ num_frames=300,
69
+ r=2,
70
+ fov=40,
71
+ **kwargs,
72
+ ):
73
+ yaws = torch.linspace(0, 2 * 3.1415, num_frames)
74
+ yaws = yaws.tolist()
75
+ pitch = [0.5] * num_frames
76
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
77
+ yaws, pitch, r, fov
78
+ )
79
+ render_fn = (
80
+ render_mesh if isinstance(sample, MeshExtractResult) else render_frames
81
+ )
82
+ result = render_fn(
83
+ sample,
84
+ extrinsics,
85
+ intrinsics,
86
+ {"resolution": resolution, "bg_color": bg_color},
87
+ **kwargs,
88
+ )
89
+
90
+ return result
embodied_gen/validators/aesthetic_predictor.py CHANGED
@@ -102,7 +102,7 @@ class AestheticPredictor:
102
  def _load_sac_model(self, model_path, input_size):
103
  """Load the SAC model."""
104
  model = self.MLP(input_size)
105
- ckpt = torch.load(model_path)
106
  model.load_state_dict(ckpt)
107
  model.to(self.device)
108
  model.eval()
@@ -135,15 +135,3 @@ class AestheticPredictor:
135
  )
136
 
137
  return prediction.item()
138
-
139
-
140
- if __name__ == "__main__":
141
- # Configuration
142
- img_path = "apps/assets/example_image/sample_00.jpg"
143
-
144
- # Initialize the predictor
145
- predictor = AestheticPredictor()
146
-
147
- # Predict the aesthetic score
148
- score = predictor.predict(img_path)
149
- print("Aesthetic score predicted by the model:", score)
 
102
  def _load_sac_model(self, model_path, input_size):
103
  """Load the SAC model."""
104
  model = self.MLP(input_size)
105
+ ckpt = torch.load(model_path, weights_only=True)
106
  model.load_state_dict(ckpt)
107
  model.to(self.device)
108
  model.eval()
 
135
  )
136
 
137
  return prediction.item()
 
 
 
 
 
 
 
 
 
 
 
 
embodied_gen/validators/quality_checkers.py CHANGED
@@ -16,17 +16,29 @@
16
 
17
 
18
  import logging
19
- import os
20
 
21
- from tqdm import tqdm
 
22
  from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
23
- from embodied_gen.utils.process_media import render_asset3d
24
  from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class BaseChecker:
31
  def __init__(self, prompt: str = None, verbose: bool = False) -> None:
32
  self.prompt = prompt
@@ -37,16 +49,20 @@ class BaseChecker:
37
  "Subclasses must implement the query method."
38
  )
39
 
40
- def __call__(self, *args, **kwargs) -> bool:
41
  response = self.query(*args, **kwargs)
42
- if response is None:
43
- response = "Error when calling gpt api."
44
-
45
- if self.verbose and response != "YES":
46
  logger.info(response)
47
 
48
- flag = "YES" in response
49
- response = "YES" if flag else response
 
 
 
 
 
 
 
50
 
51
  return flag, response
52
 
@@ -92,21 +108,29 @@ class MeshGeoChecker(BaseChecker):
92
  self.gpt_client = gpt_client
93
  if self.prompt is None:
94
  self.prompt = """
95
- Refer to the provided multi-view rendering images to evaluate
96
- whether the geometry of the 3D object asset is complete and
97
- whether the asset can be placed stably on the ground.
98
- Return "YES" only if reach the requirments,
99
- otherwise "NO" and explain the reason very briefly.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  """
101
 
102
- def query(self, image_paths: str) -> str:
103
- # Hardcode tmp because of the openrouter can't input multi images.
104
- if "openrouter" in self.gpt_client.endpoint:
105
- from embodied_gen.utils.process_media import (
106
- combine_images_to_base64,
107
- )
108
-
109
- image_paths = combine_images_to_base64(image_paths)
110
 
111
  return self.gpt_client.query(
112
  text_prompt=self.prompt,
@@ -137,14 +161,19 @@ class ImageSegChecker(BaseChecker):
137
  self.gpt_client = gpt_client
138
  if self.prompt is None:
139
  self.prompt = """
140
- The first image is the original, and the second image is the
141
- result after segmenting the main object. Evaluate the segmentation
142
- quality to ensure the main object is clearly segmented without
143
- significant truncation. Note that the foreground of the object
144
- needs to be extracted instead of the background.
145
- Minor imperfections can be ignored. If segmentation is acceptable,
146
- return "YES" only; otherwise, return "NO" with
147
- very brief explanation.
 
 
 
 
 
148
  """
149
 
150
  def query(self, image_paths: list[str]) -> str:
@@ -152,13 +181,6 @@ class ImageSegChecker(BaseChecker):
152
  raise ValueError(
153
  "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
154
  )
155
- # Hardcode tmp because of the openrouter can't input multi images.
156
- if "openrouter" in self.gpt_client.endpoint:
157
- from embodied_gen.utils.process_media import (
158
- combine_images_to_base64,
159
- )
160
-
161
- image_paths = combine_images_to_base64(image_paths)
162
 
163
  return self.gpt_client.query(
164
  text_prompt=self.prompt,
@@ -201,42 +223,358 @@ class ImageAestheticChecker(BaseChecker):
201
  return avg_score > self.thresh, avg_score
202
 
203
 
204
- if __name__ == "__main__":
205
- geo_checker = MeshGeoChecker(GPT_CLIENT)
206
- seg_checker = ImageSegChecker(GPT_CLIENT)
207
- aesthetic_checker = ImageAestheticChecker()
208
-
209
- checkers = [geo_checker, seg_checker, aesthetic_checker]
210
-
211
- output_root = "outputs/test_gpt"
212
-
213
- fails = []
214
- for idx in tqdm(range(150)):
215
- mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
216
- if not os.path.exists(mesh_path):
217
- continue
218
- image_paths = render_asset3d(
219
- mesh_path,
220
- f"{output_root}/{idx}",
221
- num_images=8,
222
- elevation=(30, -30),
223
- distance=5.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  )
225
 
226
- for cid, checker in enumerate(checkers):
227
- if isinstance(checker, ImageSegChecker):
228
- images = [
229
- f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
230
- f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
231
- ]
232
- else:
233
- images = image_paths
234
- result, info = checker(images)
235
- logger.info(
236
- f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  )
 
 
 
 
238
 
239
- if result is False:
240
- fails.append((idx, cid, info))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  import logging
19
+ import random
20
 
21
+ import json_repair
22
+ from PIL import Image
23
  from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
 
24
  from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
 
30
+ __all__ = [
31
+ "MeshGeoChecker",
32
+ "ImageSegChecker",
33
+ "ImageAestheticChecker",
34
+ "SemanticConsistChecker",
35
+ "TextGenAlignChecker",
36
+ "PanoImageGenChecker",
37
+ "PanoHeightEstimator",
38
+ "PanoImageOccChecker",
39
+ ]
40
+
41
+
42
  class BaseChecker:
43
  def __init__(self, prompt: str = None, verbose: bool = False) -> None:
44
  self.prompt = prompt
 
49
  "Subclasses must implement the query method."
50
  )
51
 
52
+ def __call__(self, *args, **kwargs) -> tuple[bool, str]:
53
  response = self.query(*args, **kwargs)
54
+ if self.verbose:
 
 
 
55
  logger.info(response)
56
 
57
+ if response is None:
58
+ flag = None
59
+ response = (
60
+ "Error when calling GPT api, check config in "
61
+ "`embodied_gen/utils/gpt_config.yaml` or net connection."
62
+ )
63
+ else:
64
+ flag = "YES" in response
65
+ response = "YES" if flag else response
66
 
67
  return flag, response
68
 
 
108
  self.gpt_client = gpt_client
109
  if self.prompt is None:
110
  self.prompt = """
111
+ You are an expert in evaluating the geometry quality of generated 3D asset.
112
+ You will be given rendered views of a generated 3D asset with black background.
113
+ Your task is to evaluate the quality of the 3D asset generation,
114
+ including geometry, structure, and appearance, based on the rendered views.
115
+ Criteria:
116
+ - Is the object in the image a single, complete, and well-formed instance,
117
+ without truncation, missing parts, overlapping duplicates, or redundant geometry?
118
+ - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
119
+ soft edges) are acceptable if the object is structurally sound and recognizable.
120
+ - Only evaluate geometry. Do not assess texture quality.
121
+ - The asset should not contain any unrelated elements, such as
122
+ ground planes, platforms, or background props (e.g., paper, flooring).
123
+
124
+ If all the above criteria are met, return "YES". Otherwise, return
125
+ "NO" followed by a brief explanation (no more than 20 words).
126
+
127
+ Example:
128
+ Images show a yellow cup standing on a flat white plane -> NO
129
+ -> Response: NO: extra white surface under the object.
130
+ Image shows a chair with simplified back legs and soft edges β†’ YES
131
  """
132
 
133
+ def query(self, image_paths: list[str | Image.Image]) -> str:
 
 
 
 
 
 
 
134
 
135
  return self.gpt_client.query(
136
  text_prompt=self.prompt,
 
161
  self.gpt_client = gpt_client
162
  if self.prompt is None:
163
  self.prompt = """
164
+ Task: Evaluate the quality of object segmentation between two images:
165
+ the first is the original, the second is the segmented result.
166
+
167
+ Criteria:
168
+ - The main foreground object should be clearly extracted (not the background).
169
+ - The object must appear realistic, with reasonable geometry and color.
170
+ - The object should be geometrically complete β€” no missing, truncated, or cropped parts.
171
+ - The object must be centered, with a margin on all sides.
172
+ - Ignore minor imperfections (e.g., small holes or fine edge artifacts).
173
+
174
+ Output Rules:
175
+ If segmentation is acceptable, respond with "YES" (and nothing else).
176
+ If not acceptable, respond with "NO", followed by a brief reason (max 20 words).
177
  """
178
 
179
  def query(self, image_paths: list[str]) -> str:
 
181
  raise ValueError(
182
  "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
183
  )
 
 
 
 
 
 
 
184
 
185
  return self.gpt_client.query(
186
  text_prompt=self.prompt,
 
223
  return avg_score > self.thresh, avg_score
224
 
225
 
226
+ class SemanticConsistChecker(BaseChecker):
227
+ def __init__(
228
+ self,
229
+ gpt_client: GPTclient,
230
+ prompt: str = None,
231
+ verbose: bool = False,
232
+ ) -> None:
233
+ super().__init__(prompt, verbose)
234
+ self.gpt_client = gpt_client
235
+ if self.prompt is None:
236
+ self.prompt = """
237
+ You are an expert in image-text consistency assessment.
238
+ You will be given:
239
+ - A short text description of an object.
240
+ - An segmented image of the same object with the background removed.
241
+
242
+ Criteria:
243
+ - The image must visually match the text description in terms of object type, structure, geometry, and color.
244
+ - The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number
245
+ of legs with a reasonable distribution. Count the number of legs visible in the image. (strict) For tables,
246
+ fewer than four legs or if the legs are unevenly distributed, are not allowed. Do not assume
247
+ hidden legs unless they are clearly visible.)
248
+ - Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
249
+ - The image must contain exactly one object. Multiple distinct objects are not allowed.
250
+ A single composite object (e.g., a chair with legs) is acceptable.
251
+ - The object should be shown from a slightly angled (three-quarter) perspective,
252
+ not a flat, front-facing view showing only one surface.
253
+
254
+ Instructions:
255
+ - If all criteria are met, return `"YES"`.
256
+ - Otherwise, return "NO" with a brief explanation (max 20 words).
257
+
258
+ Respond in exactly one of the following formats:
259
+ YES
260
+ or
261
+ NO: brief explanation.
262
+
263
+ Input:
264
+ {}
265
+ """
266
+
267
+ def query(self, text: str, image: list[Image.Image | str]) -> str:
268
+
269
+ return self.gpt_client.query(
270
+ text_prompt=self.prompt.format(text),
271
+ image_base64=image,
272
  )
273
 
274
+
275
+ class TextGenAlignChecker(BaseChecker):
276
+ def __init__(
277
+ self,
278
+ gpt_client: GPTclient,
279
+ prompt: str = None,
280
+ verbose: bool = False,
281
+ ) -> None:
282
+ super().__init__(prompt, verbose)
283
+ self.gpt_client = gpt_client
284
+ if self.prompt is None:
285
+ self.prompt = """
286
+ You are an expert in evaluating the quality of generated 3D assets.
287
+ You will be given:
288
+ - A text description of an object: TEXT
289
+ - Rendered views of the generated 3D asset.
290
+
291
+ Your task is to:
292
+ 1. Determine whether the generated 3D asset roughly reflects the object class
293
+ or a semantically adjacent category described in the text.
294
+ 2. Evaluate the geometry quality of the 3D asset generation based on the rendered views.
295
+
296
+ Criteria:
297
+ - Determine if the generated 3D asset belongs to the text described or a similar category.
298
+ - Focus on functional similarity: if the object serves the same general
299
+ purpose (e.g., writing, placing items), it should be accepted.
300
+ - Is the geometry complete and well-formed, with no missing parts,
301
+ distortions, visual artifacts, or redundant structures?
302
+ - Does the number of object instances match the description?
303
+ There should be only one object unless otherwise specified.
304
+ - Minor flaws in geometry or texture are acceptable, high tolerance for texture quality defects.
305
+ - Minor simplifications in geometry or texture (e.g. soft edges, less detail)
306
+ are acceptable if the object is still recognizable.
307
+ - The asset should not contain any unrelated elements, such as
308
+ ground planes, platforms, or background props (e.g., paper, flooring).
309
+
310
+ Example:
311
+ Text: "yellow cup"
312
+ Image: shows a yellow cup standing on a flat white plane -> NO: extra surface under the object.
313
+
314
+ Instructions:
315
+ - If the quality of generated asset is acceptable and faithfully represents the text, return "YES".
316
+ - Otherwise, return "NO" followed by a brief explanation (no more than 20 words).
317
+
318
+ Respond in exactly one of the following formats:
319
+ YES
320
+ or
321
+ NO: brief explanation
322
+
323
+ Input:
324
+ Text description: {}
325
+ """
326
+
327
+ def query(self, text: str, image: list[Image.Image | str]) -> str:
328
+
329
+ return self.gpt_client.query(
330
+ text_prompt=self.prompt.format(text),
331
+ image_base64=image,
332
+ )
333
+
334
+
335
+ class PanoImageGenChecker(BaseChecker):
336
+ """A checker class that validates the quality and realism of generated panoramic indoor images.
337
+
338
+ Attributes:
339
+ gpt_client (GPTclient): A GPT client instance used to query for image validation.
340
+ prompt (str): The instruction prompt passed to the GPT model. If None, a default prompt is used.
341
+ verbose (bool): Whether to print internal processing information for debugging.
342
+ """
343
+
344
+ def __init__(
345
+ self,
346
+ gpt_client: GPTclient,
347
+ prompt: str = None,
348
+ verbose: bool = False,
349
+ ) -> None:
350
+ super().__init__(prompt, verbose)
351
+ self.gpt_client = gpt_client
352
+ if self.prompt is None:
353
+ self.prompt = """
354
+ You are a panoramic image analyzer specializing in indoor room structure validation.
355
+
356
+ Given a generated panoramic image, assess if it meets all the criteria:
357
+ - Floor Space: β‰₯30 percent of the floor is free of objects or obstructions.
358
+ - Visual Clarity: Floor, walls, and ceiling are clear, with no distortion, blur, noise.
359
+ - Structural Continuity: Surfaces form plausible, continuous geometry
360
+ without breaks, floating parts, or abrupt cuts.
361
+ - Spatial Completeness: Full 360Β° coverage without missing areas,
362
+ seams, gaps, or stitching artifacts.
363
+ Instructions:
364
+ - If all criteria are met, reply with "YES".
365
+ - Otherwise, reply with "NO: <brief explanation>" (max 20 words).
366
+
367
+ Respond exactly as:
368
+ "YES"
369
+ or
370
+ "NO: brief explanation."
371
+ """
372
+
373
+ def query(self, image_paths: str | Image.Image) -> str:
374
+
375
+ return self.gpt_client.query(
376
+ text_prompt=self.prompt,
377
+ image_base64=image_paths,
378
+ )
379
+
380
+
381
+ class PanoImageOccChecker(BaseChecker):
382
+ """Checks for physical obstacles in the bottom-center region of a panoramic image.
383
+
384
+ This class crops a specified region from the input panoramic image and uses
385
+ a GPT client to determine whether any physical obstacles there.
386
+
387
+ Args:
388
+ gpt_client (GPTclient): The GPT-based client used for visual reasoning.
389
+ box_hw (tuple[int, int]): The height and width of the crop box.
390
+ prompt (str, optional): Custom prompt for the GPT client. Defaults to a predefined one.
391
+ verbose (bool, optional): Whether to print verbose logs. Defaults to False.
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ gpt_client: GPTclient,
397
+ box_hw: tuple[int, int],
398
+ prompt: str = None,
399
+ verbose: bool = False,
400
+ ) -> None:
401
+ super().__init__(prompt, verbose)
402
+ self.gpt_client = gpt_client
403
+ self.box_hw = box_hw
404
+ if self.prompt is None:
405
+ self.prompt = """
406
+ This image is a cropped region from the bottom-center of a panoramic view.
407
+ Please determine whether there is any obstacle present β€” such as furniture, tables, or other physical objects.
408
+ Ignore floor textures, rugs, carpets, shadows, and lighting effects β€” they do not count as obstacles.
409
+ Only consider real, physical objects that could block walking or movement.
410
+
411
+ Instructions:
412
+ - If there is no obstacle, reply: "YES".
413
+ - Otherwise, reply: "NO: <brief explanation>" (max 20 words).
414
+
415
+ Respond exactly as:
416
+ "YES"
417
+ or
418
+ "NO: brief explanation."
419
+ """
420
+
421
+ def query(self, image_paths: str | Image.Image) -> str:
422
+ if isinstance(image_paths, str):
423
+ image_paths = Image.open(image_paths)
424
+
425
+ w, h = image_paths.size
426
+ image_paths = image_paths.crop(
427
+ (
428
+ (w - self.box_hw[1]) // 2,
429
+ h - self.box_hw[0],
430
+ (w + self.box_hw[1]) // 2,
431
+ h,
432
+ )
433
+ )
434
+
435
+ return self.gpt_client.query(
436
+ text_prompt=self.prompt,
437
+ image_base64=image_paths,
438
+ )
439
+
440
+
441
+ class PanoHeightEstimator(object):
442
+ """Estimate the real ceiling height of an indoor space from a 360Β° panoramic image.
443
+
444
+ Attributes:
445
+ gpt_client (GPTclient): The GPT client used to perform image-based reasoning and return height estimates.
446
+ default_value (float): The fallback height in meters if parsing the GPT output fails.
447
+ prompt (str): The textual instruction used to guide the GPT model for height estimation.
448
+ """
449
+
450
+ def __init__(
451
+ self,
452
+ gpt_client: GPTclient,
453
+ default_value: float = 3.5,
454
+ ) -> None:
455
+ self.gpt_client = gpt_client
456
+ self.default_value = default_value
457
+ self.prompt = """
458
+ You are an expert in building height estimation and panoramic image analysis.
459
+ Your task is to analyze a 360Β° indoor panoramic image and estimate the **actual height** of the space in meters.
460
+
461
+ Consider the following visual cues:
462
+ 1. Ceiling visibility and reference objects (doors, windows, furniture, appliances).
463
+ 2. Floor features or level differences.
464
+ 3. Room type (e.g., residential, office, commercial).
465
+ 4. Object-to-ceiling proportions (e.g., height of doors relative to ceiling).
466
+ 5. Architectural elements (e.g., chandeliers, shelves, kitchen cabinets).
467
+
468
+ Input: A full 360Β° panoramic indoor photo.
469
+ Output: A single number in meters representing the estimated room height. Only return the number (e.g., `3.2`)
470
+ """
471
+
472
+ def __call__(self, image_paths: str | Image.Image) -> float:
473
+ result = self.gpt_client.query(
474
+ text_prompt=self.prompt,
475
+ image_base64=image_paths,
476
+ )
477
+ try:
478
+ result = float(result.strip())
479
+ except Exception as e:
480
+ logger.error(
481
+ f"Parser error: failed convert {result} to float, {e}, use default value {self.default_value}."
482
  )
483
+ result = self.default_value
484
+
485
+ return result
486
+
487
 
488
+ class SemanticMatcher(BaseChecker):
489
+ def __init__(
490
+ self,
491
+ gpt_client: GPTclient,
492
+ prompt: str = None,
493
+ verbose: bool = False,
494
+ seed: int = None,
495
+ ) -> None:
496
+ super().__init__(prompt, verbose)
497
+ self.gpt_client = gpt_client
498
+ self.seed = seed
499
+ random.seed(seed)
500
+ if self.prompt is None:
501
+ self.prompt = """
502
+ You are an expert in semantic similarity and scene retrieval.
503
+ You will be given:
504
+ - A dictionary where each key is a scene ID, and each value is a scene description.
505
+ - A query text describing a target scene.
506
+
507
+ Your task:
508
+ return_num = 2
509
+ - Find the <return_num> most semantically similar scene IDs to the query text.
510
+ - If there are fewer than <return_num> distinct relevant matches, repeat the closest ones to make a list of <return_num>.
511
+ - Only output the list of <return_num> scene IDs, sorted from most to less similar.
512
+ - Do NOT use markdown, JSON code blocks, or any formatting syntax, only return a plain list like ["id1", ...].
513
+
514
+ Input example:
515
+ Dictionary:
516
+ "{{
517
+ "t_scene_008": "A study room with full bookshelves and a lamp in the corner.",
518
+ "t_scene_019": "A child's bedroom with pink walls and a small desk.",
519
+ "t_scene_020": "A living room with a wooden floor.",
520
+ "t_scene_021": "A living room with toys scattered on the floor.",
521
+ ...
522
+ "t_scene_office_001": "A very spacious, modern open-plan office with wide desks and no people, panoramic view."
523
+ }}"
524
+ Text:
525
+ "A traditional indoor room"
526
+ Output:
527
+ '["t_scene_office_001", ...]'
528
+
529
+ Input:
530
+ Dictionary:
531
+ {context}
532
+ Text:
533
+ {text}
534
+ Output:
535
+ <topk_key_list>
536
+ """
537
 
538
+ def query(
539
+ self, text: str, context: dict, rand: bool = True, params: dict = None
540
+ ) -> str:
541
+ match_list = self.gpt_client.query(
542
+ self.prompt.format(context=context, text=text),
543
+ params=params,
544
+ )
545
+ match_list = json_repair.loads(match_list)
546
+ result = random.choice(match_list) if rand else match_list[0]
547
+
548
+ return result
549
+
550
+
551
+ def test_semantic_matcher(
552
+ bg_file: str = "outputs/bg_scenes/bg_scene_list.txt",
553
+ ):
554
+ bg_file = "outputs/bg_scenes/bg_scene_list.txt"
555
+ scene_dict = {}
556
+ with open(bg_file, "r") as f:
557
+ for line in f:
558
+ line = line.strip()
559
+ if not line or ":" not in line:
560
+ continue
561
+ scene_id, desc = line.split(":", 1)
562
+ scene_dict[scene_id.strip()] = desc.strip()
563
+
564
+ office_scene = scene_dict.get("t_scene_office_001")
565
+ text = "bright kitchen"
566
+ SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
567
+ # gpt_params = {
568
+ # "temperature": 0.8,
569
+ # "max_tokens": 500,
570
+ # "top_p": 0.8,
571
+ # "frequency_penalty": 0.3,
572
+ # "presence_penalty": 0.3,
573
+ # }
574
+ gpt_params = None
575
+ match_key = SCENE_MATCHER.query(text, str(scene_dict))
576
+ print(match_key, ",", scene_dict[match_key])
577
+
578
+
579
+ if __name__ == "__main__":
580
+ test_semantic_matcher()
embodied_gen/validators/urdf_convertor.py CHANGED
@@ -101,34 +101,42 @@ class URDFGenerator(object):
101
  prompt_template = (
102
  view_desc
103
  + """of the 3D object asset,
104
- category: {category}.
105
- You are an expert in 3D object analysis and physical property estimation.
106
- Give the category of this object asset (within 3 words),
107
- (if category is already provided, use it directly),
108
- accurately describe this 3D object asset (within 15 words),
109
- and give the recommended geometric height range (unit: meter),
110
- weight range (unit: kilogram), the average static friction
111
- coefficient of the object relative to rubber and the average
112
- dynamic friction coefficient of the object relative to rubber.
113
- Return response format as shown in Output Example.
114
-
115
- IMPORTANT:
116
- Inputed images are orthographic projection showing the front, left, right and back views,
117
- the first image is always the front view. Use the object's pose and orientation in the
118
- rendered images to estimate its **true vertical height as it appears in the image**,
119
- not the real-world length or width of the object.
120
- For example:
121
- - A pen standing upright in the front view β†’ vertical height: 0.15-0.2 m
122
- - A pen lying horizontally in the front view β†’ vertical height: 0.01-0.02 m
123
- (based on its thickness in the image)
124
-
125
- Output Example:
126
- Category: cup
127
- Description: shiny golden cup with floral design
128
- Height: 0.1-0.15 m
129
- Weight: 0.3-0.6 kg
130
- Static friction coefficient: 1.1
131
- Dynamic friction coefficient: 0.9
 
 
 
 
 
 
 
 
132
  """
133
  )
134
 
@@ -297,20 +305,24 @@ class URDFGenerator(object):
297
  if not os.path.exists(urdf_path):
298
  raise FileNotFoundError(f"URDF file not found: {urdf_path}")
299
 
300
- mesh_scale = 1.0
301
  tree = ET.parse(urdf_path)
302
  root = tree.getroot()
303
  extra_info = root.find(attr_root)
304
  if extra_info is not None:
305
  scale_element = extra_info.find(attr_name)
306
  if scale_element is not None:
307
- mesh_scale = float(scale_element.text)
 
 
 
 
308
 
309
- return mesh_scale
310
 
311
  @staticmethod
312
  def add_quality_tag(
313
- urdf_path: str, results, output_path: str = None
314
  ) -> None:
315
  if output_path is None:
316
  output_path = urdf_path
@@ -366,17 +378,11 @@ class URDFGenerator(object):
366
  output_root,
367
  num_images=self.render_view_num,
368
  output_subdir=self.output_render_dir,
 
369
  )
370
 
371
- # Hardcode tmp because of the openrouter can't input multi images.
372
- if "openrouter" in self.gpt_client.endpoint:
373
- from embodied_gen.utils.process_media import (
374
- combine_images_to_base64,
375
- )
376
-
377
- image_path = combine_images_to_base64(image_path)
378
-
379
  response = self.gpt_client.query(text_prompt, image_path)
 
380
  if response is None:
381
  asset_attrs = {
382
  "category": category.lower(),
@@ -412,14 +418,18 @@ class URDFGenerator(object):
412
  if __name__ == "__main__":
413
  urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
414
  urdf_path = urdf_gen(
415
- mesh_path="outputs/imageto3d/cma/o5/URDF_o5/mesh/o5.obj",
416
  output_root="outputs/test_urdf",
417
- # category="coffee machine",
418
  # min_height=1.0,
419
  # max_height=1.2,
420
  version=VERSION,
421
  )
422
 
 
 
 
 
423
  # zip_files(
424
  # input_paths=[
425
  # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
 
101
  prompt_template = (
102
  view_desc
103
  + """of the 3D object asset,
104
+ category: {category}.
105
+ You are an expert in 3D object analysis and physical property estimation.
106
+ Give the category of this object asset (within 3 words), (if category is
107
+ already provided, use it directly), accurately describe this 3D object asset (within 15 words),
108
+ Determine the pose of the object in the first image and estimate the true vertical height
109
+ (vertical projection) range of the object (in meters), i.e., how tall the object appears from top
110
+ to bottom in the front view (first) image. also weight range (unit: kilogram), the average
111
+ static friction coefficient of the object relative to rubber and the average dynamic friction
112
+ coefficient of the object relative to rubber. Return response format as shown in Output Example.
113
+
114
+ Output Example:
115
+ Category: cup
116
+ Description: shiny golden cup with floral design
117
+ Height: 0.1-0.15 m
118
+ Weight: 0.3-0.6 kg
119
+ Static friction coefficient: 0.6
120
+ Dynamic friction coefficient: 0.5
121
+
122
+ IMPORTANT: Estimating Vertical Height from the First (Front View) Image.
123
+ - The "vertical height" refers to the real-world vertical size of the object
124
+ as projected in the first image, aligned with the image's vertical axis.
125
+ - For flat objects like plates or disks or book, if their face is visible in the front view,
126
+ use the diameter as the vertical height. If the edge is visible, use the thickness instead.
127
+ - This is not necessarily the full length of the object, but how tall it appears
128
+ in the first image vertically, based on its pose and orientation.
129
+ - For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in
130
+ the first image, tilted at 45Β° will appear shorter vertically than when upright.
131
+ Estimate the vertical projection of their real length based on its pose.
132
+ For example:
133
+ - A pen standing upright in the first view (aligned with the image's vertical axis)
134
+ full body visible in the first image: β†’ vertical height β‰ˆ 0.14-0.20 m
135
+ - A pen lying flat in the front view (showing thickness) β†’ vertical height β‰ˆ 0.018-0.025 m
136
+ - Tilted pen in the first image (e.g., ~45Β° angle): vertical height β‰ˆ 0.07-0.12 m
137
+ - Use the rest views(except the first image) to help determine the object's 3D pose and orientation.
138
+ Assume the object is in real-world scale and estimate the approximate vertical height
139
+ (in meters) based on how large it appears vertically in the first image.
140
  """
141
  )
142
 
 
305
  if not os.path.exists(urdf_path):
306
  raise FileNotFoundError(f"URDF file not found: {urdf_path}")
307
 
308
+ mesh_attr = None
309
  tree = ET.parse(urdf_path)
310
  root = tree.getroot()
311
  extra_info = root.find(attr_root)
312
  if extra_info is not None:
313
  scale_element = extra_info.find(attr_name)
314
  if scale_element is not None:
315
+ mesh_attr = scale_element.text
316
+ try:
317
+ mesh_attr = float(mesh_attr)
318
+ except ValueError as e:
319
+ pass
320
 
321
+ return mesh_attr
322
 
323
  @staticmethod
324
  def add_quality_tag(
325
+ urdf_path: str, results: list, output_path: str = None
326
  ) -> None:
327
  if output_path is None:
328
  output_path = urdf_path
 
378
  output_root,
379
  num_images=self.render_view_num,
380
  output_subdir=self.output_render_dir,
381
+ no_index_file=True,
382
  )
383
 
 
 
 
 
 
 
 
 
384
  response = self.gpt_client.query(text_prompt, image_path)
385
+ # logger.info(response)
386
  if response is None:
387
  asset_attrs = {
388
  "category": category.lower(),
 
418
  if __name__ == "__main__":
419
  urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
420
  urdf_path = urdf_gen(
421
+ mesh_path="outputs/layout2/asset3d/marker/result/mesh/marker.obj",
422
  output_root="outputs/test_urdf",
423
+ category="marker",
424
  # min_height=1.0,
425
  # max_height=1.2,
426
  version=VERSION,
427
  )
428
 
429
+ URDFGenerator.add_quality_tag(
430
+ urdf_path, [[urdf_gen.__class__.__name__, "OK"]]
431
+ )
432
+
433
  # zip_files(
434
  # input_paths=[
435
  # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",