Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
Β·
8131b67
1
Parent(s):
33d9f9a
update
Browse files- common.py +4 -4
- embodied_gen/data/backproject_v2.py +8 -4
- embodied_gen/data/datasets.py +65 -1
- embodied_gen/data/differentiable_render.py +8 -3
- embodied_gen/data/mesh_operator.py +2 -0
- embodied_gen/data/utils.py +2 -15
- embodied_gen/models/image_comm_model.py +236 -0
- embodied_gen/models/text_model.py +7 -2
- embodied_gen/scripts/gen_scene3d.py +191 -0
- embodied_gen/scripts/imageto3d.py +107 -91
- embodied_gen/scripts/text2image.py +3 -9
- embodied_gen/scripts/textto3d.py +280 -0
- embodied_gen/scripts/textto3d.sh +43 -6
- embodied_gen/scripts/texture_gen.sh +3 -6
- embodied_gen/trainer/gsplat_trainer.py +678 -0
- embodied_gen/trainer/pono2mesh_trainer.py +538 -0
- embodied_gen/utils/config.py +190 -0
- embodied_gen/utils/enum.py +107 -0
- embodied_gen/utils/gaussian.py +331 -0
- embodied_gen/utils/gpt_clients.py +47 -37
- embodied_gen/utils/log.py +48 -0
- embodied_gen/utils/monkey_patches.py +152 -0
- embodied_gen/utils/process_media.py +228 -90
- embodied_gen/utils/tags.py +1 -1
- embodied_gen/utils/trender.py +90 -0
- embodied_gen/validators/aesthetic_predictor.py +1 -13
- embodied_gen/validators/quality_checkers.py +410 -72
- embodied_gen/validators/urdf_convertor.py +52 -42
common.py
CHANGED
@@ -55,9 +55,9 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
55 |
from embodied_gen.utils.process_media import (
|
56 |
filter_image_small_connected_components,
|
57 |
merge_images_video,
|
58 |
-
render_video,
|
59 |
)
|
60 |
from embodied_gen.utils.tags import VERSION
|
|
|
61 |
from embodied_gen.validators.quality_checkers import (
|
62 |
BaseChecker,
|
63 |
ImageAestheticChecker,
|
@@ -94,9 +94,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
|
94 |
os.environ["SPCONV_ALGO"] = "native"
|
95 |
|
96 |
MAX_SEED = 100000
|
97 |
-
DELIGHT = DelightingModel()
|
98 |
-
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
99 |
-
# IMAGESR_MODEL = ImageStableSR()
|
100 |
|
101 |
|
102 |
def patched_setup_functions(self):
|
@@ -136,6 +133,9 @@ def patched_setup_functions(self):
|
|
136 |
Gaussian.setup_functions = patched_setup_functions
|
137 |
|
138 |
|
|
|
|
|
|
|
139 |
if os.getenv("GRADIO_APP") == "imageto3d":
|
140 |
RBG_REMOVER = RembgRemover()
|
141 |
RBG14_REMOVER = BMGG14Remover()
|
|
|
55 |
from embodied_gen.utils.process_media import (
|
56 |
filter_image_small_connected_components,
|
57 |
merge_images_video,
|
|
|
58 |
)
|
59 |
from embodied_gen.utils.tags import VERSION
|
60 |
+
from embodied_gen.utils.trender import render_video
|
61 |
from embodied_gen.validators.quality_checkers import (
|
62 |
BaseChecker,
|
63 |
ImageAestheticChecker,
|
|
|
94 |
os.environ["SPCONV_ALGO"] = "native"
|
95 |
|
96 |
MAX_SEED = 100000
|
|
|
|
|
|
|
97 |
|
98 |
|
99 |
def patched_setup_functions(self):
|
|
|
133 |
Gaussian.setup_functions = patched_setup_functions
|
134 |
|
135 |
|
136 |
+
DELIGHT = DelightingModel()
|
137 |
+
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
138 |
+
# IMAGESR_MODEL = ImageStableSR()
|
139 |
if os.getenv("GRADIO_APP") == "imageto3d":
|
140 |
RBG_REMOVER = RembgRemover()
|
141 |
RBG14_REMOVER = BMGG14Remover()
|
embodied_gen/data/backproject_v2.py
CHANGED
@@ -251,6 +251,7 @@ class TextureBacker:
|
|
251 |
during rendering. Defaults to 0.5.
|
252 |
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
253 |
blurring) to the final texture. Defaults to True.
|
|
|
254 |
"""
|
255 |
|
256 |
def __init__(
|
@@ -262,6 +263,7 @@ class TextureBacker:
|
|
262 |
bake_angle_thresh: int = 75,
|
263 |
mask_thresh: float = 0.5,
|
264 |
smooth_texture: bool = True,
|
|
|
265 |
) -> None:
|
266 |
self.camera_params = camera_params
|
267 |
self.renderer = None
|
@@ -271,6 +273,7 @@ class TextureBacker:
|
|
271 |
self.texture_wh = texture_wh
|
272 |
self.mask_thresh = mask_thresh
|
273 |
self.smooth_texture = smooth_texture
|
|
|
274 |
|
275 |
self.bake_angle_thresh = bake_angle_thresh
|
276 |
self.bake_unreliable_kernel_size = int(
|
@@ -446,11 +449,12 @@ class TextureBacker:
|
|
446 |
def uv_inpaint(
|
447 |
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
448 |
) -> np.ndarray:
|
449 |
-
|
|
|
|
|
|
|
|
|
450 |
|
451 |
-
texture, mask = _texture_inpaint_smooth(
|
452 |
-
texture, mask, vertices, faces, uv_map
|
453 |
-
)
|
454 |
texture = texture.clip(0, 1)
|
455 |
texture = cv2.inpaint(
|
456 |
(texture * 255).astype(np.uint8),
|
|
|
251 |
during rendering. Defaults to 0.5.
|
252 |
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
253 |
blurring) to the final texture. Defaults to True.
|
254 |
+
inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
|
255 |
"""
|
256 |
|
257 |
def __init__(
|
|
|
263 |
bake_angle_thresh: int = 75,
|
264 |
mask_thresh: float = 0.5,
|
265 |
smooth_texture: bool = True,
|
266 |
+
inpaint_smooth: bool = False,
|
267 |
) -> None:
|
268 |
self.camera_params = camera_params
|
269 |
self.renderer = None
|
|
|
273 |
self.texture_wh = texture_wh
|
274 |
self.mask_thresh = mask_thresh
|
275 |
self.smooth_texture = smooth_texture
|
276 |
+
self.inpaint_smooth = inpaint_smooth
|
277 |
|
278 |
self.bake_angle_thresh = bake_angle_thresh
|
279 |
self.bake_unreliable_kernel_size = int(
|
|
|
449 |
def uv_inpaint(
|
450 |
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
451 |
) -> np.ndarray:
|
452 |
+
if self.inpaint_smooth:
|
453 |
+
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
454 |
+
texture, mask = _texture_inpaint_smooth(
|
455 |
+
texture, mask, vertices, faces, uv_map
|
456 |
+
)
|
457 |
|
|
|
|
|
|
|
458 |
texture = texture.clip(0, 1)
|
459 |
texture = cv2.inpaint(
|
460 |
(texture * 255).astype(np.uint8),
|
embodied_gen/data/datasets.py
CHANGED
@@ -19,8 +19,9 @@ import json
|
|
19 |
import logging
|
20 |
import os
|
21 |
import random
|
22 |
-
from typing import Any, Callable, Dict, List, Tuple
|
23 |
|
|
|
24 |
import torch
|
25 |
import torch.utils.checkpoint
|
26 |
from PIL import Image
|
@@ -36,6 +37,7 @@ logger = logging.getLogger(__name__)
|
|
36 |
|
37 |
__all__ = [
|
38 |
"Asset3dGenDataset",
|
|
|
39 |
]
|
40 |
|
41 |
|
@@ -222,6 +224,68 @@ class Asset3dGenDataset(Dataset):
|
|
222 |
return data
|
223 |
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
if __name__ == "__main__":
|
226 |
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
|
227 |
target_hw = (512, 512)
|
|
|
19 |
import logging
|
20 |
import os
|
21 |
import random
|
22 |
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
23 |
|
24 |
+
import numpy as np
|
25 |
import torch
|
26 |
import torch.utils.checkpoint
|
27 |
from PIL import Image
|
|
|
37 |
|
38 |
__all__ = [
|
39 |
"Asset3dGenDataset",
|
40 |
+
"PanoGSplatDataset",
|
41 |
]
|
42 |
|
43 |
|
|
|
224 |
return data
|
225 |
|
226 |
|
227 |
+
class PanoGSplatDataset(Dataset):
|
228 |
+
"""A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data.
|
229 |
+
|
230 |
+
This dataset is designed to be compatible with train and eval pipelines
|
231 |
+
that use COLMAP-style camera conventions.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
data_dir (str): Root directory where the dataset file is located.
|
235 |
+
split (str): Dataset split to use, either "train" or "eval".
|
236 |
+
data_name (str, optional): Name of the dataset file (default: "gs_data.pt").
|
237 |
+
max_sample_num (int, optional): Maximum number of samples to load. If None,
|
238 |
+
all available samples in the split will be used.
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
data_dir: str,
|
244 |
+
split: str = Literal["train", "eval"],
|
245 |
+
data_name: str = "gs_data.pt",
|
246 |
+
max_sample_num: int = None,
|
247 |
+
) -> None:
|
248 |
+
self.data_path = os.path.join(data_dir, data_name)
|
249 |
+
self.split = split
|
250 |
+
self.max_sample_num = max_sample_num
|
251 |
+
if not os.path.exists(self.data_path):
|
252 |
+
raise FileNotFoundError(
|
253 |
+
f"Dataset file {self.data_path} not found. Please provide the correct path."
|
254 |
+
)
|
255 |
+
self.data = torch.load(self.data_path, weights_only=False)
|
256 |
+
self.frames = self.data[split]
|
257 |
+
if max_sample_num is not None:
|
258 |
+
self.frames = self.frames[:max_sample_num]
|
259 |
+
self.points = self.data.get("points", None)
|
260 |
+
self.points_rgb = self.data.get("points_rgb", None)
|
261 |
+
|
262 |
+
def __len__(self) -> int:
|
263 |
+
return len(self.frames)
|
264 |
+
|
265 |
+
def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray:
|
266 |
+
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
267 |
+
tranformed_c2w = np.copy(c2w)
|
268 |
+
tranformed_c2w[:3, 1:3] *= -1
|
269 |
+
|
270 |
+
return tranformed_c2w
|
271 |
+
|
272 |
+
def __getitem__(self, index: int) -> dict[str, any]:
|
273 |
+
data = self.frames[index]
|
274 |
+
c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"])
|
275 |
+
item = dict(
|
276 |
+
camtoworld=c2w,
|
277 |
+
K=data["K"],
|
278 |
+
image_h=data["image_h"],
|
279 |
+
image_w=data["image_w"],
|
280 |
+
)
|
281 |
+
if "image" in data:
|
282 |
+
item["image"] = data["image"]
|
283 |
+
if "image_id" in data:
|
284 |
+
item["image_id"] = data["image_id"]
|
285 |
+
|
286 |
+
return item
|
287 |
+
|
288 |
+
|
289 |
if __name__ == "__main__":
|
290 |
index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
|
291 |
target_hw = (512, 512)
|
embodied_gen/data/differentiable_render.py
CHANGED
@@ -33,7 +33,6 @@ from tqdm import tqdm
|
|
33 |
from embodied_gen.data.utils import (
|
34 |
CameraSetting,
|
35 |
DiffrastRender,
|
36 |
-
RenderItems,
|
37 |
as_list,
|
38 |
calc_vertex_normals,
|
39 |
import_kaolin_mesh,
|
@@ -42,6 +41,7 @@ from embodied_gen.data.utils import (
|
|
42 |
render_pbr,
|
43 |
save_images,
|
44 |
)
|
|
|
45 |
|
46 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
47 |
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
@@ -470,7 +470,7 @@ def parse_args():
|
|
470 |
"--pbr_light_factor",
|
471 |
type=float,
|
472 |
default=1.0,
|
473 |
-
help="Light factor for mesh PBR rendering (default:
|
474 |
)
|
475 |
parser.add_argument(
|
476 |
"--with_mtl",
|
@@ -482,6 +482,11 @@ def parse_args():
|
|
482 |
action="store_true",
|
483 |
help="Whether to generate color .gif rendering file.",
|
484 |
)
|
|
|
|
|
|
|
|
|
|
|
485 |
parser.add_argument(
|
486 |
"--gen_color_mp4",
|
487 |
action="store_true",
|
@@ -568,7 +573,7 @@ def entrypoint(**kwargs) -> None:
|
|
568 |
gen_viewnormal_mp4=args.gen_viewnormal_mp4,
|
569 |
gen_glonormal_mp4=args.gen_glonormal_mp4,
|
570 |
light_factor=args.pbr_light_factor,
|
571 |
-
no_index_file=gen_video,
|
572 |
)
|
573 |
image_render.render_mesh(
|
574 |
mesh_path=args.mesh_path,
|
|
|
33 |
from embodied_gen.data.utils import (
|
34 |
CameraSetting,
|
35 |
DiffrastRender,
|
|
|
36 |
as_list,
|
37 |
calc_vertex_normals,
|
38 |
import_kaolin_mesh,
|
|
|
41 |
render_pbr,
|
42 |
save_images,
|
43 |
)
|
44 |
+
from embodied_gen.utils.enum import RenderItems
|
45 |
|
46 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
47 |
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
|
|
470 |
"--pbr_light_factor",
|
471 |
type=float,
|
472 |
default=1.0,
|
473 |
+
help="Light factor for mesh PBR rendering (default: 1.)",
|
474 |
)
|
475 |
parser.add_argument(
|
476 |
"--with_mtl",
|
|
|
482 |
action="store_true",
|
483 |
help="Whether to generate color .gif rendering file.",
|
484 |
)
|
485 |
+
parser.add_argument(
|
486 |
+
"--no_index_file",
|
487 |
+
action="store_true",
|
488 |
+
help="Whether skip the index file saving.",
|
489 |
+
)
|
490 |
parser.add_argument(
|
491 |
"--gen_color_mp4",
|
492 |
action="store_true",
|
|
|
573 |
gen_viewnormal_mp4=args.gen_viewnormal_mp4,
|
574 |
gen_glonormal_mp4=args.gen_glonormal_mp4,
|
575 |
light_factor=args.pbr_light_factor,
|
576 |
+
no_index_file=gen_video or args.no_index_file,
|
577 |
)
|
578 |
image_render.render_mesh(
|
579 |
mesh_path=args.mesh_path,
|
embodied_gen/data/mesh_operator.py
CHANGED
@@ -395,6 +395,8 @@ class MeshFixer(object):
|
|
395 |
self.vertices_np,
|
396 |
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
|
397 |
)
|
|
|
|
|
398 |
mesh = mesh.decimate(ratio, progress_bar=True)
|
399 |
|
400 |
# Update vertices and faces
|
|
|
395 |
self.vertices_np,
|
396 |
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
|
397 |
)
|
398 |
+
mesh.clean(inplace=True)
|
399 |
+
mesh.clear_data()
|
400 |
mesh = mesh.decimate(ratio, progress_bar=True)
|
401 |
|
402 |
# Update vertices and faces
|
embodied_gen/data/utils.py
CHANGED
@@ -38,7 +38,6 @@ except ImportError:
|
|
38 |
ChatGLMModel = None
|
39 |
import logging
|
40 |
from dataclasses import dataclass, field
|
41 |
-
from enum import Enum
|
42 |
|
43 |
import trimesh
|
44 |
from kaolin.render.camera import Camera
|
@@ -57,7 +56,6 @@ __all__ = [
|
|
57 |
"load_mesh_to_unit_cube",
|
58 |
"as_list",
|
59 |
"CameraSetting",
|
60 |
-
"RenderItems",
|
61 |
"import_kaolin_mesh",
|
62 |
"save_mesh_with_mtl",
|
63 |
"get_images_from_grid",
|
@@ -160,8 +158,9 @@ class DiffrastRender(object):
|
|
160 |
|
161 |
return normalized_maps
|
162 |
|
|
|
163 |
def normalize_map_by_mask(
|
164 |
-
|
165 |
) -> torch.Tensor:
|
166 |
# Normalize all maps in total by mask, normalized map in [0, 1].
|
167 |
foreground = (mask == 1).squeeze(dim=-1)
|
@@ -738,18 +737,6 @@ class CameraSetting:
|
|
738 |
self.Ks = Ks
|
739 |
|
740 |
|
741 |
-
@dataclass
|
742 |
-
class RenderItems(str, Enum):
|
743 |
-
IMAGE = "image_color"
|
744 |
-
ALPHA = "image_mask"
|
745 |
-
VIEW_NORMAL = "image_view_normal"
|
746 |
-
GLOBAL_NORMAL = "image_global_normal"
|
747 |
-
POSITION_MAP = "image_position"
|
748 |
-
DEPTH = "image_depth"
|
749 |
-
ALBEDO = "image_albedo"
|
750 |
-
DIFFUSE = "image_diffuse"
|
751 |
-
|
752 |
-
|
753 |
def _compute_az_el_by_camera_params(
|
754 |
camera_params: CameraSetting, flip_az: bool = False
|
755 |
):
|
|
|
38 |
ChatGLMModel = None
|
39 |
import logging
|
40 |
from dataclasses import dataclass, field
|
|
|
41 |
|
42 |
import trimesh
|
43 |
from kaolin.render.camera import Camera
|
|
|
56 |
"load_mesh_to_unit_cube",
|
57 |
"as_list",
|
58 |
"CameraSetting",
|
|
|
59 |
"import_kaolin_mesh",
|
60 |
"save_mesh_with_mtl",
|
61 |
"get_images_from_grid",
|
|
|
158 |
|
159 |
return normalized_maps
|
160 |
|
161 |
+
@staticmethod
|
162 |
def normalize_map_by_mask(
|
163 |
+
map: torch.Tensor, mask: torch.Tensor
|
164 |
) -> torch.Tensor:
|
165 |
# Normalize all maps in total by mask, normalized map in [0, 1].
|
166 |
foreground = (mask == 1).squeeze(dim=-1)
|
|
|
737 |
self.Ks = Ks
|
738 |
|
739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
def _compute_az_el_by_camera_params(
|
741 |
camera_params: CameraSetting, flip_az: bool = False
|
742 |
):
|
embodied_gen/models/image_comm_model.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
# Text-to-Image generation models from Hugging Face community.
|
17 |
+
|
18 |
+
import os
|
19 |
+
from abc import ABC, abstractmethod
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from diffusers import (
|
23 |
+
ChromaPipeline,
|
24 |
+
Cosmos2TextToImagePipeline,
|
25 |
+
DPMSolverMultistepScheduler,
|
26 |
+
FluxPipeline,
|
27 |
+
KolorsPipeline,
|
28 |
+
StableDiffusion3Pipeline,
|
29 |
+
)
|
30 |
+
from diffusers.quantizers import PipelineQuantizationConfig
|
31 |
+
from huggingface_hub import snapshot_download
|
32 |
+
from PIL import Image
|
33 |
+
from transformers import AutoModelForCausalLM, SiglipProcessor
|
34 |
+
|
35 |
+
__all__ = [
|
36 |
+
"build_hf_image_pipeline",
|
37 |
+
]
|
38 |
+
|
39 |
+
|
40 |
+
class BasePipelineLoader(ABC):
|
41 |
+
def __init__(self, device="cuda"):
|
42 |
+
self.device = device
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def load(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
class BasePipelineRunner(ABC):
|
50 |
+
def __init__(self, pipe):
|
51 |
+
self.pipe = pipe
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def run(self, prompt: str, **kwargs) -> Image.Image:
|
55 |
+
pass
|
56 |
+
|
57 |
+
|
58 |
+
# ===== SD3.5-medium =====
|
59 |
+
class SD35Loader(BasePipelineLoader):
|
60 |
+
def load(self):
|
61 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(
|
62 |
+
"stabilityai/stable-diffusion-3.5-medium",
|
63 |
+
torch_dtype=torch.float16,
|
64 |
+
)
|
65 |
+
pipe = pipe.to(self.device)
|
66 |
+
pipe.enable_model_cpu_offload()
|
67 |
+
pipe.enable_xformers_memory_efficient_attention()
|
68 |
+
pipe.enable_attention_slicing()
|
69 |
+
return pipe
|
70 |
+
|
71 |
+
|
72 |
+
class SD35Runner(BasePipelineRunner):
|
73 |
+
def run(self, prompt: str, **kwargs) -> Image.Image:
|
74 |
+
return self.pipe(prompt=prompt, **kwargs).images
|
75 |
+
|
76 |
+
|
77 |
+
# ===== Cosmos2 =====
|
78 |
+
class CosmosLoader(BasePipelineLoader):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
82 |
+
local_dir="weights/cosmos2",
|
83 |
+
device="cuda",
|
84 |
+
):
|
85 |
+
super().__init__(device)
|
86 |
+
self.model_id = model_id
|
87 |
+
self.local_dir = local_dir
|
88 |
+
|
89 |
+
def _patch(self):
|
90 |
+
def patch_model(cls):
|
91 |
+
orig = cls.from_pretrained
|
92 |
+
|
93 |
+
def new(*args, **kwargs):
|
94 |
+
kwargs.setdefault("attn_implementation", "flash_attention_2")
|
95 |
+
kwargs.setdefault("torch_dtype", torch.bfloat16)
|
96 |
+
return orig(*args, **kwargs)
|
97 |
+
|
98 |
+
cls.from_pretrained = new
|
99 |
+
|
100 |
+
def patch_processor(cls):
|
101 |
+
orig = cls.from_pretrained
|
102 |
+
|
103 |
+
def new(*args, **kwargs):
|
104 |
+
kwargs.setdefault("use_fast", True)
|
105 |
+
return orig(*args, **kwargs)
|
106 |
+
|
107 |
+
cls.from_pretrained = new
|
108 |
+
|
109 |
+
patch_model(AutoModelForCausalLM)
|
110 |
+
patch_processor(SiglipProcessor)
|
111 |
+
|
112 |
+
def load(self):
|
113 |
+
self._patch()
|
114 |
+
snapshot_download(
|
115 |
+
repo_id=self.model_id,
|
116 |
+
local_dir=self.local_dir,
|
117 |
+
local_dir_use_symlinks=False,
|
118 |
+
resume_download=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
config = PipelineQuantizationConfig(
|
122 |
+
quant_backend="bitsandbytes_4bit",
|
123 |
+
quant_kwargs={
|
124 |
+
"load_in_4bit": True,
|
125 |
+
"bnb_4bit_quant_type": "nf4",
|
126 |
+
"bnb_4bit_compute_dtype": torch.bfloat16,
|
127 |
+
"bnb_4bit_use_double_quant": True,
|
128 |
+
},
|
129 |
+
components_to_quantize=["text_encoder", "transformer", "unet"],
|
130 |
+
)
|
131 |
+
|
132 |
+
pipe = Cosmos2TextToImagePipeline.from_pretrained(
|
133 |
+
self.model_id,
|
134 |
+
torch_dtype=torch.bfloat16,
|
135 |
+
quantization_config=config,
|
136 |
+
use_safetensors=True,
|
137 |
+
safety_checker=None,
|
138 |
+
requires_safety_checker=False,
|
139 |
+
).to(self.device)
|
140 |
+
return pipe
|
141 |
+
|
142 |
+
|
143 |
+
class CosmosRunner(BasePipelineRunner):
|
144 |
+
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
145 |
+
return self.pipe(
|
146 |
+
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
147 |
+
).images
|
148 |
+
|
149 |
+
|
150 |
+
# ===== Kolors =====
|
151 |
+
class KolorsLoader(BasePipelineLoader):
|
152 |
+
def load(self):
|
153 |
+
pipe = KolorsPipeline.from_pretrained(
|
154 |
+
"Kwai-Kolors/Kolors-diffusers",
|
155 |
+
torch_dtype=torch.float16,
|
156 |
+
variant="fp16",
|
157 |
+
).to(self.device)
|
158 |
+
pipe.enable_model_cpu_offload()
|
159 |
+
pipe.enable_xformers_memory_efficient_attention()
|
160 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
161 |
+
pipe.scheduler.config, use_karras_sigmas=True
|
162 |
+
)
|
163 |
+
return pipe
|
164 |
+
|
165 |
+
|
166 |
+
class KolorsRunner(BasePipelineRunner):
|
167 |
+
def run(self, prompt: str, **kwargs) -> Image.Image:
|
168 |
+
return self.pipe(prompt=prompt, **kwargs).images
|
169 |
+
|
170 |
+
|
171 |
+
# ===== Flux =====
|
172 |
+
class FluxLoader(BasePipelineLoader):
|
173 |
+
def load(self):
|
174 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
175 |
+
pipe = FluxPipeline.from_pretrained(
|
176 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
177 |
+
)
|
178 |
+
pipe.enable_model_cpu_offload()
|
179 |
+
pipe.enable_xformers_memory_efficient_attention()
|
180 |
+
pipe.enable_attention_slicing()
|
181 |
+
return pipe.to(self.device)
|
182 |
+
|
183 |
+
|
184 |
+
class FluxRunner(BasePipelineRunner):
|
185 |
+
def run(self, prompt: str, **kwargs) -> Image.Image:
|
186 |
+
return self.pipe(prompt=prompt, **kwargs).images
|
187 |
+
|
188 |
+
|
189 |
+
# ===== Chroma =====
|
190 |
+
class ChromaLoader(BasePipelineLoader):
|
191 |
+
def load(self):
|
192 |
+
return ChromaPipeline.from_pretrained(
|
193 |
+
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
194 |
+
).to(self.device)
|
195 |
+
|
196 |
+
|
197 |
+
class ChromaRunner(BasePipelineRunner):
|
198 |
+
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
199 |
+
return self.pipe(
|
200 |
+
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
201 |
+
).images
|
202 |
+
|
203 |
+
|
204 |
+
PIPELINE_REGISTRY = {
|
205 |
+
"sd35": (SD35Loader, SD35Runner),
|
206 |
+
"cosmos": (CosmosLoader, CosmosRunner),
|
207 |
+
"kolors": (KolorsLoader, KolorsRunner),
|
208 |
+
"flux": (FluxLoader, FluxRunner),
|
209 |
+
"chroma": (ChromaLoader, ChromaRunner),
|
210 |
+
}
|
211 |
+
|
212 |
+
|
213 |
+
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
214 |
+
if name not in PIPELINE_REGISTRY:
|
215 |
+
raise ValueError(f"Unsupported model: {name}")
|
216 |
+
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
217 |
+
pipe = loader_cls(device=device).load()
|
218 |
+
|
219 |
+
return runner_cls(pipe)
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
model_name = "sd35"
|
224 |
+
runner = build_hf_image_pipeline(model_name)
|
225 |
+
# NOTE: Just for pipeline testing, generation quality at low resolution is poor.
|
226 |
+
images = runner.run(
|
227 |
+
prompt="A robot holding a sign that says 'Hello'",
|
228 |
+
height=512,
|
229 |
+
width=512,
|
230 |
+
num_inference_steps=10,
|
231 |
+
guidance_scale=6,
|
232 |
+
num_images_per_prompt=1,
|
233 |
+
)
|
234 |
+
|
235 |
+
for i, img in enumerate(images):
|
236 |
+
img.save(f"image_{model_name}_{i}.jpg")
|
embodied_gen/models/text_model.py
CHANGED
@@ -52,6 +52,12 @@ __all__ = [
|
|
52 |
"download_kolors_weights",
|
53 |
]
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
57 |
logger.info(f"Download kolors weights from huggingface...")
|
@@ -179,8 +185,7 @@ def text2img_gen(
|
|
179 |
ip_image_size: int = 512,
|
180 |
seed: int = None,
|
181 |
) -> list[Image.Image]:
|
182 |
-
prompt =
|
183 |
-
prompt += ", high quality, high resolution, best quality, white background, 3D style" # noqa
|
184 |
logger.info(f"Processing prompt: {prompt}")
|
185 |
|
186 |
generator = None
|
|
|
52 |
"download_kolors_weights",
|
53 |
]
|
54 |
|
55 |
+
PROMPT_APPEND = (
|
56 |
+
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
|
57 |
+
"no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces"
|
58 |
+
)
|
59 |
+
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
|
60 |
+
|
61 |
|
62 |
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
63 |
logger.info(f"Download kolors weights from huggingface...")
|
|
|
185 |
ip_image_size: int = 512,
|
186 |
seed: int = None,
|
187 |
) -> list[Image.Image]:
|
188 |
+
prompt = PROMPT_KAPPEND.format(object=prompt.strip())
|
|
|
189 |
logger.info(f"Processing prompt: {prompt}")
|
190 |
|
191 |
generator = None
|
embodied_gen/scripts/gen_scene3d.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import warnings
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from shutil import copy, rmtree
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import tyro
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
from packaging import version
|
13 |
+
|
14 |
+
# Suppress warnings
|
15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
16 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
17 |
+
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
18 |
+
|
19 |
+
# TorchVision monkey patch for >0.16
|
20 |
+
if version.parse(torch.__version__) >= version.parse("0.16"):
|
21 |
+
import sys
|
22 |
+
import types
|
23 |
+
|
24 |
+
import torchvision.transforms.functional as TF
|
25 |
+
|
26 |
+
functional_tensor = types.ModuleType(
|
27 |
+
"torchvision.transforms.functional_tensor"
|
28 |
+
)
|
29 |
+
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
|
30 |
+
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
|
31 |
+
|
32 |
+
from gsplat.distributed import cli
|
33 |
+
from txt2panoimg import Text2360PanoramaImagePipeline
|
34 |
+
from embodied_gen.trainer.gsplat_trainer import (
|
35 |
+
DefaultStrategy,
|
36 |
+
GsplatTrainConfig,
|
37 |
+
)
|
38 |
+
from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
|
39 |
+
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
40 |
+
from embodied_gen.utils.config import Pano2MeshSRConfig
|
41 |
+
from embodied_gen.utils.gaussian import restore_scene_scale_and_position
|
42 |
+
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
43 |
+
from embodied_gen.utils.log import logger
|
44 |
+
from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
|
45 |
+
from embodied_gen.validators.quality_checkers import (
|
46 |
+
PanoHeightEstimator,
|
47 |
+
PanoImageOccChecker,
|
48 |
+
)
|
49 |
+
|
50 |
+
__all__ = [
|
51 |
+
"generate_pano_image",
|
52 |
+
"entrypoint",
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class Scene3DGenConfig:
|
58 |
+
prompts: list[str] # Text desc of indoor room or style reference image.
|
59 |
+
output_dir: str
|
60 |
+
seed: int | None = None
|
61 |
+
real_height: float | None = None # The real height of the room in meters.
|
62 |
+
pano_image_only: bool = False
|
63 |
+
disable_pano_check: bool = False
|
64 |
+
keep_middle_result: bool = False
|
65 |
+
n_retry: int = 7
|
66 |
+
gs3d: GsplatTrainConfig = field(
|
67 |
+
default_factory=lambda: GsplatTrainConfig(
|
68 |
+
strategy=DefaultStrategy(verbose=True),
|
69 |
+
max_steps=4000,
|
70 |
+
init_opa=0.9,
|
71 |
+
opacity_reg=2e-3,
|
72 |
+
sh_degree=0,
|
73 |
+
means_lr=1e-4,
|
74 |
+
scales_lr=1e-3,
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def generate_pano_image(
|
80 |
+
prompt: str,
|
81 |
+
output_path: str,
|
82 |
+
pipeline,
|
83 |
+
seed: int,
|
84 |
+
n_retry: int,
|
85 |
+
checker=None,
|
86 |
+
num_inference_steps: int = 40,
|
87 |
+
) -> None:
|
88 |
+
for i in range(n_retry):
|
89 |
+
logger.info(
|
90 |
+
f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
|
91 |
+
)
|
92 |
+
if is_image_file(prompt):
|
93 |
+
raise NotImplementedError("Image mode not implemented yet.")
|
94 |
+
else:
|
95 |
+
txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
|
96 |
+
inputs = {
|
97 |
+
"prompt": txt_prompt,
|
98 |
+
"num_inference_steps": num_inference_steps,
|
99 |
+
"upscale": False,
|
100 |
+
"seed": seed,
|
101 |
+
}
|
102 |
+
pano_image = pipeline(inputs)
|
103 |
+
|
104 |
+
pano_image.save(output_path)
|
105 |
+
if checker is None:
|
106 |
+
break
|
107 |
+
|
108 |
+
flag, response = checker(pano_image)
|
109 |
+
logger.warning(f"{response}, image saved in {output_path}")
|
110 |
+
if flag is True or flag is None:
|
111 |
+
break
|
112 |
+
|
113 |
+
seed = random.randint(0, 100000)
|
114 |
+
|
115 |
+
return
|
116 |
+
|
117 |
+
|
118 |
+
def entrypoint(*args, **kwargs):
|
119 |
+
cfg = tyro.cli(Scene3DGenConfig)
|
120 |
+
|
121 |
+
# Init global models.
|
122 |
+
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
|
123 |
+
IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
|
124 |
+
model_path, torch_dtype=torch.float16, device="cuda"
|
125 |
+
)
|
126 |
+
PANOMESH_CFG = Pano2MeshSRConfig()
|
127 |
+
PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
|
128 |
+
PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
|
129 |
+
PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
|
130 |
+
|
131 |
+
prompts = parse_text_prompts(cfg.prompts)
|
132 |
+
for idx, prompt in enumerate(prompts):
|
133 |
+
start_time = time.time()
|
134 |
+
output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
|
135 |
+
os.makedirs(output_dir, exist_ok=True)
|
136 |
+
pano_path = os.path.join(output_dir, "pano_image.png")
|
137 |
+
with open(f"{output_dir}/prompt.txt", "w") as f:
|
138 |
+
f.write(prompt)
|
139 |
+
|
140 |
+
generate_pano_image(
|
141 |
+
prompt,
|
142 |
+
pano_path,
|
143 |
+
IMG2PANO_PIPE,
|
144 |
+
cfg.seed if cfg.seed is not None else random.randint(0, 100000),
|
145 |
+
cfg.n_retry,
|
146 |
+
checker=None if cfg.disable_pano_check else PANO_CHECKER,
|
147 |
+
)
|
148 |
+
|
149 |
+
if cfg.pano_image_only:
|
150 |
+
continue
|
151 |
+
|
152 |
+
logger.info("GEN and REPAIR Mesh from Panorama...")
|
153 |
+
PANO2MESH_PIPE(pano_path, output_dir)
|
154 |
+
|
155 |
+
logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
|
156 |
+
cfg.gs3d.data_dir = output_dir
|
157 |
+
cfg.gs3d.result_dir = f"{output_dir}/gaussian"
|
158 |
+
cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
|
159 |
+
torch.set_default_device("cpu") # recover default setting.
|
160 |
+
cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
|
161 |
+
|
162 |
+
# Clean up the middle results.
|
163 |
+
gs_path = (
|
164 |
+
f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
|
165 |
+
)
|
166 |
+
copy(gs_path, f"{output_dir}/gs_model.ply")
|
167 |
+
video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
|
168 |
+
copy(video_path, f"{output_dir}/video.mp4")
|
169 |
+
gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
|
170 |
+
copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
|
171 |
+
if not cfg.keep_middle_result:
|
172 |
+
rmtree(cfg.gs3d.result_dir, ignore_errors=True)
|
173 |
+
os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
|
174 |
+
|
175 |
+
real_height = (
|
176 |
+
PANOHEIGHT_ESTOR(pano_path)
|
177 |
+
if cfg.real_height is None
|
178 |
+
else cfg.real_height
|
179 |
+
)
|
180 |
+
gs_path = os.path.join(output_dir, "gs_model.ply")
|
181 |
+
mesh_path = os.path.join(output_dir, "mesh_model.ply")
|
182 |
+
restore_scene_scale_and_position(real_height, mesh_path, gs_path)
|
183 |
+
|
184 |
+
elapsed_time = (time.time() - start_time) / 60
|
185 |
+
logger.info(
|
186 |
+
f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
|
187 |
+
)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
entrypoint()
|
embodied_gen/scripts/imageto3d.py
CHANGED
@@ -16,29 +16,28 @@
|
|
16 |
|
17 |
|
18 |
import argparse
|
19 |
-
import logging
|
20 |
import os
|
|
|
21 |
import sys
|
22 |
from glob import glob
|
23 |
from shutil import copy, copytree, rmtree
|
24 |
|
25 |
import numpy as np
|
|
|
26 |
import trimesh
|
27 |
from PIL import Image
|
28 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
29 |
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
30 |
from embodied_gen.models.delight_model import DelightingModel
|
31 |
from embodied_gen.models.gs_model import GaussianOperator
|
32 |
-
from embodied_gen.models.segment_model import
|
33 |
-
BMGG14Remover,
|
34 |
-
RembgRemover,
|
35 |
-
SAMPredictor,
|
36 |
-
)
|
37 |
from embodied_gen.models.sr_model import ImageRealESRGAN
|
38 |
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
39 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
40 |
-
from embodied_gen.utils.
|
|
|
41 |
from embodied_gen.utils.tags import VERSION
|
|
|
42 |
from embodied_gen.validators.quality_checkers import (
|
43 |
BaseChecker,
|
44 |
ImageAestheticChecker,
|
@@ -52,36 +51,25 @@ current_dir = os.path.dirname(current_file_path)
|
|
52 |
sys.path.append(os.path.join(current_dir, "../.."))
|
53 |
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
54 |
|
55 |
-
logging.basicConfig(
|
56 |
-
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
57 |
-
)
|
58 |
-
logger = logging.getLogger(__name__)
|
59 |
-
|
60 |
-
|
61 |
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
62 |
"~/.cache/torch_extensions"
|
63 |
)
|
64 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
65 |
os.environ["SPCONV_ALGO"] = "native"
|
|
|
66 |
|
67 |
-
|
68 |
DELIGHT = DelightingModel()
|
69 |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
70 |
-
|
71 |
RBG_REMOVER = RembgRemover()
|
72 |
-
RBG14_REMOVER = BMGG14Remover()
|
73 |
-
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
74 |
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
75 |
"microsoft/TRELLIS-image-large"
|
76 |
)
|
77 |
-
PIPELINE.cuda()
|
78 |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
79 |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
80 |
AESTHETIC_CHECKER = ImageAestheticChecker()
|
81 |
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
82 |
-
TMP_DIR = os.path.join(
|
83 |
-
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
84 |
-
)
|
85 |
|
86 |
|
87 |
def parse_args():
|
@@ -95,7 +83,6 @@ def parse_args():
|
|
95 |
parser.add_argument(
|
96 |
"--output_root",
|
97 |
type=str,
|
98 |
-
required=True,
|
99 |
help="Root directory for saving outputs.",
|
100 |
)
|
101 |
parser.add_argument(
|
@@ -110,12 +97,26 @@ def parse_args():
|
|
110 |
default=None,
|
111 |
help="The mass in kg to restore the mesh real weight.",
|
112 |
)
|
113 |
-
parser.add_argument("--asset_type", type=str, default=None)
|
114 |
parser.add_argument("--skip_exists", action="store_true")
|
115 |
-
parser.add_argument("--strict_seg", action="store_true")
|
116 |
parser.add_argument("--version", type=str, default=VERSION)
|
117 |
-
parser.add_argument("--
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
assert (
|
121 |
args.image_path or args.image_root
|
@@ -125,13 +126,7 @@ def parse_args():
|
|
125 |
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
|
126 |
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
if __name__ == "__main__":
|
132 |
-
args = parse_args()
|
133 |
-
|
134 |
-
for image_path in args.image_path:
|
135 |
try:
|
136 |
filename = os.path.basename(image_path).split(".")[0]
|
137 |
output_root = args.output_root
|
@@ -141,7 +136,7 @@ if __name__ == "__main__":
|
|
141 |
|
142 |
mesh_out = f"{output_root}/{filename}.obj"
|
143 |
if args.skip_exists and os.path.exists(mesh_out):
|
144 |
-
logger.
|
145 |
f"Skip {image_path}, already processed in {mesh_out}"
|
146 |
)
|
147 |
continue
|
@@ -149,67 +144,84 @@ if __name__ == "__main__":
|
|
149 |
image = Image.open(image_path)
|
150 |
image.save(f"{output_root}/{filename}_raw.png")
|
151 |
|
152 |
-
# Segmentation: Get segmented image using
|
153 |
seg_path = f"{output_root}/{filename}_cond.png"
|
154 |
-
if image.mode != "RGBA"
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
try:
|
163 |
-
outputs = PIPELINE.run(
|
164 |
-
seg_image,
|
165 |
-
preprocess_image=False,
|
166 |
-
# Optional parameters
|
167 |
-
# seed=1,
|
168 |
-
# sparse_structure_sampler_params={
|
169 |
-
# "steps": 12,
|
170 |
-
# "cfg_strength": 7.5,
|
171 |
-
# },
|
172 |
-
# slat_sampler_params={
|
173 |
-
# "steps": 12,
|
174 |
-
# "cfg_strength": 3,
|
175 |
-
# },
|
176 |
)
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
)
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
-
# Render
|
184 |
-
gs_model = outputs["gaussian"][0]
|
185 |
-
mesh_model = outputs["mesh"][0]
|
186 |
color_images = render_video(gs_model)["color"]
|
187 |
normal_images = render_video(mesh_model)["normal"]
|
188 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
189 |
merge_images_video(color_images, normal_images, video_path)
|
190 |
|
191 |
-
# Save the raw Gaussian model
|
192 |
-
gs_path = mesh_out.replace(".obj", "_gs.ply")
|
193 |
-
gs_model.save_ply(gs_path)
|
194 |
-
|
195 |
-
# Rotate mesh and GS by 90 degrees around Z-axis.
|
196 |
-
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
197 |
-
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
198 |
-
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
199 |
-
|
200 |
-
# Addtional rotation for GS to align mesh.
|
201 |
-
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
202 |
-
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
203 |
-
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
204 |
-
GaussianOperator.resave_ply(
|
205 |
-
in_ply=gs_path,
|
206 |
-
out_ply=aligned_gs_path,
|
207 |
-
instance_pose=pose,
|
208 |
-
device="cpu",
|
209 |
-
)
|
210 |
-
color_path = os.path.join(output_root, "color.png")
|
211 |
-
render_gs_api(aligned_gs_path, color_path)
|
212 |
-
|
213 |
mesh = trimesh.Trimesh(
|
214 |
vertices=mesh_model.vertices.cpu().numpy(),
|
215 |
faces=mesh_model.faces.cpu().numpy(),
|
@@ -249,8 +261,8 @@ if __name__ == "__main__":
|
|
249 |
min_mass, max_mass = map(float, args.mass_range.split("-"))
|
250 |
asset_attrs["min_mass"] = min_mass
|
251 |
asset_attrs["max_mass"] = max_mass
|
252 |
-
if args.asset_type:
|
253 |
-
asset_attrs["category"] = args.asset_type
|
254 |
if args.version:
|
255 |
asset_attrs["version"] = args.version
|
256 |
|
@@ -289,8 +301,8 @@ if __name__ == "__main__":
|
|
289 |
]
|
290 |
images_list.append(images)
|
291 |
|
292 |
-
|
293 |
-
urdf_convertor.add_quality_tag(urdf_path,
|
294 |
|
295 |
# Organize the final result files
|
296 |
result_dir = f"{output_root}/result"
|
@@ -303,7 +315,7 @@ if __name__ == "__main__":
|
|
303 |
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
304 |
)
|
305 |
copy(video_path, f"{result_dir}/video.mp4")
|
306 |
-
if args.
|
307 |
delete_dir(output_root, keep_subs=["result"])
|
308 |
|
309 |
except Exception as e:
|
@@ -311,3 +323,7 @@ if __name__ == "__main__":
|
|
311 |
continue
|
312 |
|
313 |
logger.info(f"Processing complete. Outputs saved to {args.output_root}")
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
import argparse
|
|
|
19 |
import os
|
20 |
+
import random
|
21 |
import sys
|
22 |
from glob import glob
|
23 |
from shutil import copy, copytree, rmtree
|
24 |
|
25 |
import numpy as np
|
26 |
+
import torch
|
27 |
import trimesh
|
28 |
from PIL import Image
|
29 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
30 |
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
31 |
from embodied_gen.models.delight_model import DelightingModel
|
32 |
from embodied_gen.models.gs_model import GaussianOperator
|
33 |
+
from embodied_gen.models.segment_model import RembgRemover
|
|
|
|
|
|
|
|
|
34 |
from embodied_gen.models.sr_model import ImageRealESRGAN
|
35 |
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
36 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
37 |
+
from embodied_gen.utils.log import logger
|
38 |
+
from embodied_gen.utils.process_media import merge_images_video
|
39 |
from embodied_gen.utils.tags import VERSION
|
40 |
+
from embodied_gen.utils.trender import render_video
|
41 |
from embodied_gen.validators.quality_checkers import (
|
42 |
BaseChecker,
|
43 |
ImageAestheticChecker,
|
|
|
51 |
sys.path.append(os.path.join(current_dir, "../.."))
|
52 |
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
55 |
"~/.cache/torch_extensions"
|
56 |
)
|
57 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
58 |
os.environ["SPCONV_ALGO"] = "native"
|
59 |
+
random.seed(0)
|
60 |
|
61 |
+
logger.info("Loading Models...")
|
62 |
DELIGHT = DelightingModel()
|
63 |
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
|
64 |
RBG_REMOVER = RembgRemover()
|
|
|
|
|
65 |
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
66 |
"microsoft/TRELLIS-image-large"
|
67 |
)
|
68 |
+
# PIPELINE.cuda()
|
69 |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
70 |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
71 |
AESTHETIC_CHECKER = ImageAestheticChecker()
|
72 |
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
def parse_args():
|
|
|
83 |
parser.add_argument(
|
84 |
"--output_root",
|
85 |
type=str,
|
|
|
86 |
help="Root directory for saving outputs.",
|
87 |
)
|
88 |
parser.add_argument(
|
|
|
97 |
default=None,
|
98 |
help="The mass in kg to restore the mesh real weight.",
|
99 |
)
|
100 |
+
parser.add_argument("--asset_type", type=str, nargs="+", default=None)
|
101 |
parser.add_argument("--skip_exists", action="store_true")
|
|
|
102 |
parser.add_argument("--version", type=str, default=VERSION)
|
103 |
+
parser.add_argument("--keep_intermediate", action="store_true")
|
104 |
+
parser.add_argument("--seed", type=int, default=0)
|
105 |
+
parser.add_argument(
|
106 |
+
"--n_retry",
|
107 |
+
type=int,
|
108 |
+
default=2,
|
109 |
+
)
|
110 |
+
args, unknown = parser.parse_known_args()
|
111 |
+
|
112 |
+
return args
|
113 |
+
|
114 |
+
|
115 |
+
def entrypoint(**kwargs):
|
116 |
+
args = parse_args()
|
117 |
+
for k, v in kwargs.items():
|
118 |
+
if hasattr(args, k) and v is not None:
|
119 |
+
setattr(args, k, v)
|
120 |
|
121 |
assert (
|
122 |
args.image_path or args.image_root
|
|
|
126 |
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
|
127 |
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
|
128 |
|
129 |
+
for idx, image_path in enumerate(args.image_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
try:
|
131 |
filename = os.path.basename(image_path).split(".")[0]
|
132 |
output_root = args.output_root
|
|
|
136 |
|
137 |
mesh_out = f"{output_root}/{filename}.obj"
|
138 |
if args.skip_exists and os.path.exists(mesh_out):
|
139 |
+
logger.warning(
|
140 |
f"Skip {image_path}, already processed in {mesh_out}"
|
141 |
)
|
142 |
continue
|
|
|
144 |
image = Image.open(image_path)
|
145 |
image.save(f"{output_root}/{filename}_raw.png")
|
146 |
|
147 |
+
# Segmentation: Get segmented image using Rembg.
|
148 |
seg_path = f"{output_root}/{filename}_cond.png"
|
149 |
+
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
|
150 |
+
seg_image = trellis_preprocess(seg_image)
|
151 |
+
seg_image.save(seg_path)
|
152 |
+
|
153 |
+
seed = args.seed
|
154 |
+
for try_idx in range(args.n_retry):
|
155 |
+
logger.info(
|
156 |
+
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
)
|
158 |
+
# Run the pipeline
|
159 |
+
try:
|
160 |
+
PIPELINE.cuda()
|
161 |
+
outputs = PIPELINE.run(
|
162 |
+
seg_image,
|
163 |
+
preprocess_image=False,
|
164 |
+
seed=(
|
165 |
+
random.randint(0, 100000) if seed is None else seed
|
166 |
+
),
|
167 |
+
# Optional parameters
|
168 |
+
# sparse_structure_sampler_params={
|
169 |
+
# "steps": 12,
|
170 |
+
# "cfg_strength": 7.5,
|
171 |
+
# },
|
172 |
+
# slat_sampler_params={
|
173 |
+
# "steps": 12,
|
174 |
+
# "cfg_strength": 3,
|
175 |
+
# },
|
176 |
+
)
|
177 |
+
PIPELINE.cpu()
|
178 |
+
torch.cuda.empty_cache()
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(
|
181 |
+
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
182 |
+
)
|
183 |
+
continue
|
184 |
+
|
185 |
+
gs_model = outputs["gaussian"][0]
|
186 |
+
mesh_model = outputs["mesh"][0]
|
187 |
+
|
188 |
+
# Save the raw Gaussian model
|
189 |
+
gs_path = mesh_out.replace(".obj", "_gs.ply")
|
190 |
+
gs_model.save_ply(gs_path)
|
191 |
+
|
192 |
+
# Rotate mesh and GS by 90 degrees around Z-axis.
|
193 |
+
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
194 |
+
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
195 |
+
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
196 |
+
|
197 |
+
# Addtional rotation for GS to align mesh.
|
198 |
+
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
199 |
+
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
200 |
+
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
201 |
+
GaussianOperator.resave_ply(
|
202 |
+
in_ply=gs_path,
|
203 |
+
out_ply=aligned_gs_path,
|
204 |
+
instance_pose=pose,
|
205 |
+
device="cpu",
|
206 |
)
|
207 |
+
color_path = os.path.join(output_root, "color.png")
|
208 |
+
render_gs_api(aligned_gs_path, color_path)
|
209 |
+
|
210 |
+
geo_flag, geo_result = GEO_CHECKER([color_path])
|
211 |
+
logger.warning(
|
212 |
+
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
|
213 |
+
)
|
214 |
+
if geo_flag is True or geo_flag is None:
|
215 |
+
break
|
216 |
+
|
217 |
+
seed = random.randint(0, 100000) if seed is not None else None
|
218 |
|
219 |
+
# Render the video for generated 3D asset.
|
|
|
|
|
220 |
color_images = render_video(gs_model)["color"]
|
221 |
normal_images = render_video(mesh_model)["normal"]
|
222 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
223 |
merge_images_video(color_images, normal_images, video_path)
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
mesh = trimesh.Trimesh(
|
226 |
vertices=mesh_model.vertices.cpu().numpy(),
|
227 |
faces=mesh_model.faces.cpu().numpy(),
|
|
|
261 |
min_mass, max_mass = map(float, args.mass_range.split("-"))
|
262 |
asset_attrs["min_mass"] = min_mass
|
263 |
asset_attrs["max_mass"] = max_mass
|
264 |
+
if isinstance(args.asset_type, list) and args.asset_type[idx]:
|
265 |
+
asset_attrs["category"] = args.asset_type[idx]
|
266 |
if args.version:
|
267 |
asset_attrs["version"] = args.version
|
268 |
|
|
|
301 |
]
|
302 |
images_list.append(images)
|
303 |
|
304 |
+
qa_results = BaseChecker.validate(CHECKERS, images_list)
|
305 |
+
urdf_convertor.add_quality_tag(urdf_path, qa_results)
|
306 |
|
307 |
# Organize the final result files
|
308 |
result_dir = f"{output_root}/result"
|
|
|
315 |
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
316 |
)
|
317 |
copy(video_path, f"{result_dir}/video.mp4")
|
318 |
+
if not args.keep_intermediate:
|
319 |
delete_dir(output_root, keep_subs=["result"])
|
320 |
|
321 |
except Exception as e:
|
|
|
323 |
continue
|
324 |
|
325 |
logger.info(f"Processing complete. Outputs saved to {args.output_root}")
|
326 |
+
|
327 |
+
|
328 |
+
if __name__ == "__main__":
|
329 |
+
entrypoint()
|
embodied_gen/scripts/text2image.py
CHANGED
@@ -31,6 +31,7 @@ from embodied_gen.models.text_model import (
|
|
31 |
build_text2img_pipeline,
|
32 |
text2img_gen,
|
33 |
)
|
|
|
34 |
|
35 |
logging.basicConfig(level=logging.INFO)
|
36 |
logger = logging.getLogger(__name__)
|
@@ -85,7 +86,7 @@ def parse_args():
|
|
85 |
parser.add_argument(
|
86 |
"--seed",
|
87 |
type=int,
|
88 |
-
default=
|
89 |
)
|
90 |
args = parser.parse_args()
|
91 |
|
@@ -101,14 +102,7 @@ def entrypoint(
|
|
101 |
if hasattr(args, k) and v is not None:
|
102 |
setattr(args, k, v)
|
103 |
|
104 |
-
prompts = args.prompts
|
105 |
-
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
106 |
-
with open(prompts[0], "r") as f:
|
107 |
-
prompts = f.readlines()
|
108 |
-
prompts = [
|
109 |
-
prompt.strip() for prompt in prompts if prompt.strip() != ""
|
110 |
-
]
|
111 |
-
|
112 |
os.makedirs(args.output_root, exist_ok=True)
|
113 |
|
114 |
ip_img_paths = args.ref_image
|
|
|
31 |
build_text2img_pipeline,
|
32 |
text2img_gen,
|
33 |
)
|
34 |
+
from embodied_gen.utils.process_media import parse_text_prompts
|
35 |
|
36 |
logging.basicConfig(level=logging.INFO)
|
37 |
logger = logging.getLogger(__name__)
|
|
|
86 |
parser.add_argument(
|
87 |
"--seed",
|
88 |
type=int,
|
89 |
+
default=None,
|
90 |
)
|
91 |
args = parser.parse_args()
|
92 |
|
|
|
102 |
if hasattr(args, k) and v is not None:
|
103 |
setattr(args, k, v)
|
104 |
|
105 |
+
prompts = parse_text_prompts(args.prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
os.makedirs(args.output_root, exist_ok=True)
|
107 |
|
108 |
ip_img_paths = args.ref_image
|
embodied_gen/scripts/textto3d.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
from collections import defaultdict
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
from PIL import Image
|
25 |
+
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
26 |
+
from embodied_gen.models.segment_model import RembgRemover
|
27 |
+
from embodied_gen.models.text_model import PROMPT_APPEND
|
28 |
+
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
|
29 |
+
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
30 |
+
from embodied_gen.utils.log import logger
|
31 |
+
from embodied_gen.utils.process_media import (
|
32 |
+
check_object_edge_truncated,
|
33 |
+
render_asset3d,
|
34 |
+
)
|
35 |
+
from embodied_gen.validators.quality_checkers import (
|
36 |
+
ImageSegChecker,
|
37 |
+
SemanticConsistChecker,
|
38 |
+
TextGenAlignChecker,
|
39 |
+
)
|
40 |
+
|
41 |
+
# Avoid huggingface/tokenizers: The current process just got forked.
|
42 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
43 |
+
random.seed(0)
|
44 |
+
|
45 |
+
logger.info("Loading Models...")
|
46 |
+
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
47 |
+
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
48 |
+
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
49 |
+
PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35"))
|
50 |
+
BG_REMOVER = RembgRemover()
|
51 |
+
|
52 |
+
|
53 |
+
__all__ = [
|
54 |
+
"text_to_image",
|
55 |
+
"text_to_3d",
|
56 |
+
]
|
57 |
+
|
58 |
+
|
59 |
+
def text_to_image(
|
60 |
+
prompt: str,
|
61 |
+
save_path: str,
|
62 |
+
n_retry: int,
|
63 |
+
img_denoise_step: int,
|
64 |
+
text_guidance_scale: float,
|
65 |
+
n_img_sample: int,
|
66 |
+
image_hw: tuple[int, int] = (1024, 1024),
|
67 |
+
seed: int = None,
|
68 |
+
) -> bool:
|
69 |
+
select_image = None
|
70 |
+
success_flag = False
|
71 |
+
assert save_path.endswith(".png"), "Image save path must end with `.png`."
|
72 |
+
for try_idx in range(n_retry):
|
73 |
+
if select_image is not None:
|
74 |
+
select_image[0].save(save_path.replace(".png", "_raw.png"))
|
75 |
+
select_image[1].save(save_path)
|
76 |
+
break
|
77 |
+
|
78 |
+
f_prompt = PROMPT_APPEND.format(object=prompt)
|
79 |
+
logger.info(
|
80 |
+
f"Image GEN for {os.path.basename(save_path)}\n"
|
81 |
+
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
|
82 |
+
)
|
83 |
+
torch.cuda.empty_cache()
|
84 |
+
images = PIPE_IMG.run(
|
85 |
+
f_prompt,
|
86 |
+
num_inference_steps=img_denoise_step,
|
87 |
+
guidance_scale=text_guidance_scale,
|
88 |
+
num_images_per_prompt=n_img_sample,
|
89 |
+
height=image_hw[0],
|
90 |
+
width=image_hw[1],
|
91 |
+
generator=(
|
92 |
+
torch.Generator().manual_seed(seed)
|
93 |
+
if seed is not None
|
94 |
+
else None
|
95 |
+
),
|
96 |
+
)
|
97 |
+
|
98 |
+
for idx in range(len(images)):
|
99 |
+
raw_image: Image.Image = images[idx]
|
100 |
+
image = BG_REMOVER(raw_image)
|
101 |
+
image.save(save_path)
|
102 |
+
semantic_flag, semantic_result = SEMANTIC_CHECKER(
|
103 |
+
prompt, [image.convert("RGB")]
|
104 |
+
)
|
105 |
+
seg_flag, seg_result = SEG_CHECKER(
|
106 |
+
[raw_image, image.convert("RGB")]
|
107 |
+
)
|
108 |
+
image_mask = np.array(image)[..., -1]
|
109 |
+
edge_flag = check_object_edge_truncated(image_mask)
|
110 |
+
logger.warning(
|
111 |
+
f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}"
|
112 |
+
)
|
113 |
+
if (
|
114 |
+
(edge_flag and semantic_flag and seg_flag)
|
115 |
+
or (edge_flag and semantic_flag is None)
|
116 |
+
or (edge_flag and seg_flag is None)
|
117 |
+
):
|
118 |
+
select_image = [raw_image, image]
|
119 |
+
success_flag = True
|
120 |
+
break
|
121 |
+
|
122 |
+
seed = random.randint(0, 100000) if seed is not None else None
|
123 |
+
|
124 |
+
return success_flag
|
125 |
+
|
126 |
+
|
127 |
+
def text_to_3d(**kwargs) -> dict:
|
128 |
+
args = parse_args()
|
129 |
+
for k, v in kwargs.items():
|
130 |
+
if hasattr(args, k) and v is not None:
|
131 |
+
setattr(args, k, v)
|
132 |
+
|
133 |
+
if args.asset_names is None or len(args.asset_names) == 0:
|
134 |
+
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
135 |
+
img_save_dir = os.path.join(args.output_root, "images")
|
136 |
+
asset_save_dir = os.path.join(args.output_root, "asset3d")
|
137 |
+
os.makedirs(img_save_dir, exist_ok=True)
|
138 |
+
os.makedirs(asset_save_dir, exist_ok=True)
|
139 |
+
results = defaultdict(dict)
|
140 |
+
for prompt, node in zip(args.prompts, args.asset_names):
|
141 |
+
success_flag = False
|
142 |
+
n_pipe_retry = args.n_pipe_retry
|
143 |
+
seed_img = args.seed_img
|
144 |
+
seed_3d = args.seed_3d
|
145 |
+
while success_flag is False and n_pipe_retry > 0:
|
146 |
+
logger.info(
|
147 |
+
f"GEN pipeline for node {node}\n"
|
148 |
+
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
|
149 |
+
)
|
150 |
+
# Text-to-image GEN
|
151 |
+
save_node = node.replace(" ", "_")
|
152 |
+
gen_image_path = f"{img_save_dir}/{save_node}.png"
|
153 |
+
textgen_flag = text_to_image(
|
154 |
+
prompt,
|
155 |
+
gen_image_path,
|
156 |
+
args.n_image_retry,
|
157 |
+
args.img_denoise_step,
|
158 |
+
args.text_guidance_scale,
|
159 |
+
args.n_img_sample,
|
160 |
+
seed=seed_img,
|
161 |
+
)
|
162 |
+
|
163 |
+
# Asset 3D GEN
|
164 |
+
node_save_dir = f"{asset_save_dir}/{save_node}"
|
165 |
+
asset_type = node if "sample3d_" not in node else None
|
166 |
+
imageto3d_api(
|
167 |
+
image_path=[gen_image_path],
|
168 |
+
output_root=node_save_dir,
|
169 |
+
asset_type=[asset_type],
|
170 |
+
seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
|
171 |
+
n_retry=args.n_asset_retry,
|
172 |
+
keep_intermediate=args.keep_intermediate,
|
173 |
+
)
|
174 |
+
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
|
175 |
+
image_path = render_asset3d(
|
176 |
+
mesh_path,
|
177 |
+
output_root=f"{node_save_dir}/result",
|
178 |
+
num_images=6,
|
179 |
+
elevation=(30, -30),
|
180 |
+
output_subdir="renders",
|
181 |
+
no_index_file=True,
|
182 |
+
)
|
183 |
+
|
184 |
+
check_text = asset_type if asset_type is not None else prompt
|
185 |
+
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
|
186 |
+
logger.warning(
|
187 |
+
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
|
188 |
+
)
|
189 |
+
results["assets"][node] = f"{node_save_dir}/result"
|
190 |
+
results["quality"][node] = qa_result
|
191 |
+
|
192 |
+
if qa_flag is None or qa_flag is True:
|
193 |
+
success_flag = True
|
194 |
+
break
|
195 |
+
|
196 |
+
n_pipe_retry -= 1
|
197 |
+
seed_img = (
|
198 |
+
random.randint(0, 100000) if seed_img is not None else None
|
199 |
+
)
|
200 |
+
seed_3d = (
|
201 |
+
random.randint(0, 100000) if seed_3d is not None else None
|
202 |
+
)
|
203 |
+
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
|
206 |
+
return results
|
207 |
+
|
208 |
+
|
209 |
+
def parse_args():
|
210 |
+
parser = argparse.ArgumentParser(description="3D Layout Generation Config")
|
211 |
+
parser.add_argument("--prompts", nargs="+", help="text descriptions")
|
212 |
+
parser.add_argument(
|
213 |
+
"--output_root",
|
214 |
+
type=str,
|
215 |
+
help="Directory to save outputs",
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
"--asset_names",
|
219 |
+
type=str,
|
220 |
+
nargs="+",
|
221 |
+
default=None,
|
222 |
+
help="Asset names to generate",
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--n_img_sample",
|
226 |
+
type=int,
|
227 |
+
default=3,
|
228 |
+
help="Number of image samples to generate",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--text_guidance_scale",
|
232 |
+
type=float,
|
233 |
+
default=7,
|
234 |
+
help="Text-to-image guidance scale",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--img_denoise_step",
|
238 |
+
type=int,
|
239 |
+
default=25,
|
240 |
+
help="Denoising steps for image generation",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--n_image_retry",
|
244 |
+
type=int,
|
245 |
+
default=2,
|
246 |
+
help="Max retry count for image generation",
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--n_asset_retry",
|
250 |
+
type=int,
|
251 |
+
default=2,
|
252 |
+
help="Max retry count for 3D generation",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--n_pipe_retry",
|
256 |
+
type=int,
|
257 |
+
default=1,
|
258 |
+
help="Max retry count for 3D asset generation",
|
259 |
+
)
|
260 |
+
parser.add_argument(
|
261 |
+
"--seed_img",
|
262 |
+
type=int,
|
263 |
+
default=None,
|
264 |
+
help="Random seed for image generation",
|
265 |
+
)
|
266 |
+
parser.add_argument(
|
267 |
+
"--seed_3d",
|
268 |
+
type=int,
|
269 |
+
default=0,
|
270 |
+
help="Random seed for 3D generation",
|
271 |
+
)
|
272 |
+
parser.add_argument("--keep_intermediate", action="store_true")
|
273 |
+
|
274 |
+
args, unknown = parser.parse_known_args()
|
275 |
+
|
276 |
+
return args
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == "__main__":
|
280 |
+
text_to_3d()
|
embodied_gen/scripts/textto3d.sh
CHANGED
@@ -2,7 +2,9 @@
|
|
2 |
|
3 |
# Initialize variables
|
4 |
prompts=()
|
|
|
5 |
output_root=""
|
|
|
6 |
|
7 |
# Parse arguments
|
8 |
while [[ $# -gt 0 ]]; do
|
@@ -14,10 +16,21 @@ while [[ $# -gt 0 ]]; do
|
|
14 |
shift
|
15 |
done
|
16 |
;;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
--output_root)
|
18 |
output_root="$2"
|
19 |
shift 2
|
20 |
;;
|
|
|
|
|
|
|
|
|
21 |
*)
|
22 |
echo "Unknown argument: $1"
|
23 |
exit 1
|
@@ -28,7 +41,21 @@ done
|
|
28 |
# Validate required arguments
|
29 |
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
|
30 |
echo "Missing required arguments."
|
31 |
-
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
exit 1
|
33 |
fi
|
34 |
|
@@ -37,20 +64,30 @@ echo "Prompts:"
|
|
37 |
for p in "${prompts[@]}"; do
|
38 |
echo " - $p"
|
39 |
done
|
|
|
|
|
|
|
|
|
40 |
echo "Output root: ${output_root}"
|
|
|
41 |
|
42 |
-
# Concatenate prompts for Python command
|
43 |
prompt_args=""
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
done
|
47 |
|
|
|
48 |
# Step 1: Text-to-Image
|
49 |
eval python3 embodied_gen/scripts/text2image.py \
|
50 |
--prompts ${prompt_args} \
|
51 |
-
--output_root "${output_root}/images"
|
|
|
52 |
|
53 |
# Step 2: Image-to-3D
|
54 |
python3 embodied_gen/scripts/imageto3d.py \
|
55 |
--image_root "${output_root}/images" \
|
56 |
-
--output_root "${output_root}/asset3d"
|
|
|
|
2 |
|
3 |
# Initialize variables
|
4 |
prompts=()
|
5 |
+
asset_types=()
|
6 |
output_root=""
|
7 |
+
seed=0
|
8 |
|
9 |
# Parse arguments
|
10 |
while [[ $# -gt 0 ]]; do
|
|
|
16 |
shift
|
17 |
done
|
18 |
;;
|
19 |
+
--asset_types)
|
20 |
+
shift
|
21 |
+
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
|
22 |
+
asset_types+=("$1")
|
23 |
+
shift
|
24 |
+
done
|
25 |
+
;;
|
26 |
--output_root)
|
27 |
output_root="$2"
|
28 |
shift 2
|
29 |
;;
|
30 |
+
--seed)
|
31 |
+
seed="$2"
|
32 |
+
shift 2
|
33 |
+
;;
|
34 |
*)
|
35 |
echo "Unknown argument: $1"
|
36 |
exit 1
|
|
|
41 |
# Validate required arguments
|
42 |
if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
|
43 |
echo "Missing required arguments."
|
44 |
+
echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \
|
45 |
+
--asset_types \"type1\" \"type2\" --seed <seed_value> --output_root <path>"
|
46 |
+
exit 1
|
47 |
+
fi
|
48 |
+
|
49 |
+
# If no asset_types provided, default to ""
|
50 |
+
if [[ ${#asset_types[@]} -eq 0 ]]; then
|
51 |
+
for (( i=0; i<${#prompts[@]}; i++ )); do
|
52 |
+
asset_types+=("")
|
53 |
+
done
|
54 |
+
fi
|
55 |
+
|
56 |
+
# Ensure the number of asset_types matches the number of prompts
|
57 |
+
if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then
|
58 |
+
echo "The number of asset types must match the number of prompts."
|
59 |
exit 1
|
60 |
fi
|
61 |
|
|
|
64 |
for p in "${prompts[@]}"; do
|
65 |
echo " - $p"
|
66 |
done
|
67 |
+
# echo "Asset types:"
|
68 |
+
# for at in "${asset_types[@]}"; do
|
69 |
+
# echo " - $at"
|
70 |
+
# done
|
71 |
echo "Output root: ${output_root}"
|
72 |
+
echo "Seed: ${seed}"
|
73 |
|
74 |
+
# Concatenate prompts and asset types for Python command
|
75 |
prompt_args=""
|
76 |
+
asset_type_args=""
|
77 |
+
for i in "${!prompts[@]}"; do
|
78 |
+
prompt_args+="\"${prompts[$i]}\" "
|
79 |
+
asset_type_args+="\"${asset_types[$i]}\" "
|
80 |
done
|
81 |
|
82 |
+
|
83 |
# Step 1: Text-to-Image
|
84 |
eval python3 embodied_gen/scripts/text2image.py \
|
85 |
--prompts ${prompt_args} \
|
86 |
+
--output_root "${output_root}/images" \
|
87 |
+
--seed ${seed}
|
88 |
|
89 |
# Step 2: Image-to-3D
|
90 |
python3 embodied_gen/scripts/imageto3d.py \
|
91 |
--image_root "${output_root}/images" \
|
92 |
+
--output_root "${output_root}/asset3d" \
|
93 |
+
--asset_type ${asset_type_args}
|
embodied_gen/scripts/texture_gen.sh
CHANGED
@@ -10,10 +10,6 @@ while [[ $# -gt 0 ]]; do
|
|
10 |
prompt="$2"
|
11 |
shift 2
|
12 |
;;
|
13 |
-
--uuid)
|
14 |
-
uuid="$2"
|
15 |
-
shift 2
|
16 |
-
;;
|
17 |
--output_root)
|
18 |
output_root="$2"
|
19 |
shift 2
|
@@ -26,12 +22,13 @@ while [[ $# -gt 0 ]]; do
|
|
26 |
done
|
27 |
|
28 |
|
29 |
-
if [[ -z "$mesh_path" || -z "$prompt" || -z "$
|
30 |
echo "params missing"
|
31 |
-
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --
|
32 |
exit 1
|
33 |
fi
|
34 |
|
|
|
35 |
# Step 1: drender-cli for condition rendering
|
36 |
drender-cli --mesh_path ${mesh_path} \
|
37 |
--output_root ${output_root}/condition \
|
|
|
10 |
prompt="$2"
|
11 |
shift 2
|
12 |
;;
|
|
|
|
|
|
|
|
|
13 |
--output_root)
|
14 |
output_root="$2"
|
15 |
shift 2
|
|
|
22 |
done
|
23 |
|
24 |
|
25 |
+
if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
|
26 |
echo "params missing"
|
27 |
+
echo "usage: bash run.sh --mesh_path <path> --prompt <text> --output_root <path>"
|
28 |
exit 1
|
29 |
fi
|
30 |
|
31 |
+
uuid=$(basename "$output_root")
|
32 |
# Step 1: drender-cli for condition rendering
|
33 |
drender-cli --mesh_path ${mesh_path} \
|
34 |
--output_root ${output_root}/condition \
|
embodied_gen/trainer/gsplat_trainer.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
|
17 |
+
# Both under the Apache License, Version 2.0.
|
18 |
+
|
19 |
+
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
from collections import defaultdict
|
24 |
+
from typing import Dict, Optional, Tuple
|
25 |
+
|
26 |
+
import cv2
|
27 |
+
import imageio
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import tqdm
|
32 |
+
import tyro
|
33 |
+
import yaml
|
34 |
+
from fused_ssim import fused_ssim
|
35 |
+
from gsplat.distributed import cli
|
36 |
+
from gsplat.rendering import rasterization
|
37 |
+
from gsplat.strategy import DefaultStrategy, MCMCStrategy
|
38 |
+
from torch import Tensor
|
39 |
+
from torch.utils.tensorboard import SummaryWriter
|
40 |
+
from torchmetrics.image import (
|
41 |
+
PeakSignalNoiseRatio,
|
42 |
+
StructuralSimilarityIndexMeasure,
|
43 |
+
)
|
44 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
45 |
+
from typing_extensions import Literal, assert_never
|
46 |
+
from embodied_gen.data.datasets import PanoGSplatDataset
|
47 |
+
from embodied_gen.utils.config import GsplatTrainConfig
|
48 |
+
from embodied_gen.utils.gaussian import (
|
49 |
+
create_splats_with_optimizers,
|
50 |
+
export_splats,
|
51 |
+
resize_pinhole_intrinsics,
|
52 |
+
set_random_seed,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class Runner:
|
57 |
+
"""Engine for training and testing from gsplat example.
|
58 |
+
|
59 |
+
Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
local_rank: int,
|
65 |
+
world_rank,
|
66 |
+
world_size: int,
|
67 |
+
cfg: GsplatTrainConfig,
|
68 |
+
) -> None:
|
69 |
+
set_random_seed(42 + local_rank)
|
70 |
+
|
71 |
+
self.cfg = cfg
|
72 |
+
self.world_rank = world_rank
|
73 |
+
self.local_rank = local_rank
|
74 |
+
self.world_size = world_size
|
75 |
+
self.device = f"cuda:{local_rank}"
|
76 |
+
|
77 |
+
# Where to dump results.
|
78 |
+
os.makedirs(cfg.result_dir, exist_ok=True)
|
79 |
+
|
80 |
+
# Setup output directories.
|
81 |
+
self.ckpt_dir = f"{cfg.result_dir}/ckpts"
|
82 |
+
os.makedirs(self.ckpt_dir, exist_ok=True)
|
83 |
+
self.stats_dir = f"{cfg.result_dir}/stats"
|
84 |
+
os.makedirs(self.stats_dir, exist_ok=True)
|
85 |
+
self.render_dir = f"{cfg.result_dir}/renders"
|
86 |
+
os.makedirs(self.render_dir, exist_ok=True)
|
87 |
+
self.ply_dir = f"{cfg.result_dir}/ply"
|
88 |
+
os.makedirs(self.ply_dir, exist_ok=True)
|
89 |
+
|
90 |
+
# Tensorboard
|
91 |
+
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
|
92 |
+
self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
|
93 |
+
self.valset = PanoGSplatDataset(
|
94 |
+
cfg.data_dir, split="train", max_sample_num=6
|
95 |
+
)
|
96 |
+
self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
|
97 |
+
self.scene_scale = cfg.scene_scale
|
98 |
+
|
99 |
+
# Model
|
100 |
+
self.splats, self.optimizers = create_splats_with_optimizers(
|
101 |
+
self.trainset.points,
|
102 |
+
self.trainset.points_rgb,
|
103 |
+
init_num_pts=cfg.init_num_pts,
|
104 |
+
init_extent=cfg.init_extent,
|
105 |
+
init_opacity=cfg.init_opa,
|
106 |
+
init_scale=cfg.init_scale,
|
107 |
+
means_lr=cfg.means_lr,
|
108 |
+
scales_lr=cfg.scales_lr,
|
109 |
+
opacities_lr=cfg.opacities_lr,
|
110 |
+
quats_lr=cfg.quats_lr,
|
111 |
+
sh0_lr=cfg.sh0_lr,
|
112 |
+
shN_lr=cfg.shN_lr,
|
113 |
+
scene_scale=self.scene_scale,
|
114 |
+
sh_degree=cfg.sh_degree,
|
115 |
+
sparse_grad=cfg.sparse_grad,
|
116 |
+
visible_adam=cfg.visible_adam,
|
117 |
+
batch_size=cfg.batch_size,
|
118 |
+
feature_dim=None,
|
119 |
+
device=self.device,
|
120 |
+
world_rank=world_rank,
|
121 |
+
world_size=world_size,
|
122 |
+
)
|
123 |
+
print("Model initialized. Number of GS:", len(self.splats["means"]))
|
124 |
+
|
125 |
+
# Densification Strategy
|
126 |
+
self.cfg.strategy.check_sanity(self.splats, self.optimizers)
|
127 |
+
|
128 |
+
if isinstance(self.cfg.strategy, DefaultStrategy):
|
129 |
+
self.strategy_state = self.cfg.strategy.initialize_state(
|
130 |
+
scene_scale=self.scene_scale
|
131 |
+
)
|
132 |
+
elif isinstance(self.cfg.strategy, MCMCStrategy):
|
133 |
+
self.strategy_state = self.cfg.strategy.initialize_state()
|
134 |
+
else:
|
135 |
+
assert_never(self.cfg.strategy)
|
136 |
+
|
137 |
+
# Losses & Metrics.
|
138 |
+
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
|
139 |
+
self.device
|
140 |
+
)
|
141 |
+
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
|
142 |
+
|
143 |
+
if cfg.lpips_net == "alex":
|
144 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(
|
145 |
+
net_type="alex", normalize=True
|
146 |
+
).to(self.device)
|
147 |
+
elif cfg.lpips_net == "vgg":
|
148 |
+
# The 3DGS official repo uses lpips vgg, which is equivalent with the following:
|
149 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(
|
150 |
+
net_type="vgg", normalize=False
|
151 |
+
).to(self.device)
|
152 |
+
else:
|
153 |
+
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
|
154 |
+
|
155 |
+
def rasterize_splats(
|
156 |
+
self,
|
157 |
+
camtoworlds: Tensor,
|
158 |
+
Ks: Tensor,
|
159 |
+
width: int,
|
160 |
+
height: int,
|
161 |
+
masks: Optional[Tensor] = None,
|
162 |
+
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
|
163 |
+
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
|
164 |
+
**kwargs,
|
165 |
+
) -> Tuple[Tensor, Tensor, Dict]:
|
166 |
+
means = self.splats["means"] # [N, 3]
|
167 |
+
# quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
|
168 |
+
# rasterization does normalization internally
|
169 |
+
quats = self.splats["quats"] # [N, 4]
|
170 |
+
scales = torch.exp(self.splats["scales"]) # [N, 3]
|
171 |
+
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
|
172 |
+
image_ids = kwargs.pop("image_ids", None)
|
173 |
+
|
174 |
+
colors = torch.cat(
|
175 |
+
[self.splats["sh0"], self.splats["shN"]], 1
|
176 |
+
) # [N, K, 3]
|
177 |
+
|
178 |
+
if rasterize_mode is None:
|
179 |
+
rasterize_mode = (
|
180 |
+
"antialiased" if self.cfg.antialiased else "classic"
|
181 |
+
)
|
182 |
+
if camera_model is None:
|
183 |
+
camera_model = self.cfg.camera_model
|
184 |
+
|
185 |
+
render_colors, render_alphas, info = rasterization(
|
186 |
+
means=means,
|
187 |
+
quats=quats,
|
188 |
+
scales=scales,
|
189 |
+
opacities=opacities,
|
190 |
+
colors=colors,
|
191 |
+
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
|
192 |
+
Ks=Ks, # [C, 3, 3]
|
193 |
+
width=width,
|
194 |
+
height=height,
|
195 |
+
packed=self.cfg.packed,
|
196 |
+
absgrad=(
|
197 |
+
self.cfg.strategy.absgrad
|
198 |
+
if isinstance(self.cfg.strategy, DefaultStrategy)
|
199 |
+
else False
|
200 |
+
),
|
201 |
+
sparse_grad=self.cfg.sparse_grad,
|
202 |
+
rasterize_mode=rasterize_mode,
|
203 |
+
distributed=self.world_size > 1,
|
204 |
+
camera_model=self.cfg.camera_model,
|
205 |
+
with_ut=self.cfg.with_ut,
|
206 |
+
with_eval3d=self.cfg.with_eval3d,
|
207 |
+
**kwargs,
|
208 |
+
)
|
209 |
+
if masks is not None:
|
210 |
+
render_colors[~masks] = 0
|
211 |
+
return render_colors, render_alphas, info
|
212 |
+
|
213 |
+
def train(self):
|
214 |
+
cfg = self.cfg
|
215 |
+
device = self.device
|
216 |
+
world_rank = self.world_rank
|
217 |
+
|
218 |
+
# Dump cfg.
|
219 |
+
if world_rank == 0:
|
220 |
+
with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
|
221 |
+
yaml.dump(vars(cfg), f)
|
222 |
+
|
223 |
+
max_steps = cfg.max_steps
|
224 |
+
init_step = 0
|
225 |
+
|
226 |
+
schedulers = [
|
227 |
+
# means has a learning rate schedule, that end at 0.01 of the initial value
|
228 |
+
torch.optim.lr_scheduler.ExponentialLR(
|
229 |
+
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
|
230 |
+
),
|
231 |
+
]
|
232 |
+
trainloader = torch.utils.data.DataLoader(
|
233 |
+
self.trainset,
|
234 |
+
batch_size=cfg.batch_size,
|
235 |
+
shuffle=True,
|
236 |
+
num_workers=4,
|
237 |
+
persistent_workers=True,
|
238 |
+
pin_memory=True,
|
239 |
+
)
|
240 |
+
trainloader_iter = iter(trainloader)
|
241 |
+
|
242 |
+
# Training loop.
|
243 |
+
global_tic = time.time()
|
244 |
+
pbar = tqdm.tqdm(range(init_step, max_steps))
|
245 |
+
for step in pbar:
|
246 |
+
try:
|
247 |
+
data = next(trainloader_iter)
|
248 |
+
except StopIteration:
|
249 |
+
trainloader_iter = iter(trainloader)
|
250 |
+
data = next(trainloader_iter)
|
251 |
+
|
252 |
+
camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
|
253 |
+
Ks = data["K"].to(device) # [1, 3, 3]
|
254 |
+
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
|
255 |
+
image_ids = data["image_id"].to(device)
|
256 |
+
masks = (
|
257 |
+
data["mask"].to(device) if "mask" in data else None
|
258 |
+
) # [1, H, W]
|
259 |
+
if cfg.depth_loss:
|
260 |
+
points = data["points"].to(device) # [1, M, 2]
|
261 |
+
depths_gt = data["depths"].to(device) # [1, M]
|
262 |
+
|
263 |
+
height, width = pixels.shape[1:3]
|
264 |
+
|
265 |
+
# sh schedule
|
266 |
+
sh_degree_to_use = min(
|
267 |
+
step // cfg.sh_degree_interval, cfg.sh_degree
|
268 |
+
)
|
269 |
+
|
270 |
+
# forward
|
271 |
+
renders, alphas, info = self.rasterize_splats(
|
272 |
+
camtoworlds=camtoworlds,
|
273 |
+
Ks=Ks,
|
274 |
+
width=width,
|
275 |
+
height=height,
|
276 |
+
sh_degree=sh_degree_to_use,
|
277 |
+
near_plane=cfg.near_plane,
|
278 |
+
far_plane=cfg.far_plane,
|
279 |
+
image_ids=image_ids,
|
280 |
+
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
|
281 |
+
masks=masks,
|
282 |
+
)
|
283 |
+
if renders.shape[-1] == 4:
|
284 |
+
colors, depths = renders[..., 0:3], renders[..., 3:4]
|
285 |
+
else:
|
286 |
+
colors, depths = renders, None
|
287 |
+
|
288 |
+
if cfg.random_bkgd:
|
289 |
+
bkgd = torch.rand(1, 3, device=device)
|
290 |
+
colors = colors + bkgd * (1.0 - alphas)
|
291 |
+
|
292 |
+
self.cfg.strategy.step_pre_backward(
|
293 |
+
params=self.splats,
|
294 |
+
optimizers=self.optimizers,
|
295 |
+
state=self.strategy_state,
|
296 |
+
step=step,
|
297 |
+
info=info,
|
298 |
+
)
|
299 |
+
|
300 |
+
# loss
|
301 |
+
l1loss = F.l1_loss(colors, pixels)
|
302 |
+
ssimloss = 1.0 - fused_ssim(
|
303 |
+
colors.permute(0, 3, 1, 2),
|
304 |
+
pixels.permute(0, 3, 1, 2),
|
305 |
+
padding="valid",
|
306 |
+
)
|
307 |
+
loss = (
|
308 |
+
l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
|
309 |
+
)
|
310 |
+
if cfg.depth_loss:
|
311 |
+
# query depths from depth map
|
312 |
+
points = torch.stack(
|
313 |
+
[
|
314 |
+
points[:, :, 0] / (width - 1) * 2 - 1,
|
315 |
+
points[:, :, 1] / (height - 1) * 2 - 1,
|
316 |
+
],
|
317 |
+
dim=-1,
|
318 |
+
) # normalize to [-1, 1]
|
319 |
+
grid = points.unsqueeze(2) # [1, M, 1, 2]
|
320 |
+
depths = F.grid_sample(
|
321 |
+
depths.permute(0, 3, 1, 2), grid, align_corners=True
|
322 |
+
) # [1, 1, M, 1]
|
323 |
+
depths = depths.squeeze(3).squeeze(1) # [1, M]
|
324 |
+
# calculate loss in disparity space
|
325 |
+
disp = torch.where(
|
326 |
+
depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
|
327 |
+
)
|
328 |
+
disp_gt = 1.0 / depths_gt # [1, M]
|
329 |
+
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
|
330 |
+
loss += depthloss * cfg.depth_lambda
|
331 |
+
|
332 |
+
# regularizations
|
333 |
+
if cfg.opacity_reg > 0.0:
|
334 |
+
loss += (
|
335 |
+
cfg.opacity_reg
|
336 |
+
* torch.sigmoid(self.splats["opacities"]).mean()
|
337 |
+
)
|
338 |
+
if cfg.scale_reg > 0.0:
|
339 |
+
loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
|
340 |
+
|
341 |
+
loss.backward()
|
342 |
+
|
343 |
+
desc = (
|
344 |
+
f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
|
345 |
+
)
|
346 |
+
if cfg.depth_loss:
|
347 |
+
desc += f"depth loss={depthloss.item():.6f}| "
|
348 |
+
pbar.set_description(desc)
|
349 |
+
|
350 |
+
# write images (gt and render)
|
351 |
+
# if world_rank == 0 and step % 800 == 0:
|
352 |
+
# canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
|
353 |
+
# canvas = canvas.reshape(-1, *canvas.shape[2:])
|
354 |
+
# imageio.imwrite(
|
355 |
+
# f"{self.render_dir}/train_rank{self.world_rank}.png",
|
356 |
+
# (canvas * 255).astype(np.uint8),
|
357 |
+
# )
|
358 |
+
|
359 |
+
if (
|
360 |
+
world_rank == 0
|
361 |
+
and cfg.tb_every > 0
|
362 |
+
and step % cfg.tb_every == 0
|
363 |
+
):
|
364 |
+
mem = torch.cuda.max_memory_allocated() / 1024**3
|
365 |
+
self.writer.add_scalar("train/loss", loss.item(), step)
|
366 |
+
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
|
367 |
+
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
|
368 |
+
self.writer.add_scalar(
|
369 |
+
"train/num_GS", len(self.splats["means"]), step
|
370 |
+
)
|
371 |
+
self.writer.add_scalar("train/mem", mem, step)
|
372 |
+
if cfg.depth_loss:
|
373 |
+
self.writer.add_scalar(
|
374 |
+
"train/depthloss", depthloss.item(), step
|
375 |
+
)
|
376 |
+
if cfg.tb_save_image:
|
377 |
+
canvas = (
|
378 |
+
torch.cat([pixels, colors], dim=2)
|
379 |
+
.detach()
|
380 |
+
.cpu()
|
381 |
+
.numpy()
|
382 |
+
)
|
383 |
+
canvas = canvas.reshape(-1, *canvas.shape[2:])
|
384 |
+
self.writer.add_image("train/render", canvas, step)
|
385 |
+
self.writer.flush()
|
386 |
+
|
387 |
+
# save checkpoint before updating the model
|
388 |
+
if (
|
389 |
+
step in [i - 1 for i in cfg.save_steps]
|
390 |
+
or step == max_steps - 1
|
391 |
+
):
|
392 |
+
mem = torch.cuda.max_memory_allocated() / 1024**3
|
393 |
+
stats = {
|
394 |
+
"mem": mem,
|
395 |
+
"ellipse_time": time.time() - global_tic,
|
396 |
+
"num_GS": len(self.splats["means"]),
|
397 |
+
}
|
398 |
+
print("Step: ", step, stats)
|
399 |
+
with open(
|
400 |
+
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
|
401 |
+
"w",
|
402 |
+
) as f:
|
403 |
+
json.dump(stats, f)
|
404 |
+
data = {"step": step, "splats": self.splats.state_dict()}
|
405 |
+
torch.save(
|
406 |
+
data,
|
407 |
+
f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
|
408 |
+
)
|
409 |
+
if (
|
410 |
+
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
|
411 |
+
) and cfg.save_ply:
|
412 |
+
sh0 = self.splats["sh0"]
|
413 |
+
shN = self.splats["shN"]
|
414 |
+
means = self.splats["means"]
|
415 |
+
scales = self.splats["scales"]
|
416 |
+
quats = self.splats["quats"]
|
417 |
+
opacities = self.splats["opacities"]
|
418 |
+
export_splats(
|
419 |
+
means=means,
|
420 |
+
scales=scales,
|
421 |
+
quats=quats,
|
422 |
+
opacities=opacities,
|
423 |
+
sh0=sh0,
|
424 |
+
shN=shN,
|
425 |
+
format="ply",
|
426 |
+
save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
|
427 |
+
)
|
428 |
+
|
429 |
+
# Turn Gradients into Sparse Tensor before running optimizer
|
430 |
+
if cfg.sparse_grad:
|
431 |
+
assert (
|
432 |
+
cfg.packed
|
433 |
+
), "Sparse gradients only work with packed mode."
|
434 |
+
gaussian_ids = info["gaussian_ids"]
|
435 |
+
for k in self.splats.keys():
|
436 |
+
grad = self.splats[k].grad
|
437 |
+
if grad is None or grad.is_sparse:
|
438 |
+
continue
|
439 |
+
self.splats[k].grad = torch.sparse_coo_tensor(
|
440 |
+
indices=gaussian_ids[None], # [1, nnz]
|
441 |
+
values=grad[gaussian_ids], # [nnz, ...]
|
442 |
+
size=self.splats[k].size(), # [N, ...]
|
443 |
+
is_coalesced=len(Ks) == 1,
|
444 |
+
)
|
445 |
+
|
446 |
+
if cfg.visible_adam:
|
447 |
+
gaussian_cnt = self.splats.means.shape[0]
|
448 |
+
if cfg.packed:
|
449 |
+
visibility_mask = torch.zeros_like(
|
450 |
+
self.splats["opacities"], dtype=bool
|
451 |
+
)
|
452 |
+
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
|
453 |
+
else:
|
454 |
+
visibility_mask = (info["radii"] > 0).all(-1).any(0)
|
455 |
+
|
456 |
+
# optimize
|
457 |
+
for optimizer in self.optimizers.values():
|
458 |
+
if cfg.visible_adam:
|
459 |
+
optimizer.step(visibility_mask)
|
460 |
+
else:
|
461 |
+
optimizer.step()
|
462 |
+
optimizer.zero_grad(set_to_none=True)
|
463 |
+
for scheduler in schedulers:
|
464 |
+
scheduler.step()
|
465 |
+
|
466 |
+
# Run post-backward steps after backward and optimizer
|
467 |
+
if isinstance(self.cfg.strategy, DefaultStrategy):
|
468 |
+
self.cfg.strategy.step_post_backward(
|
469 |
+
params=self.splats,
|
470 |
+
optimizers=self.optimizers,
|
471 |
+
state=self.strategy_state,
|
472 |
+
step=step,
|
473 |
+
info=info,
|
474 |
+
packed=cfg.packed,
|
475 |
+
)
|
476 |
+
elif isinstance(self.cfg.strategy, MCMCStrategy):
|
477 |
+
self.cfg.strategy.step_post_backward(
|
478 |
+
params=self.splats,
|
479 |
+
optimizers=self.optimizers,
|
480 |
+
state=self.strategy_state,
|
481 |
+
step=step,
|
482 |
+
info=info,
|
483 |
+
lr=schedulers[0].get_last_lr()[0],
|
484 |
+
)
|
485 |
+
else:
|
486 |
+
assert_never(self.cfg.strategy)
|
487 |
+
|
488 |
+
# eval the full set
|
489 |
+
if step in [i - 1 for i in cfg.eval_steps]:
|
490 |
+
self.eval(step)
|
491 |
+
self.render_video(step)
|
492 |
+
|
493 |
+
@torch.no_grad()
|
494 |
+
def eval(
|
495 |
+
self,
|
496 |
+
step: int,
|
497 |
+
stage: str = "val",
|
498 |
+
canvas_h: int = 512,
|
499 |
+
canvas_w: int = 1024,
|
500 |
+
):
|
501 |
+
"""Entry for evaluation."""
|
502 |
+
print("Running evaluation...")
|
503 |
+
cfg = self.cfg
|
504 |
+
device = self.device
|
505 |
+
world_rank = self.world_rank
|
506 |
+
|
507 |
+
valloader = torch.utils.data.DataLoader(
|
508 |
+
self.valset, batch_size=1, shuffle=False, num_workers=1
|
509 |
+
)
|
510 |
+
ellipse_time = 0
|
511 |
+
metrics = defaultdict(list)
|
512 |
+
for i, data in enumerate(valloader):
|
513 |
+
camtoworlds = data["camtoworld"].to(device)
|
514 |
+
Ks = data["K"].to(device)
|
515 |
+
pixels = data["image"].to(device) / 255.0
|
516 |
+
height, width = pixels.shape[1:3]
|
517 |
+
masks = data["mask"].to(device) if "mask" in data else None
|
518 |
+
|
519 |
+
pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
|
520 |
+
pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
|
521 |
+
|
522 |
+
torch.cuda.synchronize()
|
523 |
+
tic = time.time()
|
524 |
+
colors, _, _ = self.rasterize_splats(
|
525 |
+
camtoworlds=camtoworlds,
|
526 |
+
Ks=Ks,
|
527 |
+
width=width,
|
528 |
+
height=height,
|
529 |
+
sh_degree=cfg.sh_degree,
|
530 |
+
near_plane=cfg.near_plane,
|
531 |
+
far_plane=cfg.far_plane,
|
532 |
+
masks=masks,
|
533 |
+
) # [1, H, W, 3]
|
534 |
+
torch.cuda.synchronize()
|
535 |
+
ellipse_time += max(time.time() - tic, 1e-10)
|
536 |
+
|
537 |
+
colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
|
538 |
+
colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
|
539 |
+
colors = torch.clamp(colors, 0.0, 1.0)
|
540 |
+
canvas_list = [pixels, colors]
|
541 |
+
|
542 |
+
if world_rank == 0:
|
543 |
+
canvas = torch.cat(canvas_list, dim=2).squeeze(0)
|
544 |
+
canvas = canvas.permute(1, 2, 0) # CHW -> HWC
|
545 |
+
canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
|
546 |
+
cv2.imwrite(
|
547 |
+
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
|
548 |
+
canvas[..., ::-1],
|
549 |
+
)
|
550 |
+
metrics["psnr"].append(self.psnr(colors, pixels))
|
551 |
+
metrics["ssim"].append(self.ssim(colors, pixels))
|
552 |
+
metrics["lpips"].append(self.lpips(colors, pixels))
|
553 |
+
|
554 |
+
if world_rank == 0:
|
555 |
+
ellipse_time /= len(valloader)
|
556 |
+
|
557 |
+
stats = {
|
558 |
+
k: torch.stack(v).mean().item() for k, v in metrics.items()
|
559 |
+
}
|
560 |
+
stats.update(
|
561 |
+
{
|
562 |
+
"ellipse_time": ellipse_time,
|
563 |
+
"num_GS": len(self.splats["means"]),
|
564 |
+
}
|
565 |
+
)
|
566 |
+
print(
|
567 |
+
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
|
568 |
+
f"Time: {stats['ellipse_time']:.3f}s/image "
|
569 |
+
f"Number of GS: {stats['num_GS']}"
|
570 |
+
)
|
571 |
+
# save stats as json
|
572 |
+
with open(
|
573 |
+
f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
|
574 |
+
) as f:
|
575 |
+
json.dump(stats, f)
|
576 |
+
# save stats to tensorboard
|
577 |
+
for k, v in stats.items():
|
578 |
+
self.writer.add_scalar(f"{stage}/{k}", v, step)
|
579 |
+
self.writer.flush()
|
580 |
+
|
581 |
+
@torch.no_grad()
|
582 |
+
def render_video(
|
583 |
+
self, step: int, canvas_h: int = 512, canvas_w: int = 1024
|
584 |
+
):
|
585 |
+
testloader = torch.utils.data.DataLoader(
|
586 |
+
self.testset, batch_size=1, shuffle=False, num_workers=1
|
587 |
+
)
|
588 |
+
|
589 |
+
images_cache = []
|
590 |
+
depth_global_min, depth_global_max = float("inf"), -float("inf")
|
591 |
+
for data in testloader:
|
592 |
+
camtoworlds = data["camtoworld"].to(self.device)
|
593 |
+
Ks = resize_pinhole_intrinsics(
|
594 |
+
data["K"].squeeze(),
|
595 |
+
raw_hw=(data["image_h"].item(), data["image_w"].item()),
|
596 |
+
new_hw=(canvas_h, canvas_w // 2),
|
597 |
+
).to(self.device)
|
598 |
+
renders, _, _ = self.rasterize_splats(
|
599 |
+
camtoworlds=camtoworlds,
|
600 |
+
Ks=Ks[None, ...],
|
601 |
+
width=canvas_w // 2,
|
602 |
+
height=canvas_h,
|
603 |
+
sh_degree=self.cfg.sh_degree,
|
604 |
+
near_plane=self.cfg.near_plane,
|
605 |
+
far_plane=self.cfg.far_plane,
|
606 |
+
render_mode="RGB+ED",
|
607 |
+
) # [1, H, W, 4]
|
608 |
+
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
|
609 |
+
colors = (colors * 255).to(torch.uint8).cpu().numpy()
|
610 |
+
depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
|
611 |
+
images_cache.append([colors, depths])
|
612 |
+
depth_global_min = min(depth_global_min, depths.min().item())
|
613 |
+
depth_global_max = max(depth_global_max, depths.max().item())
|
614 |
+
|
615 |
+
video_path = f"{self.render_dir}/video_step{step}.mp4"
|
616 |
+
writer = imageio.get_writer(video_path, fps=30)
|
617 |
+
for rgb, depth in images_cache:
|
618 |
+
depth_normalized = torch.clip(
|
619 |
+
(depth - depth_global_min)
|
620 |
+
/ (depth_global_max - depth_global_min),
|
621 |
+
0,
|
622 |
+
1,
|
623 |
+
)
|
624 |
+
depth_normalized = (
|
625 |
+
(depth_normalized * 255).to(torch.uint8).cpu().numpy()
|
626 |
+
)
|
627 |
+
depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
|
628 |
+
image = np.concatenate([rgb, depth_map], axis=1)
|
629 |
+
writer.append_data(image)
|
630 |
+
|
631 |
+
writer.close()
|
632 |
+
|
633 |
+
|
634 |
+
def entrypoint(
|
635 |
+
local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
|
636 |
+
):
|
637 |
+
runner = Runner(local_rank, world_rank, world_size, cfg)
|
638 |
+
|
639 |
+
if cfg.ckpt is not None:
|
640 |
+
# run eval only
|
641 |
+
ckpts = [
|
642 |
+
torch.load(file, map_location=runner.device, weights_only=True)
|
643 |
+
for file in cfg.ckpt
|
644 |
+
]
|
645 |
+
for k in runner.splats.keys():
|
646 |
+
runner.splats[k].data = torch.cat(
|
647 |
+
[ckpt["splats"][k] for ckpt in ckpts]
|
648 |
+
)
|
649 |
+
step = ckpts[0]["step"]
|
650 |
+
runner.eval(step=step)
|
651 |
+
runner.render_video(step=step)
|
652 |
+
else:
|
653 |
+
runner.train()
|
654 |
+
runner.render_video(step=cfg.max_steps - 1)
|
655 |
+
|
656 |
+
|
657 |
+
if __name__ == "__main__":
|
658 |
+
configs = {
|
659 |
+
"default": (
|
660 |
+
"Gaussian splatting training using densification heuristics from the original paper.",
|
661 |
+
GsplatTrainConfig(
|
662 |
+
strategy=DefaultStrategy(verbose=True),
|
663 |
+
),
|
664 |
+
),
|
665 |
+
"mcmc": (
|
666 |
+
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
|
667 |
+
GsplatTrainConfig(
|
668 |
+
init_scale=0.1,
|
669 |
+
opacity_reg=0.01,
|
670 |
+
scale_reg=0.01,
|
671 |
+
strategy=MCMCStrategy(verbose=True),
|
672 |
+
),
|
673 |
+
),
|
674 |
+
}
|
675 |
+
cfg = tyro.extras.overridable_config_cli(configs)
|
676 |
+
cfg.adjust_steps(cfg.steps_scaler)
|
677 |
+
|
678 |
+
cli(entrypoint, cfg, verbose=True)
|
embodied_gen/trainer/pono2mesh_trainer.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from embodied_gen.utils.monkey_patches import monkey_patch_pano2room
|
19 |
+
|
20 |
+
monkey_patch_pano2room()
|
21 |
+
|
22 |
+
import os
|
23 |
+
|
24 |
+
import cv2
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import trimesh
|
28 |
+
from equilib import cube2equi, equi2pers
|
29 |
+
from kornia.morphology import dilation
|
30 |
+
from PIL import Image
|
31 |
+
from embodied_gen.models.sr_model import ImageRealESRGAN
|
32 |
+
from embodied_gen.utils.config import Pano2MeshSRConfig
|
33 |
+
from embodied_gen.utils.gaussian import compute_pinhole_intrinsics
|
34 |
+
from embodied_gen.utils.log import logger
|
35 |
+
from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
|
36 |
+
from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
|
37 |
+
PanoFusionDistancePredictor,
|
38 |
+
)
|
39 |
+
from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter
|
40 |
+
from thirdparty.pano2room.modules.mesh_fusion.render import (
|
41 |
+
features_to_world_space_mesh,
|
42 |
+
render_mesh,
|
43 |
+
)
|
44 |
+
from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool
|
45 |
+
from thirdparty.pano2room.utils.camera_utils import gen_pano_rays
|
46 |
+
from thirdparty.pano2room.utils.functions import (
|
47 |
+
depth_to_distance,
|
48 |
+
get_cubemap_views_world_to_cam,
|
49 |
+
resize_image_with_aspect_ratio,
|
50 |
+
rot_z_world_to_cam,
|
51 |
+
tensor_to_pil,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class Pano2MeshSRPipeline:
|
56 |
+
"""Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
|
57 |
+
|
58 |
+
This class integrates several key components including:
|
59 |
+
- Depth estimation from RGB panorama
|
60 |
+
- Inpainting of missing regions under offsets
|
61 |
+
- RGB-D to mesh conversion
|
62 |
+
- Multi-view mesh repair
|
63 |
+
- 3D Gaussian Splatting (3DGS) dataset generation
|
64 |
+
|
65 |
+
Args:
|
66 |
+
config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
|
67 |
+
|
68 |
+
Example:
|
69 |
+
```python
|
70 |
+
pipeline = Pano2MeshSRPipeline(config)
|
71 |
+
pipeline(pano_image='example.png', output_dir='./output')
|
72 |
+
```
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, config: Pano2MeshSRConfig) -> None:
|
76 |
+
self.cfg = config
|
77 |
+
self.device = config.device
|
78 |
+
|
79 |
+
# Init models.
|
80 |
+
self.inpainter = PanoPersFusionInpainter(save_path=None)
|
81 |
+
self.geo_predictor = PanoJointPredictor(save_path=None)
|
82 |
+
self.pano_fusion_distance_predictor = PanoFusionDistancePredictor()
|
83 |
+
self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor)
|
84 |
+
|
85 |
+
# Init poses.
|
86 |
+
cubemap_w2cs = get_cubemap_views_world_to_cam()
|
87 |
+
self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs]
|
88 |
+
self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir)
|
89 |
+
|
90 |
+
kernel = cv2.getStructuringElement(
|
91 |
+
cv2.MORPH_ELLIPSE, self.cfg.kernel_size
|
92 |
+
)
|
93 |
+
self.kernel = torch.from_numpy(kernel).float().to(self.device)
|
94 |
+
|
95 |
+
def init_mesh_params(self) -> None:
|
96 |
+
torch.set_default_device(self.device)
|
97 |
+
self.inpaint_mask = torch.ones(
|
98 |
+
(self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
|
99 |
+
)
|
100 |
+
self.vertices = torch.empty((3, 0), requires_grad=False)
|
101 |
+
self.colors = torch.empty((3, 0), requires_grad=False)
|
102 |
+
self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False)
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def read_camera_pose_file(filepath: str) -> np.ndarray:
|
106 |
+
with open(filepath, "r") as f:
|
107 |
+
values = [float(num) for line in f for num in line.split()]
|
108 |
+
|
109 |
+
return np.array(values).reshape(4, 4)
|
110 |
+
|
111 |
+
def load_camera_poses(
|
112 |
+
self, trajectory_dir: str
|
113 |
+
) -> tuple[np.ndarray, list[torch.Tensor]]:
|
114 |
+
pose_filenames = sorted(
|
115 |
+
[
|
116 |
+
fname
|
117 |
+
for fname in os.listdir(trajectory_dir)
|
118 |
+
if fname.startswith("camera_pose")
|
119 |
+
]
|
120 |
+
)
|
121 |
+
|
122 |
+
pano_pose_world = None
|
123 |
+
relative_poses = []
|
124 |
+
for idx, filename in enumerate(pose_filenames):
|
125 |
+
pose_path = os.path.join(trajectory_dir, filename)
|
126 |
+
pose_matrix = self.read_camera_pose_file(pose_path)
|
127 |
+
|
128 |
+
if pano_pose_world is None:
|
129 |
+
pano_pose_world = pose_matrix.copy()
|
130 |
+
pano_pose_world[0, 3] += self.cfg.pano_center_offset[0]
|
131 |
+
pano_pose_world[2, 3] += self.cfg.pano_center_offset[1]
|
132 |
+
|
133 |
+
# Use different reference for the first 6 cubemap views
|
134 |
+
reference_pose = pose_matrix if idx < 6 else pano_pose_world
|
135 |
+
relative_matrix = pose_matrix @ np.linalg.inv(reference_pose)
|
136 |
+
relative_matrix[0:2, :] *= -1 # flip_xy
|
137 |
+
relative_matrix = (
|
138 |
+
relative_matrix @ rot_z_world_to_cam(180).cpu().numpy()
|
139 |
+
)
|
140 |
+
relative_matrix[:3, 3] *= self.cfg.pose_scale
|
141 |
+
relative_matrix = torch.tensor(
|
142 |
+
relative_matrix, dtype=torch.float32
|
143 |
+
)
|
144 |
+
relative_poses.append(relative_matrix)
|
145 |
+
|
146 |
+
return relative_poses
|
147 |
+
|
148 |
+
def load_inpaint_poses(
|
149 |
+
self, poses: torch.Tensor
|
150 |
+
) -> dict[int, torch.Tensor]:
|
151 |
+
inpaint_poses = dict()
|
152 |
+
sampled_views = poses[:: self.cfg.inpaint_frame_stride]
|
153 |
+
init_pose = torch.eye(4)
|
154 |
+
for idx, w2c_tensor in enumerate(sampled_views):
|
155 |
+
w2c = w2c_tensor.cpu().numpy().astype(np.float32)
|
156 |
+
c2w = np.linalg.inv(w2c)
|
157 |
+
pose_tensor = init_pose.clone()
|
158 |
+
pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3])
|
159 |
+
pose_tensor[:3, 3] *= -1
|
160 |
+
inpaint_poses[idx] = pose_tensor.to(self.device)
|
161 |
+
|
162 |
+
return inpaint_poses
|
163 |
+
|
164 |
+
def project(self, world_to_cam: torch.Tensor):
|
165 |
+
(
|
166 |
+
project_image,
|
167 |
+
project_depth,
|
168 |
+
inpaint_mask,
|
169 |
+
_,
|
170 |
+
z_buf,
|
171 |
+
mesh,
|
172 |
+
) = render_mesh(
|
173 |
+
vertices=self.vertices,
|
174 |
+
faces=self.faces,
|
175 |
+
vertex_features=self.colors,
|
176 |
+
H=self.cfg.cubemap_h,
|
177 |
+
W=self.cfg.cubemap_w,
|
178 |
+
fov_in_degrees=self.cfg.fov,
|
179 |
+
RT=world_to_cam,
|
180 |
+
blur_radius=self.cfg.blur_radius,
|
181 |
+
faces_per_pixel=self.cfg.faces_per_pixel,
|
182 |
+
)
|
183 |
+
project_image = project_image * ~inpaint_mask
|
184 |
+
|
185 |
+
return project_image[:3, ...], inpaint_mask, project_depth
|
186 |
+
|
187 |
+
def render_pano(self, pose: torch.Tensor):
|
188 |
+
cubemap_list = []
|
189 |
+
for cubemap_pose in self.cubemap_w2cs:
|
190 |
+
project_pose = cubemap_pose @ pose
|
191 |
+
rgb, inpaint_mask, depth = self.project(project_pose)
|
192 |
+
distance_map = depth_to_distance(depth[None, ...])
|
193 |
+
mask = inpaint_mask[None, ...]
|
194 |
+
cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0))
|
195 |
+
|
196 |
+
# Set default tensor type for CPU operation in cube2equi
|
197 |
+
with torch.device("cpu"):
|
198 |
+
pano_rgbd = cube2equi(
|
199 |
+
cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w
|
200 |
+
)
|
201 |
+
|
202 |
+
pano_rgb = pano_rgbd[:3, :, :]
|
203 |
+
pano_depth = pano_rgbd[3:4, :, :].squeeze(0)
|
204 |
+
pano_mask = pano_rgbd[4:, :, :].squeeze(0)
|
205 |
+
|
206 |
+
return pano_rgb, pano_depth, pano_mask
|
207 |
+
|
208 |
+
def rgbd_to_mesh(
|
209 |
+
self,
|
210 |
+
rgb: torch.Tensor,
|
211 |
+
depth: torch.Tensor,
|
212 |
+
inpaint_mask: torch.Tensor,
|
213 |
+
world_to_cam: torch.Tensor = None,
|
214 |
+
using_distance_map: bool = True,
|
215 |
+
) -> None:
|
216 |
+
if world_to_cam is None:
|
217 |
+
world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
|
218 |
+
|
219 |
+
if inpaint_mask.sum() == 0:
|
220 |
+
return
|
221 |
+
|
222 |
+
vertices, faces, colors = features_to_world_space_mesh(
|
223 |
+
colors=rgb.squeeze(0),
|
224 |
+
depth=depth,
|
225 |
+
fov_in_degrees=self.cfg.fov,
|
226 |
+
world_to_cam=world_to_cam,
|
227 |
+
mask=inpaint_mask,
|
228 |
+
faces=self.faces,
|
229 |
+
vertices=self.vertices,
|
230 |
+
using_distance_map=using_distance_map,
|
231 |
+
edge_threshold=0.05,
|
232 |
+
)
|
233 |
+
|
234 |
+
faces += self.vertices.shape[1]
|
235 |
+
self.vertices = torch.cat([self.vertices, vertices], dim=1)
|
236 |
+
self.colors = torch.cat([self.colors, colors], dim=1)
|
237 |
+
self.faces = torch.cat([self.faces, faces], dim=1)
|
238 |
+
|
239 |
+
def get_edge_image_by_depth(
|
240 |
+
self, depth: torch.Tensor, dilate_iter: int = 1
|
241 |
+
) -> np.ndarray:
|
242 |
+
if isinstance(depth, torch.Tensor):
|
243 |
+
depth = depth.cpu().detach().numpy()
|
244 |
+
|
245 |
+
gray = (depth / depth.max() * 255).astype(np.uint8)
|
246 |
+
edges = cv2.Canny(gray, 60, 150)
|
247 |
+
if dilate_iter > 0:
|
248 |
+
kernel = np.ones((3, 3), np.uint8)
|
249 |
+
edges = cv2.dilate(edges, kernel, iterations=dilate_iter)
|
250 |
+
|
251 |
+
return edges
|
252 |
+
|
253 |
+
def mesh_repair_by_greedy_view_selection(
|
254 |
+
self, pose_dict: dict[str, torch.Tensor], output_dir: str
|
255 |
+
) -> list:
|
256 |
+
inpainted_panos_w_pose = []
|
257 |
+
while len(pose_dict) > 0:
|
258 |
+
logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
|
259 |
+
sampled_views = []
|
260 |
+
for key, pose in pose_dict.items():
|
261 |
+
pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
|
262 |
+
completeness = torch.sum(1 - pano_mask) / (pano_mask.numel())
|
263 |
+
sampled_views.append((key, completeness.item(), pose))
|
264 |
+
|
265 |
+
if len(sampled_views) == 0:
|
266 |
+
break
|
267 |
+
|
268 |
+
# Find inpainting with least view completeness.
|
269 |
+
sampled_views = sorted(sampled_views, key=lambda x: x[1])
|
270 |
+
key, _, pose = sampled_views[len(sampled_views) * 2 // 3]
|
271 |
+
pose_dict.pop(key)
|
272 |
+
|
273 |
+
pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
|
274 |
+
|
275 |
+
colors = pano_rgb.permute(1, 2, 0).clone()
|
276 |
+
distances = pano_distance.unsqueeze(-1).clone()
|
277 |
+
pano_inpaint_mask = pano_mask.clone()
|
278 |
+
init_pose = pose.clone()
|
279 |
+
normals = None
|
280 |
+
if pano_inpaint_mask.min().item() < 0.5:
|
281 |
+
colors, distances, normals = self.inpaint_panorama(
|
282 |
+
idx=key,
|
283 |
+
colors=colors,
|
284 |
+
distances=distances,
|
285 |
+
pano_mask=pano_inpaint_mask,
|
286 |
+
)
|
287 |
+
|
288 |
+
init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = (
|
289 |
+
-pose[0, 3],
|
290 |
+
pose[2, 3],
|
291 |
+
0,
|
292 |
+
)
|
293 |
+
rays = gen_pano_rays(
|
294 |
+
init_pose, self.cfg.pano_h, self.cfg.pano_w
|
295 |
+
)
|
296 |
+
conflict_mask = self.sup_pool.geo_check(
|
297 |
+
rays, distances.unsqueeze(-1)
|
298 |
+
) # 0 is conflict, 1 not conflict
|
299 |
+
pano_inpaint_mask *= conflict_mask
|
300 |
+
|
301 |
+
self.rgbd_to_mesh(
|
302 |
+
colors.permute(2, 0, 1),
|
303 |
+
distances,
|
304 |
+
pano_inpaint_mask,
|
305 |
+
world_to_cam=pose,
|
306 |
+
)
|
307 |
+
|
308 |
+
self.sup_pool.register_sup_info(
|
309 |
+
pose=init_pose,
|
310 |
+
mask=pano_inpaint_mask.clone(),
|
311 |
+
rgb=colors,
|
312 |
+
distance=distances.unsqueeze(-1),
|
313 |
+
normal=normals,
|
314 |
+
)
|
315 |
+
|
316 |
+
colors = colors.permute(2, 0, 1).unsqueeze(0)
|
317 |
+
inpainted_panos_w_pose.append([colors, pose])
|
318 |
+
|
319 |
+
if self.cfg.visualize:
|
320 |
+
from embodied_gen.data.utils import DiffrastRender
|
321 |
+
|
322 |
+
tensor_to_pil(pano_rgb.unsqueeze(0)).save(
|
323 |
+
f"{output_dir}/rendered_pano_{key}.jpg"
|
324 |
+
)
|
325 |
+
tensor_to_pil(colors).save(
|
326 |
+
f"{output_dir}/inpainted_pano_{key}.jpg"
|
327 |
+
)
|
328 |
+
norm_depth = DiffrastRender.normalize_map_by_mask(
|
329 |
+
distances, torch.ones_like(distances)
|
330 |
+
)
|
331 |
+
heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8)
|
332 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
333 |
+
Image.fromarray(heatmap).save(
|
334 |
+
f"{output_dir}/inpainted_depth_{key}.png"
|
335 |
+
)
|
336 |
+
|
337 |
+
return inpainted_panos_w_pose
|
338 |
+
|
339 |
+
def inpaint_panorama(
|
340 |
+
self,
|
341 |
+
idx: int,
|
342 |
+
colors: torch.Tensor,
|
343 |
+
distances: torch.Tensor,
|
344 |
+
pano_mask: torch.Tensor,
|
345 |
+
) -> tuple[torch.Tensor]:
|
346 |
+
mask = (pano_mask[None, ..., None] > 0.5).float()
|
347 |
+
mask = mask.permute(0, 3, 1, 2)
|
348 |
+
mask = dilation(mask, kernel=self.kernel)
|
349 |
+
mask = mask[0, 0, ..., None] # hwc
|
350 |
+
inpainted_img = self.inpainter.inpaint(idx, colors, mask)
|
351 |
+
inpainted_img = colors * (1 - mask) + inpainted_img * mask
|
352 |
+
inpainted_distances, inpainted_normals = self.geo_predictor(
|
353 |
+
idx,
|
354 |
+
inpainted_img,
|
355 |
+
distances[..., None],
|
356 |
+
mask=mask,
|
357 |
+
reg_loss_weight=0.0,
|
358 |
+
normal_loss_weight=5e-2,
|
359 |
+
normal_tv_loss_weight=5e-2,
|
360 |
+
)
|
361 |
+
|
362 |
+
return inpainted_img, inpainted_distances.squeeze(), inpainted_normals
|
363 |
+
|
364 |
+
def preprocess_pano(
|
365 |
+
self, image: Image.Image | str
|
366 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
367 |
+
if isinstance(image, str):
|
368 |
+
image = Image.open(image)
|
369 |
+
|
370 |
+
image = image.convert("RGB")
|
371 |
+
|
372 |
+
if image.size[0] < image.size[1]:
|
373 |
+
image = image.transpose(Image.TRANSPOSE)
|
374 |
+
|
375 |
+
image = resize_image_with_aspect_ratio(image, self.cfg.pano_w)
|
376 |
+
image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255
|
377 |
+
image_rgb = image_rgb.to(self.device)
|
378 |
+
image_depth = self.pano_fusion_distance_predictor.predict(
|
379 |
+
image_rgb.permute(1, 2, 0)
|
380 |
+
)
|
381 |
+
image_depth = (
|
382 |
+
image_depth / image_depth.max() * self.cfg.depth_scale_factor
|
383 |
+
)
|
384 |
+
|
385 |
+
return image_rgb, image_depth
|
386 |
+
|
387 |
+
def pano_to_perpective(
|
388 |
+
self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
|
389 |
+
) -> torch.Tensor:
|
390 |
+
rots = dict(
|
391 |
+
roll=0,
|
392 |
+
pitch=pitch,
|
393 |
+
yaw=yaw,
|
394 |
+
)
|
395 |
+
perspective = equi2pers(
|
396 |
+
equi=pano_image.squeeze(0),
|
397 |
+
rots=rots,
|
398 |
+
height=self.cfg.cubemap_h,
|
399 |
+
width=self.cfg.cubemap_w,
|
400 |
+
fov_x=fov,
|
401 |
+
mode="bilinear",
|
402 |
+
).unsqueeze(0)
|
403 |
+
|
404 |
+
return perspective
|
405 |
+
|
406 |
+
def pano_to_cubemap(self, pano_rgb: torch.Tensor):
|
407 |
+
# Define six canonical cube directions in (pitch, yaw)
|
408 |
+
directions = [
|
409 |
+
(0, 0),
|
410 |
+
(0, 1.5 * np.pi),
|
411 |
+
(0, 1.0 * np.pi),
|
412 |
+
(0, 0.5 * np.pi),
|
413 |
+
(-0.5 * np.pi, 0),
|
414 |
+
(0.5 * np.pi, 0),
|
415 |
+
]
|
416 |
+
|
417 |
+
cubemaps_rgb = []
|
418 |
+
for pitch, yaw in directions:
|
419 |
+
rgb_view = self.pano_to_perpective(
|
420 |
+
pano_rgb, pitch, yaw, fov=self.cfg.fov
|
421 |
+
)
|
422 |
+
cubemaps_rgb.append(rgb_view.cpu())
|
423 |
+
|
424 |
+
return cubemaps_rgb
|
425 |
+
|
426 |
+
def save_mesh(self, output_path: str) -> None:
|
427 |
+
vertices_np = self.vertices.T.cpu().numpy()
|
428 |
+
colors_np = self.colors.T.cpu().numpy()
|
429 |
+
faces_np = self.faces.T.cpu().numpy()
|
430 |
+
mesh = trimesh.Trimesh(
|
431 |
+
vertices=vertices_np, faces=faces_np, vertex_colors=colors_np
|
432 |
+
)
|
433 |
+
|
434 |
+
mesh.export(output_path)
|
435 |
+
|
436 |
+
def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
|
437 |
+
pose = mesh_pose.clone()
|
438 |
+
pose[0, :] *= -1
|
439 |
+
pose[1, :] *= -1
|
440 |
+
|
441 |
+
Rw2c = pose[:3, :3].cpu().numpy()
|
442 |
+
Tw2c = pose[:3, 3:].cpu().numpy()
|
443 |
+
yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
|
444 |
+
|
445 |
+
Rc2w = (yz_reverse @ Rw2c).T
|
446 |
+
Tc2w = -(Rc2w @ yz_reverse @ Tw2c)
|
447 |
+
c2w = np.concatenate((Rc2w, Tc2w), axis=1)
|
448 |
+
c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
|
449 |
+
|
450 |
+
return c2w
|
451 |
+
|
452 |
+
def __call__(self, pano_image: Image.Image | str, output_dir: str):
|
453 |
+
self.init_mesh_params()
|
454 |
+
pano_rgb, pano_depth = self.preprocess_pano(pano_image)
|
455 |
+
self.sup_pool = SupInfoPool()
|
456 |
+
self.sup_pool.register_sup_info(
|
457 |
+
pose=torch.eye(4).to(self.device),
|
458 |
+
mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]),
|
459 |
+
rgb=pano_rgb.permute(1, 2, 0),
|
460 |
+
distance=pano_depth[..., None],
|
461 |
+
)
|
462 |
+
self.sup_pool.gen_occ_grid(res=256)
|
463 |
+
|
464 |
+
logger.info("Init mesh from pano RGBD image...")
|
465 |
+
depth_edge = self.get_edge_image_by_depth(pano_depth)
|
466 |
+
inpaint_edge_mask = (
|
467 |
+
~torch.from_numpy(depth_edge).to(self.device).bool()
|
468 |
+
)
|
469 |
+
self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask)
|
470 |
+
|
471 |
+
repair_poses = self.load_inpaint_poses(self.camera_poses)
|
472 |
+
inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection(
|
473 |
+
repair_poses, output_dir
|
474 |
+
)
|
475 |
+
torch.cuda.empty_cache()
|
476 |
+
torch.set_default_device("cpu")
|
477 |
+
|
478 |
+
if self.cfg.mesh_file is not None:
|
479 |
+
mesh_path = os.path.join(output_dir, self.cfg.mesh_file)
|
480 |
+
self.save_mesh(mesh_path)
|
481 |
+
|
482 |
+
if self.cfg.gs_data_file is None:
|
483 |
+
return
|
484 |
+
|
485 |
+
logger.info(f"Dump data for 3DGS training...")
|
486 |
+
points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8)
|
487 |
+
data = {
|
488 |
+
"points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3)
|
489 |
+
"points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3)
|
490 |
+
"train": [],
|
491 |
+
"eval": [],
|
492 |
+
}
|
493 |
+
image_h = self.cfg.cubemap_h * self.cfg.upscale_factor
|
494 |
+
image_w = self.cfg.cubemap_w * self.cfg.upscale_factor
|
495 |
+
Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov)
|
496 |
+
for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses):
|
497 |
+
cubemaps = self.pano_to_cubemap(pano_img)
|
498 |
+
for i in range(len(cubemaps)):
|
499 |
+
cubemap = tensor_to_pil(cubemaps[i])
|
500 |
+
cubemap = self.super_model(cubemap)
|
501 |
+
mesh_pose = self.cubemap_w2cs[i] @ pano_pose
|
502 |
+
c2w = self.mesh_pose_to_gs_pose(mesh_pose)
|
503 |
+
data["train"].append(
|
504 |
+
{
|
505 |
+
"camtoworld": c2w.astype(np.float32),
|
506 |
+
"K": Ks.astype(np.float32),
|
507 |
+
"image": np.array(cubemap),
|
508 |
+
"image_h": image_h,
|
509 |
+
"image_w": image_w,
|
510 |
+
"image_id": len(cubemaps) * idx + i,
|
511 |
+
}
|
512 |
+
)
|
513 |
+
|
514 |
+
# Camera poses for evaluation.
|
515 |
+
for idx in range(len(self.camera_poses)):
|
516 |
+
c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx])
|
517 |
+
data["eval"].append(
|
518 |
+
{
|
519 |
+
"camtoworld": c2w.astype(np.float32),
|
520 |
+
"K": Ks.astype(np.float32),
|
521 |
+
"image_h": image_h,
|
522 |
+
"image_w": image_w,
|
523 |
+
"image_id": idx,
|
524 |
+
}
|
525 |
+
)
|
526 |
+
|
527 |
+
data_path = os.path.join(output_dir, self.cfg.gs_data_file)
|
528 |
+
torch.save(data, data_path)
|
529 |
+
|
530 |
+
return
|
531 |
+
|
532 |
+
|
533 |
+
if __name__ == "__main__":
|
534 |
+
output_dir = "outputs/bg_v2/test3"
|
535 |
+
input_pano = "apps/assets/example_scene/result_pano.png"
|
536 |
+
config = Pano2MeshSRConfig()
|
537 |
+
pipeline = Pano2MeshSRPipeline(config)
|
538 |
+
pipeline(input_pano, output_dir)
|
embodied_gen/utils/config.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
from gsplat.strategy import DefaultStrategy, MCMCStrategy
|
21 |
+
from typing_extensions import Literal, assert_never
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
"Pano2MeshSRConfig",
|
25 |
+
"GsplatTrainConfig",
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Pano2MeshSRConfig:
|
31 |
+
mesh_file: str = "mesh_model.ply"
|
32 |
+
gs_data_file: str = "gs_data.pt"
|
33 |
+
device: str = "cuda"
|
34 |
+
blur_radius: int = 0
|
35 |
+
faces_per_pixel: int = 8
|
36 |
+
fov: int = 90
|
37 |
+
pano_w: int = 2048
|
38 |
+
pano_h: int = 1024
|
39 |
+
cubemap_w: int = 512
|
40 |
+
cubemap_h: int = 512
|
41 |
+
pose_scale: float = 0.6
|
42 |
+
pano_center_offset: tuple = (-0.2, 0.3)
|
43 |
+
inpaint_frame_stride: int = 20
|
44 |
+
trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
|
45 |
+
visualize: bool = False
|
46 |
+
depth_scale_factor: float = 3.4092
|
47 |
+
kernel_size: tuple = (9, 9)
|
48 |
+
upscale_factor: int = 4
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class GsplatTrainConfig:
|
53 |
+
# Path to the .pt files. If provide, it will skip training and run evaluation only.
|
54 |
+
ckpt: Optional[List[str]] = None
|
55 |
+
# Render trajectory path
|
56 |
+
render_traj_path: str = "interp"
|
57 |
+
|
58 |
+
# Path to the Mip-NeRF 360 dataset
|
59 |
+
data_dir: str = "outputs/bg"
|
60 |
+
# Downsample factor for the dataset
|
61 |
+
data_factor: int = 4
|
62 |
+
# Directory to save results
|
63 |
+
result_dir: str = "outputs/bg"
|
64 |
+
# Every N images there is a test image
|
65 |
+
test_every: int = 8
|
66 |
+
# Random crop size for training (experimental)
|
67 |
+
patch_size: Optional[int] = None
|
68 |
+
# A global scaler that applies to the scene size related parameters
|
69 |
+
global_scale: float = 1.0
|
70 |
+
# Normalize the world space
|
71 |
+
normalize_world_space: bool = True
|
72 |
+
# Camera model
|
73 |
+
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
|
74 |
+
|
75 |
+
# Port for the viewer server
|
76 |
+
port: int = 8080
|
77 |
+
|
78 |
+
# Batch size for training. Learning rates are scaled automatically
|
79 |
+
batch_size: int = 1
|
80 |
+
# A global factor to scale the number of training steps
|
81 |
+
steps_scaler: float = 1.0
|
82 |
+
|
83 |
+
# Number of training steps
|
84 |
+
max_steps: int = 30_000
|
85 |
+
# Steps to evaluate the model
|
86 |
+
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
|
87 |
+
# Steps to save the model
|
88 |
+
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
|
89 |
+
# Whether to save ply file (storage size can be large)
|
90 |
+
save_ply: bool = True
|
91 |
+
# Steps to save the model as ply
|
92 |
+
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
|
93 |
+
# Whether to disable video generation during training and evaluation
|
94 |
+
disable_video: bool = False
|
95 |
+
|
96 |
+
# Initial number of GSs. Ignored if using sfm
|
97 |
+
init_num_pts: int = 100_000
|
98 |
+
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
|
99 |
+
init_extent: float = 3.0
|
100 |
+
# Degree of spherical harmonics
|
101 |
+
sh_degree: int = 1
|
102 |
+
# Turn on another SH degree every this steps
|
103 |
+
sh_degree_interval: int = 1000
|
104 |
+
# Initial opacity of GS
|
105 |
+
init_opa: float = 0.1
|
106 |
+
# Initial scale of GS
|
107 |
+
init_scale: float = 1.0
|
108 |
+
# Weight for SSIM loss
|
109 |
+
ssim_lambda: float = 0.2
|
110 |
+
|
111 |
+
# Near plane clipping distance
|
112 |
+
near_plane: float = 0.01
|
113 |
+
# Far plane clipping distance
|
114 |
+
far_plane: float = 1e10
|
115 |
+
|
116 |
+
# Strategy for GS densification
|
117 |
+
strategy: Union[DefaultStrategy, MCMCStrategy] = field(
|
118 |
+
default_factory=DefaultStrategy
|
119 |
+
)
|
120 |
+
# Use packed mode for rasterization, this leads to less memory usage but slightly slower.
|
121 |
+
packed: bool = False
|
122 |
+
# Use sparse gradients for optimization. (experimental)
|
123 |
+
sparse_grad: bool = False
|
124 |
+
# Use visible adam from Taming 3DGS. (experimental)
|
125 |
+
visible_adam: bool = False
|
126 |
+
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
|
127 |
+
antialiased: bool = False
|
128 |
+
|
129 |
+
# Use random background for training to discourage transparency
|
130 |
+
random_bkgd: bool = False
|
131 |
+
|
132 |
+
# LR for 3D point positions
|
133 |
+
means_lr: float = 1.6e-4
|
134 |
+
# LR for Gaussian scale factors
|
135 |
+
scales_lr: float = 5e-3
|
136 |
+
# LR for alpha blending weights
|
137 |
+
opacities_lr: float = 5e-2
|
138 |
+
# LR for orientation (quaternions)
|
139 |
+
quats_lr: float = 1e-3
|
140 |
+
# LR for SH band 0 (brightness)
|
141 |
+
sh0_lr: float = 2.5e-3
|
142 |
+
# LR for higher-order SH (detail)
|
143 |
+
shN_lr: float = 2.5e-3 / 20
|
144 |
+
|
145 |
+
# Opacity regularization
|
146 |
+
opacity_reg: float = 0.0
|
147 |
+
# Scale regularization
|
148 |
+
scale_reg: float = 0.0
|
149 |
+
|
150 |
+
# Enable depth loss. (experimental)
|
151 |
+
depth_loss: bool = False
|
152 |
+
# Weight for depth loss
|
153 |
+
depth_lambda: float = 1e-2
|
154 |
+
|
155 |
+
# Dump information to tensorboard every this steps
|
156 |
+
tb_every: int = 200
|
157 |
+
# Save training images to tensorboard
|
158 |
+
tb_save_image: bool = False
|
159 |
+
|
160 |
+
lpips_net: Literal["vgg", "alex"] = "alex"
|
161 |
+
|
162 |
+
# 3DGUT (uncented transform + eval 3D)
|
163 |
+
with_ut: bool = False
|
164 |
+
with_eval3d: bool = False
|
165 |
+
|
166 |
+
scene_scale: float = 1.0
|
167 |
+
|
168 |
+
def adjust_steps(self, factor: float):
|
169 |
+
self.eval_steps = [int(i * factor) for i in self.eval_steps]
|
170 |
+
self.save_steps = [int(i * factor) for i in self.save_steps]
|
171 |
+
self.ply_steps = [int(i * factor) for i in self.ply_steps]
|
172 |
+
self.max_steps = int(self.max_steps * factor)
|
173 |
+
self.sh_degree_interval = int(self.sh_degree_interval * factor)
|
174 |
+
|
175 |
+
strategy = self.strategy
|
176 |
+
if isinstance(strategy, DefaultStrategy):
|
177 |
+
strategy.refine_start_iter = int(
|
178 |
+
strategy.refine_start_iter * factor
|
179 |
+
)
|
180 |
+
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
|
181 |
+
strategy.reset_every = int(strategy.reset_every * factor)
|
182 |
+
strategy.refine_every = int(strategy.refine_every * factor)
|
183 |
+
elif isinstance(strategy, MCMCStrategy):
|
184 |
+
strategy.refine_start_iter = int(
|
185 |
+
strategy.refine_start_iter * factor
|
186 |
+
)
|
187 |
+
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
|
188 |
+
strategy.refine_every = int(strategy.refine_every * factor)
|
189 |
+
else:
|
190 |
+
assert_never(strategy)
|
embodied_gen/utils/enum.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from enum import Enum
|
19 |
+
|
20 |
+
from dataclasses_json import DataClassJsonMixin
|
21 |
+
|
22 |
+
__all__ = [
|
23 |
+
"RenderItems",
|
24 |
+
"Scene3DItemEnum",
|
25 |
+
"SpatialRelationEnum",
|
26 |
+
"RobotItemEnum",
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class RenderItems(str, Enum):
|
32 |
+
IMAGE = "image_color"
|
33 |
+
ALPHA = "image_mask"
|
34 |
+
VIEW_NORMAL = "image_view_normal"
|
35 |
+
GLOBAL_NORMAL = "image_global_normal"
|
36 |
+
POSITION_MAP = "image_position"
|
37 |
+
DEPTH = "image_depth"
|
38 |
+
ALBEDO = "image_albedo"
|
39 |
+
DIFFUSE = "image_diffuse"
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class Scene3DItemEnum(str, Enum):
|
44 |
+
BACKGROUND = "background"
|
45 |
+
CONTEXT = "context"
|
46 |
+
ROBOT = "robot"
|
47 |
+
MANIPULATED_OBJS = "manipulated_objs"
|
48 |
+
DISTRACTOR_OBJS = "distractor_objs"
|
49 |
+
OTHERS = "others"
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def object_list(cls, layout_relation: dict) -> list:
|
53 |
+
return (
|
54 |
+
[
|
55 |
+
layout_relation[cls.BACKGROUND.value],
|
56 |
+
layout_relation[cls.CONTEXT.value],
|
57 |
+
]
|
58 |
+
+ layout_relation[cls.MANIPULATED_OBJS.value]
|
59 |
+
+ layout_relation[cls.DISTRACTOR_OBJS.value]
|
60 |
+
)
|
61 |
+
|
62 |
+
@classmethod
|
63 |
+
def object_mapping(cls, layout_relation):
|
64 |
+
relation_mapping = {
|
65 |
+
# layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
|
66 |
+
layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
|
67 |
+
layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value,
|
68 |
+
}
|
69 |
+
relation_mapping.update(
|
70 |
+
{
|
71 |
+
item: cls.MANIPULATED_OBJS.value
|
72 |
+
for item in layout_relation[cls.MANIPULATED_OBJS.value]
|
73 |
+
}
|
74 |
+
)
|
75 |
+
relation_mapping.update(
|
76 |
+
{
|
77 |
+
item: cls.DISTRACTOR_OBJS.value
|
78 |
+
for item in layout_relation[cls.DISTRACTOR_OBJS.value]
|
79 |
+
}
|
80 |
+
)
|
81 |
+
|
82 |
+
return relation_mapping
|
83 |
+
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class SpatialRelationEnum(str, Enum):
|
87 |
+
ON = "ON" # objects on the table
|
88 |
+
IN = "IN" # objects in the room
|
89 |
+
INSIDE = "INSIDE" # objects inside the shelf/rack
|
90 |
+
FLOOR = "FLOOR" # object floor room/bin
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class RobotItemEnum(str, Enum):
|
95 |
+
FRANKA = "franka"
|
96 |
+
UR5 = "ur5"
|
97 |
+
PIPER = "piper"
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class LayoutInfo(DataClassJsonMixin):
|
102 |
+
tree: dict[str, list]
|
103 |
+
relation: dict[str, str | list[str]]
|
104 |
+
objs_desc: dict[str, str] = field(default_factory=dict)
|
105 |
+
assets: dict[str, str] = field(default_factory=dict)
|
106 |
+
quality: dict[str, str] = field(default_factory=dict)
|
107 |
+
position: dict[str, list[float]] = field(default_factory=dict)
|
embodied_gen/utils/gaussian.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
# Part of the code comes from https://github.com/nerfstudio-project/gsplat
|
17 |
+
# Both under the Apache License, Version 2.0.
|
18 |
+
|
19 |
+
|
20 |
+
import math
|
21 |
+
import random
|
22 |
+
from io import BytesIO
|
23 |
+
from typing import Dict, Literal, Optional, Tuple
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import trimesh
|
28 |
+
from gsplat.optimizers import SelectiveAdam
|
29 |
+
from scipy.spatial.transform import Rotation
|
30 |
+
from sklearn.neighbors import NearestNeighbors
|
31 |
+
from torch import Tensor
|
32 |
+
from embodied_gen.models.gs_model import GaussianOperator
|
33 |
+
|
34 |
+
__all__ = [
|
35 |
+
"set_random_seed",
|
36 |
+
"export_splats",
|
37 |
+
"create_splats_with_optimizers",
|
38 |
+
"compute_pinhole_intrinsics",
|
39 |
+
"resize_pinhole_intrinsics",
|
40 |
+
"restore_scene_scale_and_position",
|
41 |
+
]
|
42 |
+
|
43 |
+
|
44 |
+
def knn(x: Tensor, K: int = 4) -> Tensor:
|
45 |
+
x_np = x.cpu().numpy()
|
46 |
+
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
|
47 |
+
distances, _ = model.kneighbors(x_np)
|
48 |
+
return torch.from_numpy(distances).to(x)
|
49 |
+
|
50 |
+
|
51 |
+
def rgb_to_sh(rgb: Tensor) -> Tensor:
|
52 |
+
C0 = 0.28209479177387814
|
53 |
+
return (rgb - 0.5) / C0
|
54 |
+
|
55 |
+
|
56 |
+
def set_random_seed(seed: int):
|
57 |
+
random.seed(seed)
|
58 |
+
np.random.seed(seed)
|
59 |
+
torch.manual_seed(seed)
|
60 |
+
|
61 |
+
|
62 |
+
def splat2ply_bytes(
|
63 |
+
means: torch.Tensor,
|
64 |
+
scales: torch.Tensor,
|
65 |
+
quats: torch.Tensor,
|
66 |
+
opacities: torch.Tensor,
|
67 |
+
sh0: torch.Tensor,
|
68 |
+
shN: torch.Tensor,
|
69 |
+
) -> bytes:
|
70 |
+
num_splats = means.shape[0]
|
71 |
+
buffer = BytesIO()
|
72 |
+
|
73 |
+
# Write PLY header
|
74 |
+
buffer.write(b"ply\n")
|
75 |
+
buffer.write(b"format binary_little_endian 1.0\n")
|
76 |
+
buffer.write(f"element vertex {num_splats}\n".encode())
|
77 |
+
buffer.write(b"property float x\n")
|
78 |
+
buffer.write(b"property float y\n")
|
79 |
+
buffer.write(b"property float z\n")
|
80 |
+
for i, data in enumerate([sh0, shN]):
|
81 |
+
prefix = "f_dc" if i == 0 else "f_rest"
|
82 |
+
for j in range(data.shape[1]):
|
83 |
+
buffer.write(f"property float {prefix}_{j}\n".encode())
|
84 |
+
buffer.write(b"property float opacity\n")
|
85 |
+
for i in range(scales.shape[1]):
|
86 |
+
buffer.write(f"property float scale_{i}\n".encode())
|
87 |
+
for i in range(quats.shape[1]):
|
88 |
+
buffer.write(f"property float rot_{i}\n".encode())
|
89 |
+
buffer.write(b"end_header\n")
|
90 |
+
|
91 |
+
# Concatenate all tensors in the correct order
|
92 |
+
splat_data = torch.cat(
|
93 |
+
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
|
94 |
+
)
|
95 |
+
# Ensure correct dtype
|
96 |
+
splat_data = splat_data.to(torch.float32)
|
97 |
+
|
98 |
+
# Write binary data
|
99 |
+
float_dtype = np.dtype(np.float32).newbyteorder("<")
|
100 |
+
buffer.write(
|
101 |
+
splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
|
102 |
+
)
|
103 |
+
|
104 |
+
return buffer.getvalue()
|
105 |
+
|
106 |
+
|
107 |
+
def export_splats(
|
108 |
+
means: torch.Tensor,
|
109 |
+
scales: torch.Tensor,
|
110 |
+
quats: torch.Tensor,
|
111 |
+
opacities: torch.Tensor,
|
112 |
+
sh0: torch.Tensor,
|
113 |
+
shN: torch.Tensor,
|
114 |
+
format: Literal["ply"] = "ply",
|
115 |
+
save_to: Optional[str] = None,
|
116 |
+
) -> bytes:
|
117 |
+
"""Export a Gaussian Splats model to bytes in PLY file format."""
|
118 |
+
total_splats = means.shape[0]
|
119 |
+
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
|
120 |
+
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
|
121 |
+
assert quats.shape == (
|
122 |
+
total_splats,
|
123 |
+
4,
|
124 |
+
), "Quaternions must be of shape (N, 4)"
|
125 |
+
assert opacities.shape == (
|
126 |
+
total_splats,
|
127 |
+
), "Opacities must be of shape (N,)"
|
128 |
+
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
|
129 |
+
assert (
|
130 |
+
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
|
131 |
+
), f"shN must be of shape (N, K, 3), got {shN.shape}"
|
132 |
+
|
133 |
+
# Reshape spherical harmonics
|
134 |
+
sh0 = sh0.squeeze(1) # Shape (N, 3)
|
135 |
+
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
|
136 |
+
|
137 |
+
# Check for NaN or Inf values
|
138 |
+
invalid_mask = (
|
139 |
+
torch.isnan(means).any(dim=1)
|
140 |
+
| torch.isinf(means).any(dim=1)
|
141 |
+
| torch.isnan(scales).any(dim=1)
|
142 |
+
| torch.isinf(scales).any(dim=1)
|
143 |
+
| torch.isnan(quats).any(dim=1)
|
144 |
+
| torch.isinf(quats).any(dim=1)
|
145 |
+
| torch.isnan(opacities).any(dim=0)
|
146 |
+
| torch.isinf(opacities).any(dim=0)
|
147 |
+
| torch.isnan(sh0).any(dim=1)
|
148 |
+
| torch.isinf(sh0).any(dim=1)
|
149 |
+
| torch.isnan(shN).any(dim=1)
|
150 |
+
| torch.isinf(shN).any(dim=1)
|
151 |
+
)
|
152 |
+
|
153 |
+
# Filter out invalid entries
|
154 |
+
valid_mask = ~invalid_mask
|
155 |
+
means = means[valid_mask]
|
156 |
+
scales = scales[valid_mask]
|
157 |
+
quats = quats[valid_mask]
|
158 |
+
opacities = opacities[valid_mask]
|
159 |
+
sh0 = sh0[valid_mask]
|
160 |
+
shN = shN[valid_mask]
|
161 |
+
|
162 |
+
if format == "ply":
|
163 |
+
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Unsupported format: {format}")
|
166 |
+
|
167 |
+
if save_to:
|
168 |
+
with open(save_to, "wb") as binary_file:
|
169 |
+
binary_file.write(data)
|
170 |
+
|
171 |
+
return data
|
172 |
+
|
173 |
+
|
174 |
+
def create_splats_with_optimizers(
|
175 |
+
points: np.ndarray = None,
|
176 |
+
points_rgb: np.ndarray = None,
|
177 |
+
init_num_pts: int = 100_000,
|
178 |
+
init_extent: float = 3.0,
|
179 |
+
init_opacity: float = 0.1,
|
180 |
+
init_scale: float = 1.0,
|
181 |
+
means_lr: float = 1.6e-4,
|
182 |
+
scales_lr: float = 5e-3,
|
183 |
+
opacities_lr: float = 5e-2,
|
184 |
+
quats_lr: float = 1e-3,
|
185 |
+
sh0_lr: float = 2.5e-3,
|
186 |
+
shN_lr: float = 2.5e-3 / 20,
|
187 |
+
scene_scale: float = 1.0,
|
188 |
+
sh_degree: int = 3,
|
189 |
+
sparse_grad: bool = False,
|
190 |
+
visible_adam: bool = False,
|
191 |
+
batch_size: int = 1,
|
192 |
+
feature_dim: Optional[int] = None,
|
193 |
+
device: str = "cuda",
|
194 |
+
world_rank: int = 0,
|
195 |
+
world_size: int = 1,
|
196 |
+
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
|
197 |
+
if points is not None and points_rgb is not None:
|
198 |
+
points = torch.from_numpy(points).float()
|
199 |
+
rgbs = torch.from_numpy(points_rgb / 255.0).float()
|
200 |
+
else:
|
201 |
+
points = (
|
202 |
+
init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
|
203 |
+
)
|
204 |
+
rgbs = torch.rand((init_num_pts, 3))
|
205 |
+
|
206 |
+
# Initialize the GS size to be the average dist of the 3 nearest neighbors
|
207 |
+
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
|
208 |
+
dist_avg = torch.sqrt(dist2_avg)
|
209 |
+
scales = (
|
210 |
+
torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
|
211 |
+
) # [N, 3]
|
212 |
+
|
213 |
+
# Distribute the GSs to different ranks (also works for single rank)
|
214 |
+
points = points[world_rank::world_size]
|
215 |
+
rgbs = rgbs[world_rank::world_size]
|
216 |
+
scales = scales[world_rank::world_size]
|
217 |
+
|
218 |
+
N = points.shape[0]
|
219 |
+
quats = torch.rand((N, 4)) # [N, 4]
|
220 |
+
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
|
221 |
+
|
222 |
+
params = [
|
223 |
+
# name, value, lr
|
224 |
+
("means", torch.nn.Parameter(points), means_lr * scene_scale),
|
225 |
+
("scales", torch.nn.Parameter(scales), scales_lr),
|
226 |
+
("quats", torch.nn.Parameter(quats), quats_lr),
|
227 |
+
("opacities", torch.nn.Parameter(opacities), opacities_lr),
|
228 |
+
]
|
229 |
+
|
230 |
+
if feature_dim is None:
|
231 |
+
# color is SH coefficients.
|
232 |
+
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
|
233 |
+
colors[:, 0, :] = rgb_to_sh(rgbs)
|
234 |
+
params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
|
235 |
+
params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
|
236 |
+
else:
|
237 |
+
# features will be used for appearance and view-dependent shading
|
238 |
+
features = torch.rand(N, feature_dim) # [N, feature_dim]
|
239 |
+
params.append(("features", torch.nn.Parameter(features), sh0_lr))
|
240 |
+
colors = torch.logit(rgbs) # [N, 3]
|
241 |
+
params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
|
242 |
+
|
243 |
+
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
|
244 |
+
# Scale learning rate based on batch size, reference:
|
245 |
+
# https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
|
246 |
+
# Note that this would not make the training exactly equivalent, see
|
247 |
+
# https://arxiv.org/pdf/2402.18824v1
|
248 |
+
BS = batch_size * world_size
|
249 |
+
optimizer_class = None
|
250 |
+
if sparse_grad:
|
251 |
+
optimizer_class = torch.optim.SparseAdam
|
252 |
+
elif visible_adam:
|
253 |
+
optimizer_class = SelectiveAdam
|
254 |
+
else:
|
255 |
+
optimizer_class = torch.optim.Adam
|
256 |
+
optimizers = {
|
257 |
+
name: optimizer_class(
|
258 |
+
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
|
259 |
+
eps=1e-15 / math.sqrt(BS),
|
260 |
+
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
|
261 |
+
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
|
262 |
+
)
|
263 |
+
for name, _, lr in params
|
264 |
+
}
|
265 |
+
return splats, optimizers
|
266 |
+
|
267 |
+
|
268 |
+
def compute_pinhole_intrinsics(
|
269 |
+
image_w: int, image_h: int, fov_deg: float
|
270 |
+
) -> np.ndarray:
|
271 |
+
fov_rad = np.deg2rad(fov_deg)
|
272 |
+
fx = image_w / (2 * np.tan(fov_rad / 2))
|
273 |
+
fy = fx # assuming square pixels
|
274 |
+
cx = image_w / 2
|
275 |
+
cy = image_h / 2
|
276 |
+
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
277 |
+
|
278 |
+
return K
|
279 |
+
|
280 |
+
|
281 |
+
def resize_pinhole_intrinsics(
|
282 |
+
raw_K: np.ndarray | torch.Tensor,
|
283 |
+
raw_hw: tuple[int, int],
|
284 |
+
new_hw: tuple[int, int],
|
285 |
+
) -> np.ndarray:
|
286 |
+
raw_h, raw_w = raw_hw
|
287 |
+
new_h, new_w = new_hw
|
288 |
+
|
289 |
+
scale_x = new_w / raw_w
|
290 |
+
scale_y = new_h / raw_h
|
291 |
+
|
292 |
+
new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
|
293 |
+
new_K[0, 0] *= scale_x # fx
|
294 |
+
new_K[0, 2] *= scale_x # cx
|
295 |
+
new_K[1, 1] *= scale_y # fy
|
296 |
+
new_K[1, 2] *= scale_y # cy
|
297 |
+
|
298 |
+
return new_K
|
299 |
+
|
300 |
+
|
301 |
+
def restore_scene_scale_and_position(
|
302 |
+
real_height: float, mesh_path: str, gs_path: str
|
303 |
+
) -> None:
|
304 |
+
"""Scales a mesh and corresponding GS model to match a given real-world height.
|
305 |
+
|
306 |
+
Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
|
307 |
+
applies scaling and vertical alignment, and updates both the mesh and GS model.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
real_height (float): Target real-world height among Z axis.
|
311 |
+
mesh_path (str): Path to the input mesh file.
|
312 |
+
gs_path (str): Path to the Gaussian Splatting model file.
|
313 |
+
"""
|
314 |
+
mesh = trimesh.load(mesh_path)
|
315 |
+
z_min = np.percentile(mesh.vertices[:, 1], 1)
|
316 |
+
z_max = np.percentile(mesh.vertices[:, 1], 99)
|
317 |
+
height = z_max - z_min
|
318 |
+
scale = real_height / height
|
319 |
+
|
320 |
+
rot = Rotation.from_quat([0, 1, 0, 0])
|
321 |
+
mesh.vertices = rot.apply(mesh.vertices)
|
322 |
+
mesh.vertices[:, 1] -= z_min
|
323 |
+
mesh.vertices *= scale
|
324 |
+
mesh.export(mesh_path)
|
325 |
+
|
326 |
+
gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
|
327 |
+
gs_model = gs_model.get_gaussians(
|
328 |
+
instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
|
329 |
+
)
|
330 |
+
gs_model.rescale(scale)
|
331 |
+
gs_model.save_to_ply(gs_path)
|
embodied_gen/utils/gpt_clients.py
CHANGED
@@ -30,12 +30,20 @@ from tenacity import (
|
|
30 |
stop_after_delay,
|
31 |
wait_random_exponential,
|
32 |
)
|
33 |
-
from embodied_gen.utils.process_media import
|
34 |
|
35 |
-
logging.
|
|
|
36 |
logger = logging.getLogger(__name__)
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
class GPTclient:
|
40 |
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
41 |
|
@@ -45,6 +53,7 @@ class GPTclient:
|
|
45 |
api_key: str,
|
46 |
model_name: str = "yfb-gpt-4o",
|
47 |
api_version: str = None,
|
|
|
48 |
verbose: bool = False,
|
49 |
):
|
50 |
if api_version is not None:
|
@@ -63,6 +72,9 @@ class GPTclient:
|
|
63 |
self.model_name = model_name
|
64 |
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
|
65 |
self.verbose = verbose
|
|
|
|
|
|
|
66 |
logger.info(f"Using GPT model: {self.model_name}.")
|
67 |
|
68 |
@retry(
|
@@ -77,6 +89,7 @@ class GPTclient:
|
|
77 |
text_prompt: str,
|
78 |
image_base64: Optional[list[str | Image.Image]] = None,
|
79 |
system_role: Optional[str] = None,
|
|
|
80 |
) -> Optional[str]:
|
81 |
"""Queries the GPT model with a text and optional image prompts.
|
82 |
|
@@ -86,6 +99,7 @@ class GPTclient:
|
|
86 |
or local image paths or PIL.Image to accompany the text prompt.
|
87 |
system_role (Optional[str]): Optional system-level instructions
|
88 |
that specify the behavior of the assistant.
|
|
|
89 |
|
90 |
Returns:
|
91 |
Optional[str]: The response content generated by the model based on
|
@@ -103,11 +117,11 @@ class GPTclient:
|
|
103 |
|
104 |
# Process images if provided
|
105 |
if image_base64 is not None:
|
106 |
-
|
107 |
-
image_base64
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
for img in image_base64:
|
112 |
if isinstance(img, Image.Image):
|
113 |
buffer = BytesIO()
|
@@ -142,8 +156,11 @@ class GPTclient:
|
|
142 |
"frequency_penalty": 0,
|
143 |
"presence_penalty": 0,
|
144 |
"stop": None,
|
|
|
145 |
}
|
146 |
-
|
|
|
|
|
147 |
|
148 |
response = None
|
149 |
try:
|
@@ -159,8 +176,28 @@ class GPTclient:
|
|
159 |
|
160 |
return response
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
-
|
|
|
164 |
config = yaml.safe_load(f)
|
165 |
|
166 |
agent_type = config["agent_type"]
|
@@ -177,32 +214,5 @@ GPT_CLIENT = GPTclient(
|
|
177 |
api_key=api_key,
|
178 |
api_version=api_version,
|
179 |
model_name=model_name,
|
|
|
180 |
)
|
181 |
-
|
182 |
-
if __name__ == "__main__":
|
183 |
-
if "openrouter" in GPT_CLIENT.endpoint:
|
184 |
-
response = GPT_CLIENT.query(
|
185 |
-
text_prompt="What is the content in each image?",
|
186 |
-
image_base64=combine_images_to_base64(
|
187 |
-
[
|
188 |
-
"apps/assets/example_image/sample_02.jpg",
|
189 |
-
"apps/assets/example_image/sample_03.jpg",
|
190 |
-
]
|
191 |
-
), # input raw image_path if only one image
|
192 |
-
)
|
193 |
-
print(response)
|
194 |
-
else:
|
195 |
-
response = GPT_CLIENT.query(
|
196 |
-
text_prompt="What is the content in the images?",
|
197 |
-
image_base64=[
|
198 |
-
Image.open("apps/assets/example_image/sample_02.jpg"),
|
199 |
-
Image.open("apps/assets/example_image/sample_03.jpg"),
|
200 |
-
],
|
201 |
-
)
|
202 |
-
print(response)
|
203 |
-
|
204 |
-
# test2: text prompt
|
205 |
-
response = GPT_CLIENT.query(
|
206 |
-
text_prompt="What is the capital of China?"
|
207 |
-
)
|
208 |
-
print(response)
|
|
|
30 |
stop_after_delay,
|
31 |
wait_random_exponential,
|
32 |
)
|
33 |
+
from embodied_gen.utils.process_media import combine_images_to_grid
|
34 |
|
35 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
36 |
+
logging.basicConfig(level=logging.WARNING)
|
37 |
logger = logging.getLogger(__name__)
|
38 |
|
39 |
|
40 |
+
__all__ = [
|
41 |
+
"GPTclient",
|
42 |
+
]
|
43 |
+
|
44 |
+
CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
|
45 |
+
|
46 |
+
|
47 |
class GPTclient:
|
48 |
"""A client to interact with the GPT model via OpenAI or Azure API."""
|
49 |
|
|
|
53 |
api_key: str,
|
54 |
model_name: str = "yfb-gpt-4o",
|
55 |
api_version: str = None,
|
56 |
+
check_connection: bool = True,
|
57 |
verbose: bool = False,
|
58 |
):
|
59 |
if api_version is not None:
|
|
|
72 |
self.model_name = model_name
|
73 |
self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
|
74 |
self.verbose = verbose
|
75 |
+
if check_connection:
|
76 |
+
self.check_connection()
|
77 |
+
|
78 |
logger.info(f"Using GPT model: {self.model_name}.")
|
79 |
|
80 |
@retry(
|
|
|
89 |
text_prompt: str,
|
90 |
image_base64: Optional[list[str | Image.Image]] = None,
|
91 |
system_role: Optional[str] = None,
|
92 |
+
params: Optional[dict] = None,
|
93 |
) -> Optional[str]:
|
94 |
"""Queries the GPT model with a text and optional image prompts.
|
95 |
|
|
|
99 |
or local image paths or PIL.Image to accompany the text prompt.
|
100 |
system_role (Optional[str]): Optional system-level instructions
|
101 |
that specify the behavior of the assistant.
|
102 |
+
params (Optional[dict]): Additional parameters for GPT setting.
|
103 |
|
104 |
Returns:
|
105 |
Optional[str]: The response content generated by the model based on
|
|
|
117 |
|
118 |
# Process images if provided
|
119 |
if image_base64 is not None:
|
120 |
+
if not isinstance(image_base64, list):
|
121 |
+
image_base64 = [image_base64]
|
122 |
+
# Hardcode tmp because of the openrouter can't input multi images.
|
123 |
+
if "openrouter" in self.endpoint:
|
124 |
+
image_base64 = combine_images_to_grid(image_base64)
|
125 |
for img in image_base64:
|
126 |
if isinstance(img, Image.Image):
|
127 |
buffer = BytesIO()
|
|
|
156 |
"frequency_penalty": 0,
|
157 |
"presence_penalty": 0,
|
158 |
"stop": None,
|
159 |
+
"model": self.model_name,
|
160 |
}
|
161 |
+
|
162 |
+
if params:
|
163 |
+
payload.update(params)
|
164 |
|
165 |
response = None
|
166 |
try:
|
|
|
176 |
|
177 |
return response
|
178 |
|
179 |
+
def check_connection(self) -> None:
|
180 |
+
"""Check whether the GPT API connection is working."""
|
181 |
+
try:
|
182 |
+
response = self.completion_with_backoff(
|
183 |
+
messages=[
|
184 |
+
{"role": "system", "content": "You are a test system."},
|
185 |
+
{"role": "user", "content": "Hello"},
|
186 |
+
],
|
187 |
+
model=self.model_name,
|
188 |
+
temperature=0,
|
189 |
+
max_tokens=100,
|
190 |
+
)
|
191 |
+
content = response.choices[0].message.content
|
192 |
+
logger.info(f"Connection check success.")
|
193 |
+
except Exception as e:
|
194 |
+
raise ConnectionError(
|
195 |
+
f"Failed to connect to GPT API at {self.endpoint}, "
|
196 |
+
f"please check setting in `{CONFIG_FILE}` and `README`."
|
197 |
+
)
|
198 |
|
199 |
+
|
200 |
+
with open(CONFIG_FILE, "r") as f:
|
201 |
config = yaml.safe_load(f)
|
202 |
|
203 |
agent_type = config["agent_type"]
|
|
|
214 |
api_key=api_key,
|
215 |
api_version=api_version,
|
216 |
model_name=model_name,
|
217 |
+
check_connection=False,
|
218 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embodied_gen/utils/log.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
from colorlog import ColoredFormatter
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"logger",
|
23 |
+
]
|
24 |
+
|
25 |
+
LOG_FORMAT = (
|
26 |
+
"%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s"
|
27 |
+
)
|
28 |
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
29 |
+
|
30 |
+
formatter = ColoredFormatter(
|
31 |
+
LOG_FORMAT,
|
32 |
+
datefmt=DATE_FORMAT,
|
33 |
+
log_colors={
|
34 |
+
"DEBUG": "cyan",
|
35 |
+
"INFO": "green",
|
36 |
+
"WARNING": "yellow",
|
37 |
+
"ERROR": "red",
|
38 |
+
"CRITICAL": "bold_red",
|
39 |
+
},
|
40 |
+
)
|
41 |
+
|
42 |
+
handler = logging.StreamHandler()
|
43 |
+
handler.setFormatter(formatter)
|
44 |
+
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
+
logger.setLevel(logging.INFO)
|
47 |
+
logger.addHandler(handler)
|
48 |
+
logger.propagate = False
|
embodied_gen/utils/monkey_patches.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
import zipfile
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
from omegaconf import OmegaConf
|
24 |
+
from PIL import Image
|
25 |
+
from torchvision import transforms
|
26 |
+
|
27 |
+
|
28 |
+
def monkey_patch_pano2room():
|
29 |
+
current_file_path = os.path.abspath(__file__)
|
30 |
+
current_dir = os.path.dirname(current_file_path)
|
31 |
+
sys.path.append(os.path.join(current_dir, "../.."))
|
32 |
+
sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
|
33 |
+
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
|
34 |
+
OmnidataNormalPredictor,
|
35 |
+
)
|
36 |
+
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
|
37 |
+
OmnidataPredictor,
|
38 |
+
)
|
39 |
+
|
40 |
+
def patched_omni_depth_init(self):
|
41 |
+
self.img_size = 384
|
42 |
+
self.model = torch.hub.load(
|
43 |
+
'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
|
44 |
+
)
|
45 |
+
self.model.eval()
|
46 |
+
self.trans_totensor = transforms.Compose(
|
47 |
+
[
|
48 |
+
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
|
49 |
+
transforms.CenterCrop(self.img_size),
|
50 |
+
transforms.Normalize(mean=0.5, std=0.5),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
OmnidataPredictor.__init__ = patched_omni_depth_init
|
55 |
+
|
56 |
+
def patched_omni_normal_init(self):
|
57 |
+
self.img_size = 384
|
58 |
+
self.model = torch.hub.load(
|
59 |
+
'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
|
60 |
+
)
|
61 |
+
self.model.eval()
|
62 |
+
self.trans_totensor = transforms.Compose(
|
63 |
+
[
|
64 |
+
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
|
65 |
+
transforms.CenterCrop(self.img_size),
|
66 |
+
transforms.Normalize(mean=0.5, std=0.5),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
OmnidataNormalPredictor.__init__ = patched_omni_normal_init
|
71 |
+
|
72 |
+
def patched_panojoint_init(self, save_path=None):
|
73 |
+
self.depth_predictor = OmnidataPredictor()
|
74 |
+
self.normal_predictor = OmnidataNormalPredictor()
|
75 |
+
self.save_path = save_path
|
76 |
+
|
77 |
+
from modules.geo_predictors import PanoJointPredictor
|
78 |
+
|
79 |
+
PanoJointPredictor.__init__ = patched_panojoint_init
|
80 |
+
|
81 |
+
# NOTE: We use gsplat instead.
|
82 |
+
# import depth_diff_gaussian_rasterization_min as ddgr
|
83 |
+
# from dataclasses import dataclass
|
84 |
+
# @dataclass
|
85 |
+
# class PatchedGaussianRasterizationSettings:
|
86 |
+
# image_height: int
|
87 |
+
# image_width: int
|
88 |
+
# tanfovx: float
|
89 |
+
# tanfovy: float
|
90 |
+
# bg: torch.Tensor
|
91 |
+
# scale_modifier: float
|
92 |
+
# viewmatrix: torch.Tensor
|
93 |
+
# projmatrix: torch.Tensor
|
94 |
+
# sh_degree: int
|
95 |
+
# campos: torch.Tensor
|
96 |
+
# prefiltered: bool
|
97 |
+
# debug: bool = False
|
98 |
+
# ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
|
99 |
+
|
100 |
+
# disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
|
101 |
+
os.environ["NODE_RANK"] = "0"
|
102 |
+
|
103 |
+
from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
|
104 |
+
load_checkpoint,
|
105 |
+
)
|
106 |
+
from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
|
107 |
+
LamaInpainter,
|
108 |
+
)
|
109 |
+
|
110 |
+
def patched_lama_inpaint_init(self):
|
111 |
+
zip_path = hf_hub_download(
|
112 |
+
repo_id="smartywu/big-lama",
|
113 |
+
filename="big-lama.zip",
|
114 |
+
repo_type="model",
|
115 |
+
)
|
116 |
+
extract_dir = os.path.splitext(zip_path)[0]
|
117 |
+
|
118 |
+
if not os.path.exists(extract_dir):
|
119 |
+
os.makedirs(extract_dir, exist_ok=True)
|
120 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
121 |
+
zip_ref.extractall(extract_dir)
|
122 |
+
|
123 |
+
config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
|
124 |
+
checkpoint_path = os.path.join(
|
125 |
+
extract_dir, 'big-lama/models/best.ckpt'
|
126 |
+
)
|
127 |
+
train_config = OmegaConf.load(config_path)
|
128 |
+
train_config.training_model.predict_only = True
|
129 |
+
train_config.visualizer.kind = 'noop'
|
130 |
+
|
131 |
+
self.model = load_checkpoint(
|
132 |
+
train_config, checkpoint_path, strict=False, map_location='cpu'
|
133 |
+
)
|
134 |
+
self.model.freeze()
|
135 |
+
|
136 |
+
LamaInpainter.__init__ = patched_lama_inpaint_init
|
137 |
+
|
138 |
+
from diffusers import StableDiffusionInpaintPipeline
|
139 |
+
from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
|
140 |
+
SDFTInpainter,
|
141 |
+
)
|
142 |
+
|
143 |
+
def patched_sd_inpaint_init(self, subset_name=None):
|
144 |
+
super(SDFTInpainter, self).__init__()
|
145 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
146 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
147 |
+
torch_dtype=torch.float16,
|
148 |
+
).to("cuda")
|
149 |
+
pipe.enable_model_cpu_offload()
|
150 |
+
self.inpaint_pipe = pipe
|
151 |
+
|
152 |
+
SDFTInpainter.__init__ = patched_sd_inpaint_init
|
embodied_gen/utils/process_media.py
CHANGED
@@ -15,34 +15,25 @@
|
|
15 |
# permissions and limitations under the License.
|
16 |
|
17 |
|
18 |
-
import base64
|
19 |
import logging
|
20 |
import math
|
|
|
21 |
import os
|
22 |
-
import
|
23 |
from glob import glob
|
24 |
-
from io import BytesIO
|
25 |
from typing import Union
|
26 |
|
27 |
import cv2
|
28 |
import imageio
|
|
|
|
|
29 |
import numpy as np
|
30 |
-
import PIL.Image as Image
|
31 |
import spaces
|
32 |
-
import
|
33 |
from moviepy.editor import VideoFileClip, clips_array
|
34 |
-
from
|
35 |
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
36 |
-
|
37 |
-
current_file_path = os.path.abspath(__file__)
|
38 |
-
current_dir = os.path.dirname(current_file_path)
|
39 |
-
sys.path.append(os.path.join(current_dir, "../.."))
|
40 |
-
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
|
41 |
-
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
42 |
-
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
43 |
-
render_frames,
|
44 |
-
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
45 |
-
)
|
46 |
|
47 |
logging.basicConfig(level=logging.INFO)
|
48 |
logger = logging.getLogger(__name__)
|
@@ -53,9 +44,11 @@ __all__ = [
|
|
53 |
"merge_images_video",
|
54 |
"filter_small_connected_components",
|
55 |
"filter_image_small_connected_components",
|
56 |
-
"
|
57 |
-
"
|
58 |
-
"
|
|
|
|
|
59 |
]
|
60 |
|
61 |
|
@@ -66,12 +59,14 @@ def render_asset3d(
|
|
66 |
distance: float = 5.0,
|
67 |
num_images: int = 1,
|
68 |
elevation: list[float] = (0.0,),
|
69 |
-
pbr_light_factor: float = 1.
|
70 |
return_key: str = "image_color/*",
|
71 |
output_subdir: str = "renders",
|
72 |
gen_color_mp4: bool = False,
|
73 |
gen_viewnormal_mp4: bool = False,
|
74 |
gen_glonormal_mp4: bool = False,
|
|
|
|
|
75 |
) -> list[str]:
|
76 |
input_args = dict(
|
77 |
mesh_path=mesh_path,
|
@@ -81,14 +76,13 @@ def render_asset3d(
|
|
81 |
num_images=num_images,
|
82 |
elevation=elevation,
|
83 |
pbr_light_factor=pbr_light_factor,
|
84 |
-
with_mtl=
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
-
|
87 |
-
input_args["gen_color_mp4"] = True
|
88 |
-
if gen_viewnormal_mp4:
|
89 |
-
input_args["gen_viewnormal_mp4"] = True
|
90 |
-
if gen_glonormal_mp4:
|
91 |
-
input_args["gen_glonormal_mp4"] = True
|
92 |
try:
|
93 |
_ = render_api(**input_args)
|
94 |
except Exception as e:
|
@@ -168,12 +162,15 @@ def filter_image_small_connected_components(
|
|
168 |
return image
|
169 |
|
170 |
|
171 |
-
def
|
172 |
images: list[str | Image.Image],
|
173 |
cat_row_col: tuple[int, int] = None,
|
174 |
target_wh: tuple[int, int] = (512, 512),
|
175 |
-
) -> str:
|
176 |
n_images = len(images)
|
|
|
|
|
|
|
177 |
if cat_row_col is None:
|
178 |
n_col = math.ceil(math.sqrt(n_images))
|
179 |
n_row = math.ceil(n_images / n_col)
|
@@ -182,88 +179,229 @@ def combine_images_to_base64(
|
|
182 |
|
183 |
images = [
|
184 |
Image.open(p).convert("RGB") if isinstance(p, str) else p
|
185 |
-
for p in images
|
186 |
]
|
187 |
images = [img.resize(target_wh) for img in images]
|
188 |
|
189 |
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
|
190 |
-
grid = Image.new("RGB", (grid_w, grid_h), (
|
191 |
|
192 |
for idx, img in enumerate(images):
|
193 |
row, col = divmod(idx, n_col)
|
194 |
grid.paste(img, (col * target_wh[0], row * target_wh[1]))
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
)
|
217 |
-
normal = np.clip(
|
218 |
-
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
219 |
-
).astype(np.uint8)
|
220 |
-
rets["normal"].append(normal)
|
221 |
|
222 |
-
|
|
|
|
|
223 |
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
):
|
235 |
-
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
236 |
-
yaws = yaws.tolist()
|
237 |
-
pitch = [0.5] * num_frames
|
238 |
-
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
|
239 |
-
yaws, pitch, r, fov
|
240 |
-
)
|
241 |
-
render_fn = (
|
242 |
-
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
|
243 |
-
)
|
244 |
-
result = render_fn(
|
245 |
-
sample,
|
246 |
-
extrinsics,
|
247 |
-
intrinsics,
|
248 |
-
{"resolution": resolution, "bg_color": bg_color},
|
249 |
-
**kwargs,
|
250 |
-
)
|
251 |
|
252 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
|
255 |
if __name__ == "__main__":
|
256 |
-
# Example usage:
|
257 |
merge_video_video(
|
258 |
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
259 |
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
|
260 |
"merge.mp4",
|
261 |
)
|
262 |
-
|
263 |
-
image_base64 = combine_images_to_base64(
|
264 |
-
[
|
265 |
-
"apps/assets/example_image/sample_00.jpg",
|
266 |
-
"apps/assets/example_image/sample_01.jpg",
|
267 |
-
"apps/assets/example_image/sample_02.jpg",
|
268 |
-
]
|
269 |
-
)
|
|
|
15 |
# permissions and limitations under the License.
|
16 |
|
17 |
|
|
|
18 |
import logging
|
19 |
import math
|
20 |
+
import mimetypes
|
21 |
import os
|
22 |
+
import textwrap
|
23 |
from glob import glob
|
|
|
24 |
from typing import Union
|
25 |
|
26 |
import cv2
|
27 |
import imageio
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
import networkx as nx
|
30 |
import numpy as np
|
|
|
31 |
import spaces
|
32 |
+
from matplotlib.patches import Patch
|
33 |
from moviepy.editor import VideoFileClip, clips_array
|
34 |
+
from PIL import Image
|
35 |
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
36 |
+
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
logging.basicConfig(level=logging.INFO)
|
39 |
logger = logging.getLogger(__name__)
|
|
|
44 |
"merge_images_video",
|
45 |
"filter_small_connected_components",
|
46 |
"filter_image_small_connected_components",
|
47 |
+
"combine_images_to_grid",
|
48 |
+
"SceneTreeVisualizer",
|
49 |
+
"is_image_file",
|
50 |
+
"parse_text_prompts",
|
51 |
+
"check_object_edge_truncated",
|
52 |
]
|
53 |
|
54 |
|
|
|
59 |
distance: float = 5.0,
|
60 |
num_images: int = 1,
|
61 |
elevation: list[float] = (0.0,),
|
62 |
+
pbr_light_factor: float = 1.2,
|
63 |
return_key: str = "image_color/*",
|
64 |
output_subdir: str = "renders",
|
65 |
gen_color_mp4: bool = False,
|
66 |
gen_viewnormal_mp4: bool = False,
|
67 |
gen_glonormal_mp4: bool = False,
|
68 |
+
no_index_file: bool = False,
|
69 |
+
with_mtl: bool = True,
|
70 |
) -> list[str]:
|
71 |
input_args = dict(
|
72 |
mesh_path=mesh_path,
|
|
|
76 |
num_images=num_images,
|
77 |
elevation=elevation,
|
78 |
pbr_light_factor=pbr_light_factor,
|
79 |
+
with_mtl=with_mtl,
|
80 |
+
gen_color_mp4=gen_color_mp4,
|
81 |
+
gen_viewnormal_mp4=gen_viewnormal_mp4,
|
82 |
+
gen_glonormal_mp4=gen_glonormal_mp4,
|
83 |
+
no_index_file=no_index_file,
|
84 |
)
|
85 |
+
|
|
|
|
|
|
|
|
|
|
|
86 |
try:
|
87 |
_ = render_api(**input_args)
|
88 |
except Exception as e:
|
|
|
162 |
return image
|
163 |
|
164 |
|
165 |
+
def combine_images_to_grid(
|
166 |
images: list[str | Image.Image],
|
167 |
cat_row_col: tuple[int, int] = None,
|
168 |
target_wh: tuple[int, int] = (512, 512),
|
169 |
+
) -> list[str | Image.Image]:
|
170 |
n_images = len(images)
|
171 |
+
if n_images == 1:
|
172 |
+
return images
|
173 |
+
|
174 |
if cat_row_col is None:
|
175 |
n_col = math.ceil(math.sqrt(n_images))
|
176 |
n_row = math.ceil(n_images / n_col)
|
|
|
179 |
|
180 |
images = [
|
181 |
Image.open(p).convert("RGB") if isinstance(p, str) else p
|
182 |
+
for p in images
|
183 |
]
|
184 |
images = [img.resize(target_wh) for img in images]
|
185 |
|
186 |
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
|
187 |
+
grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0))
|
188 |
|
189 |
for idx, img in enumerate(images):
|
190 |
row, col = divmod(idx, n_col)
|
191 |
grid.paste(img, (col * target_wh[0], row * target_wh[1]))
|
192 |
|
193 |
+
return [grid]
|
194 |
+
|
195 |
+
|
196 |
+
class SceneTreeVisualizer:
|
197 |
+
def __init__(self, layout_info: LayoutInfo) -> None:
|
198 |
+
self.tree = layout_info.tree
|
199 |
+
self.relation = layout_info.relation
|
200 |
+
self.objs_desc = layout_info.objs_desc
|
201 |
+
self.G = nx.DiGraph()
|
202 |
+
self.root = self._find_root()
|
203 |
+
self._build_graph()
|
204 |
+
|
205 |
+
self.role_colors = {
|
206 |
+
Scene3DItemEnum.BACKGROUND.value: "plum",
|
207 |
+
Scene3DItemEnum.CONTEXT.value: "lightblue",
|
208 |
+
Scene3DItemEnum.ROBOT.value: "lightcoral",
|
209 |
+
Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen",
|
210 |
+
Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray",
|
211 |
+
Scene3DItemEnum.OTHERS.value: "orange",
|
212 |
+
}
|
213 |
+
|
214 |
+
def _find_root(self) -> str:
|
215 |
+
children = {c for cs in self.tree.values() for c, _ in cs}
|
216 |
+
parents = set(self.tree.keys())
|
217 |
+
roots = parents - children
|
218 |
+
if not roots:
|
219 |
+
raise ValueError("No root node found.")
|
220 |
+
return next(iter(roots))
|
221 |
+
|
222 |
+
def _build_graph(self):
|
223 |
+
for parent, children in self.tree.items():
|
224 |
+
for child, relation in children:
|
225 |
+
self.G.add_edge(parent, child, relation=relation)
|
226 |
+
|
227 |
+
def _get_node_role(self, node: str) -> str:
|
228 |
+
if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value):
|
229 |
+
return Scene3DItemEnum.BACKGROUND.value
|
230 |
+
if node == self.relation.get(Scene3DItemEnum.CONTEXT.value):
|
231 |
+
return Scene3DItemEnum.CONTEXT.value
|
232 |
+
if node == self.relation.get(Scene3DItemEnum.ROBOT.value):
|
233 |
+
return Scene3DItemEnum.ROBOT.value
|
234 |
+
if node in self.relation.get(
|
235 |
+
Scene3DItemEnum.MANIPULATED_OBJS.value, []
|
236 |
+
):
|
237 |
+
return Scene3DItemEnum.MANIPULATED_OBJS.value
|
238 |
+
if node in self.relation.get(
|
239 |
+
Scene3DItemEnum.DISTRACTOR_OBJS.value, []
|
240 |
+
):
|
241 |
+
return Scene3DItemEnum.DISTRACTOR_OBJS.value
|
242 |
+
return Scene3DItemEnum.OTHERS.value
|
243 |
+
|
244 |
+
def _get_positions(
|
245 |
+
self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None
|
246 |
+
):
|
247 |
+
if pos is None:
|
248 |
+
pos = {root: (xcenter, vert_loc)}
|
249 |
+
else:
|
250 |
+
pos[root] = (xcenter, vert_loc)
|
251 |
+
|
252 |
+
children = list(self.G.successors(root))
|
253 |
+
if children:
|
254 |
+
dx = width / len(children)
|
255 |
+
next_x = xcenter - width / 2 - dx / 2
|
256 |
+
for child in children:
|
257 |
+
next_x += dx
|
258 |
+
pos = self._get_positions(
|
259 |
+
child,
|
260 |
+
width=dx,
|
261 |
+
vert_gap=vert_gap,
|
262 |
+
vert_loc=vert_loc - vert_gap,
|
263 |
+
xcenter=next_x,
|
264 |
+
pos=pos,
|
265 |
+
)
|
266 |
+
return pos
|
267 |
+
|
268 |
+
def render(
|
269 |
+
self,
|
270 |
+
save_path: str,
|
271 |
+
figsize=(8, 6),
|
272 |
+
dpi=300,
|
273 |
+
title: str = "Scene 3D Hierarchy Tree",
|
274 |
+
):
|
275 |
+
node_colors = [
|
276 |
+
self.role_colors[self._get_node_role(n)] for n in self.G.nodes
|
277 |
+
]
|
278 |
+
pos = self._get_positions(self.root)
|
279 |
+
|
280 |
+
plt.figure(figsize=figsize)
|
281 |
+
nx.draw(
|
282 |
+
self.G,
|
283 |
+
pos,
|
284 |
+
with_labels=True,
|
285 |
+
arrows=False,
|
286 |
+
node_size=2000,
|
287 |
+
node_color=node_colors,
|
288 |
+
font_size=10,
|
289 |
+
font_weight="bold",
|
290 |
+
)
|
291 |
|
292 |
+
# Draw edge labels
|
293 |
+
edge_labels = nx.get_edge_attributes(self.G, "relation")
|
294 |
+
nx.draw_networkx_edge_labels(
|
295 |
+
self.G,
|
296 |
+
pos,
|
297 |
+
edge_labels=edge_labels,
|
298 |
+
font_size=9,
|
299 |
+
font_color="black",
|
300 |
+
)
|
301 |
|
302 |
+
# Draw small description text under each node (if available)
|
303 |
+
for node, (x, y) in pos.items():
|
304 |
+
desc = self.objs_desc.get(node)
|
305 |
+
if desc:
|
306 |
+
wrapped = "\n".join(textwrap.wrap(desc, width=30))
|
307 |
+
plt.text(
|
308 |
+
x,
|
309 |
+
y - 0.006,
|
310 |
+
wrapped,
|
311 |
+
fontsize=6,
|
312 |
+
ha="center",
|
313 |
+
va="top",
|
314 |
+
wrap=True,
|
315 |
+
color="black",
|
316 |
+
bbox=dict(
|
317 |
+
facecolor="dimgray",
|
318 |
+
edgecolor="darkgray",
|
319 |
+
alpha=0.1,
|
320 |
+
boxstyle="round,pad=0.2",
|
321 |
+
),
|
322 |
+
)
|
323 |
+
|
324 |
+
plt.title(title, fontsize=12)
|
325 |
+
task_desc = self.relation.get("task_desc", "")
|
326 |
+
if task_desc:
|
327 |
+
plt.suptitle(
|
328 |
+
f"Task Description: {task_desc}", fontsize=10, y=0.999
|
329 |
+
)
|
330 |
+
|
331 |
+
plt.axis("off")
|
332 |
+
|
333 |
+
legend_handles = [
|
334 |
+
Patch(facecolor=color, edgecolor='black', label=role)
|
335 |
+
for role, color in self.role_colors.items()
|
336 |
+
]
|
337 |
+
plt.legend(
|
338 |
+
handles=legend_handles,
|
339 |
+
loc="lower center",
|
340 |
+
ncol=3,
|
341 |
+
bbox_to_anchor=(0.5, -0.1),
|
342 |
+
fontsize=9,
|
343 |
)
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
346 |
+
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
347 |
+
plt.close()
|
348 |
|
349 |
|
350 |
+
def load_scene_dict(file_path: str) -> dict:
|
351 |
+
scene_dict = {}
|
352 |
+
with open(file_path, "r", encoding='utf-8') as f:
|
353 |
+
for line in f:
|
354 |
+
line = line.strip()
|
355 |
+
if not line or ":" not in line:
|
356 |
+
continue
|
357 |
+
scene_id, desc = line.split(":", 1)
|
358 |
+
scene_dict[scene_id.strip()] = desc.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
+
return scene_dict
|
361 |
+
|
362 |
+
|
363 |
+
def is_image_file(filename: str) -> bool:
|
364 |
+
mime_type, _ = mimetypes.guess_type(filename)
|
365 |
+
|
366 |
+
return mime_type is not None and mime_type.startswith('image')
|
367 |
+
|
368 |
+
|
369 |
+
def parse_text_prompts(prompts: list[str]) -> list[str]:
|
370 |
+
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
371 |
+
with open(prompts[0], "r") as f:
|
372 |
+
prompts = [
|
373 |
+
line.strip()
|
374 |
+
for line in f
|
375 |
+
if line.strip() and not line.strip().startswith("#")
|
376 |
+
]
|
377 |
+
return prompts
|
378 |
+
|
379 |
+
|
380 |
+
def check_object_edge_truncated(
|
381 |
+
mask: np.ndarray, edge_threshold: int = 5
|
382 |
+
) -> bool:
|
383 |
+
"""Checks if a binary object mask is truncated at the image edges.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
mask: A 2D binary NumPy array where nonzero values indicate the object region.
|
387 |
+
edge_threshold: Number of pixels from each image edge to consider for truncation.
|
388 |
+
Defaults to 5.
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
True if the object is fully enclosed (not truncated).
|
392 |
+
False if the object touches or crosses any image boundary.
|
393 |
+
"""
|
394 |
+
top = mask[:edge_threshold, :].any()
|
395 |
+
bottom = mask[-edge_threshold:, :].any()
|
396 |
+
left = mask[:, :edge_threshold].any()
|
397 |
+
right = mask[:, -edge_threshold:].any()
|
398 |
+
|
399 |
+
return not (top or bottom or left or right)
|
400 |
|
401 |
|
402 |
if __name__ == "__main__":
|
|
|
403 |
merge_video_video(
|
404 |
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
405 |
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
|
406 |
"merge.mp4",
|
407 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embodied_gen/utils/tags.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
VERSION = "v0.1.
|
|
|
1 |
+
VERSION = "v0.1.2"
|
embodied_gen/utils/trender.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project EmbodiedGen
|
2 |
+
#
|
3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
14 |
+
# implied. See the License for the specific language governing
|
15 |
+
# permissions and limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import spaces
|
22 |
+
import torch
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
current_file_path = os.path.abspath(__file__)
|
26 |
+
current_dir = os.path.dirname(current_file_path)
|
27 |
+
sys.path.append(os.path.join(current_dir, "../.."))
|
28 |
+
from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
|
29 |
+
from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
|
30 |
+
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
31 |
+
render_frames,
|
32 |
+
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
33 |
+
)
|
34 |
+
|
35 |
+
__all__ = [
|
36 |
+
"render_video",
|
37 |
+
]
|
38 |
+
|
39 |
+
|
40 |
+
@spaces.GPU
|
41 |
+
def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
|
42 |
+
renderer = MeshRenderer()
|
43 |
+
renderer.rendering_options.resolution = options.get("resolution", 512)
|
44 |
+
renderer.rendering_options.near = options.get("near", 1)
|
45 |
+
renderer.rendering_options.far = options.get("far", 100)
|
46 |
+
renderer.rendering_options.ssaa = options.get("ssaa", 4)
|
47 |
+
rets = {}
|
48 |
+
for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
|
49 |
+
res = renderer.render(sample, extr, intr)
|
50 |
+
if "normal" not in rets:
|
51 |
+
rets["normal"] = []
|
52 |
+
normal = torch.lerp(
|
53 |
+
torch.zeros_like(res["normal"]), res["normal"], res["mask"]
|
54 |
+
)
|
55 |
+
normal = np.clip(
|
56 |
+
normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
|
57 |
+
).astype(np.uint8)
|
58 |
+
rets["normal"].append(normal)
|
59 |
+
|
60 |
+
return rets
|
61 |
+
|
62 |
+
|
63 |
+
@spaces.GPU
|
64 |
+
def render_video(
|
65 |
+
sample,
|
66 |
+
resolution=512,
|
67 |
+
bg_color=(0, 0, 0),
|
68 |
+
num_frames=300,
|
69 |
+
r=2,
|
70 |
+
fov=40,
|
71 |
+
**kwargs,
|
72 |
+
):
|
73 |
+
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
74 |
+
yaws = yaws.tolist()
|
75 |
+
pitch = [0.5] * num_frames
|
76 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
|
77 |
+
yaws, pitch, r, fov
|
78 |
+
)
|
79 |
+
render_fn = (
|
80 |
+
render_mesh if isinstance(sample, MeshExtractResult) else render_frames
|
81 |
+
)
|
82 |
+
result = render_fn(
|
83 |
+
sample,
|
84 |
+
extrinsics,
|
85 |
+
intrinsics,
|
86 |
+
{"resolution": resolution, "bg_color": bg_color},
|
87 |
+
**kwargs,
|
88 |
+
)
|
89 |
+
|
90 |
+
return result
|
embodied_gen/validators/aesthetic_predictor.py
CHANGED
@@ -102,7 +102,7 @@ class AestheticPredictor:
|
|
102 |
def _load_sac_model(self, model_path, input_size):
|
103 |
"""Load the SAC model."""
|
104 |
model = self.MLP(input_size)
|
105 |
-
ckpt = torch.load(model_path)
|
106 |
model.load_state_dict(ckpt)
|
107 |
model.to(self.device)
|
108 |
model.eval()
|
@@ -135,15 +135,3 @@ class AestheticPredictor:
|
|
135 |
)
|
136 |
|
137 |
return prediction.item()
|
138 |
-
|
139 |
-
|
140 |
-
if __name__ == "__main__":
|
141 |
-
# Configuration
|
142 |
-
img_path = "apps/assets/example_image/sample_00.jpg"
|
143 |
-
|
144 |
-
# Initialize the predictor
|
145 |
-
predictor = AestheticPredictor()
|
146 |
-
|
147 |
-
# Predict the aesthetic score
|
148 |
-
score = predictor.predict(img_path)
|
149 |
-
print("Aesthetic score predicted by the model:", score)
|
|
|
102 |
def _load_sac_model(self, model_path, input_size):
|
103 |
"""Load the SAC model."""
|
104 |
model = self.MLP(input_size)
|
105 |
+
ckpt = torch.load(model_path, weights_only=True)
|
106 |
model.load_state_dict(ckpt)
|
107 |
model.to(self.device)
|
108 |
model.eval()
|
|
|
135 |
)
|
136 |
|
137 |
return prediction.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embodied_gen/validators/quality_checkers.py
CHANGED
@@ -16,17 +16,29 @@
|
|
16 |
|
17 |
|
18 |
import logging
|
19 |
-
import
|
20 |
|
21 |
-
|
|
|
22 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
23 |
-
from embodied_gen.utils.process_media import render_asset3d
|
24 |
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
25 |
|
26 |
logging.basicConfig(level=logging.INFO)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class BaseChecker:
|
31 |
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
32 |
self.prompt = prompt
|
@@ -37,16 +49,20 @@ class BaseChecker:
|
|
37 |
"Subclasses must implement the query method."
|
38 |
)
|
39 |
|
40 |
-
def __call__(self, *args, **kwargs) -> bool:
|
41 |
response = self.query(*args, **kwargs)
|
42 |
-
if
|
43 |
-
response = "Error when calling gpt api."
|
44 |
-
|
45 |
-
if self.verbose and response != "YES":
|
46 |
logger.info(response)
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
return flag, response
|
52 |
|
@@ -92,21 +108,29 @@ class MeshGeoChecker(BaseChecker):
|
|
92 |
self.gpt_client = gpt_client
|
93 |
if self.prompt is None:
|
94 |
self.prompt = """
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
"""
|
101 |
|
102 |
-
def query(self, image_paths: str) -> str:
|
103 |
-
# Hardcode tmp because of the openrouter can't input multi images.
|
104 |
-
if "openrouter" in self.gpt_client.endpoint:
|
105 |
-
from embodied_gen.utils.process_media import (
|
106 |
-
combine_images_to_base64,
|
107 |
-
)
|
108 |
-
|
109 |
-
image_paths = combine_images_to_base64(image_paths)
|
110 |
|
111 |
return self.gpt_client.query(
|
112 |
text_prompt=self.prompt,
|
@@ -137,14 +161,19 @@ class ImageSegChecker(BaseChecker):
|
|
137 |
self.gpt_client = gpt_client
|
138 |
if self.prompt is None:
|
139 |
self.prompt = """
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
148 |
"""
|
149 |
|
150 |
def query(self, image_paths: list[str]) -> str:
|
@@ -152,13 +181,6 @@ class ImageSegChecker(BaseChecker):
|
|
152 |
raise ValueError(
|
153 |
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
|
154 |
)
|
155 |
-
# Hardcode tmp because of the openrouter can't input multi images.
|
156 |
-
if "openrouter" in self.gpt_client.endpoint:
|
157 |
-
from embodied_gen.utils.process_media import (
|
158 |
-
combine_images_to_base64,
|
159 |
-
)
|
160 |
-
|
161 |
-
image_paths = combine_images_to_base64(image_paths)
|
162 |
|
163 |
return self.gpt_client.query(
|
164 |
text_prompt=self.prompt,
|
@@ -201,42 +223,358 @@ class ImageAestheticChecker(BaseChecker):
|
|
201 |
return avg_score > self.thresh, avg_score
|
202 |
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
)
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
)
|
|
|
|
|
|
|
|
|
238 |
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
import logging
|
19 |
+
import random
|
20 |
|
21 |
+
import json_repair
|
22 |
+
from PIL import Image
|
23 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
|
|
24 |
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
25 |
|
26 |
logging.basicConfig(level=logging.INFO)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
29 |
|
30 |
+
__all__ = [
|
31 |
+
"MeshGeoChecker",
|
32 |
+
"ImageSegChecker",
|
33 |
+
"ImageAestheticChecker",
|
34 |
+
"SemanticConsistChecker",
|
35 |
+
"TextGenAlignChecker",
|
36 |
+
"PanoImageGenChecker",
|
37 |
+
"PanoHeightEstimator",
|
38 |
+
"PanoImageOccChecker",
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
class BaseChecker:
|
43 |
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
44 |
self.prompt = prompt
|
|
|
49 |
"Subclasses must implement the query method."
|
50 |
)
|
51 |
|
52 |
+
def __call__(self, *args, **kwargs) -> tuple[bool, str]:
|
53 |
response = self.query(*args, **kwargs)
|
54 |
+
if self.verbose:
|
|
|
|
|
|
|
55 |
logger.info(response)
|
56 |
|
57 |
+
if response is None:
|
58 |
+
flag = None
|
59 |
+
response = (
|
60 |
+
"Error when calling GPT api, check config in "
|
61 |
+
"`embodied_gen/utils/gpt_config.yaml` or net connection."
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
flag = "YES" in response
|
65 |
+
response = "YES" if flag else response
|
66 |
|
67 |
return flag, response
|
68 |
|
|
|
108 |
self.gpt_client = gpt_client
|
109 |
if self.prompt is None:
|
110 |
self.prompt = """
|
111 |
+
You are an expert in evaluating the geometry quality of generated 3D asset.
|
112 |
+
You will be given rendered views of a generated 3D asset with black background.
|
113 |
+
Your task is to evaluate the quality of the 3D asset generation,
|
114 |
+
including geometry, structure, and appearance, based on the rendered views.
|
115 |
+
Criteria:
|
116 |
+
- Is the object in the image a single, complete, and well-formed instance,
|
117 |
+
without truncation, missing parts, overlapping duplicates, or redundant geometry?
|
118 |
+
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
119 |
+
soft edges) are acceptable if the object is structurally sound and recognizable.
|
120 |
+
- Only evaluate geometry. Do not assess texture quality.
|
121 |
+
- The asset should not contain any unrelated elements, such as
|
122 |
+
ground planes, platforms, or background props (e.g., paper, flooring).
|
123 |
+
|
124 |
+
If all the above criteria are met, return "YES". Otherwise, return
|
125 |
+
"NO" followed by a brief explanation (no more than 20 words).
|
126 |
+
|
127 |
+
Example:
|
128 |
+
Images show a yellow cup standing on a flat white plane -> NO
|
129 |
+
-> Response: NO: extra white surface under the object.
|
130 |
+
Image shows a chair with simplified back legs and soft edges β YES
|
131 |
"""
|
132 |
|
133 |
+
def query(self, image_paths: list[str | Image.Image]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
return self.gpt_client.query(
|
136 |
text_prompt=self.prompt,
|
|
|
161 |
self.gpt_client = gpt_client
|
162 |
if self.prompt is None:
|
163 |
self.prompt = """
|
164 |
+
Task: Evaluate the quality of object segmentation between two images:
|
165 |
+
the first is the original, the second is the segmented result.
|
166 |
+
|
167 |
+
Criteria:
|
168 |
+
- The main foreground object should be clearly extracted (not the background).
|
169 |
+
- The object must appear realistic, with reasonable geometry and color.
|
170 |
+
- The object should be geometrically complete β no missing, truncated, or cropped parts.
|
171 |
+
- The object must be centered, with a margin on all sides.
|
172 |
+
- Ignore minor imperfections (e.g., small holes or fine edge artifacts).
|
173 |
+
|
174 |
+
Output Rules:
|
175 |
+
If segmentation is acceptable, respond with "YES" (and nothing else).
|
176 |
+
If not acceptable, respond with "NO", followed by a brief reason (max 20 words).
|
177 |
"""
|
178 |
|
179 |
def query(self, image_paths: list[str]) -> str:
|
|
|
181 |
raise ValueError(
|
182 |
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
|
183 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
return self.gpt_client.query(
|
186 |
text_prompt=self.prompt,
|
|
|
223 |
return avg_score > self.thresh, avg_score
|
224 |
|
225 |
|
226 |
+
class SemanticConsistChecker(BaseChecker):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
gpt_client: GPTclient,
|
230 |
+
prompt: str = None,
|
231 |
+
verbose: bool = False,
|
232 |
+
) -> None:
|
233 |
+
super().__init__(prompt, verbose)
|
234 |
+
self.gpt_client = gpt_client
|
235 |
+
if self.prompt is None:
|
236 |
+
self.prompt = """
|
237 |
+
You are an expert in image-text consistency assessment.
|
238 |
+
You will be given:
|
239 |
+
- A short text description of an object.
|
240 |
+
- An segmented image of the same object with the background removed.
|
241 |
+
|
242 |
+
Criteria:
|
243 |
+
- The image must visually match the text description in terms of object type, structure, geometry, and color.
|
244 |
+
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number
|
245 |
+
of legs with a reasonable distribution. Count the number of legs visible in the image. (strict) For tables,
|
246 |
+
fewer than four legs or if the legs are unevenly distributed, are not allowed. Do not assume
|
247 |
+
hidden legs unless they are clearly visible.)
|
248 |
+
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
|
249 |
+
- The image must contain exactly one object. Multiple distinct objects are not allowed.
|
250 |
+
A single composite object (e.g., a chair with legs) is acceptable.
|
251 |
+
- The object should be shown from a slightly angled (three-quarter) perspective,
|
252 |
+
not a flat, front-facing view showing only one surface.
|
253 |
+
|
254 |
+
Instructions:
|
255 |
+
- If all criteria are met, return `"YES"`.
|
256 |
+
- Otherwise, return "NO" with a brief explanation (max 20 words).
|
257 |
+
|
258 |
+
Respond in exactly one of the following formats:
|
259 |
+
YES
|
260 |
+
or
|
261 |
+
NO: brief explanation.
|
262 |
+
|
263 |
+
Input:
|
264 |
+
{}
|
265 |
+
"""
|
266 |
+
|
267 |
+
def query(self, text: str, image: list[Image.Image | str]) -> str:
|
268 |
+
|
269 |
+
return self.gpt_client.query(
|
270 |
+
text_prompt=self.prompt.format(text),
|
271 |
+
image_base64=image,
|
272 |
)
|
273 |
|
274 |
+
|
275 |
+
class TextGenAlignChecker(BaseChecker):
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
gpt_client: GPTclient,
|
279 |
+
prompt: str = None,
|
280 |
+
verbose: bool = False,
|
281 |
+
) -> None:
|
282 |
+
super().__init__(prompt, verbose)
|
283 |
+
self.gpt_client = gpt_client
|
284 |
+
if self.prompt is None:
|
285 |
+
self.prompt = """
|
286 |
+
You are an expert in evaluating the quality of generated 3D assets.
|
287 |
+
You will be given:
|
288 |
+
- A text description of an object: TEXT
|
289 |
+
- Rendered views of the generated 3D asset.
|
290 |
+
|
291 |
+
Your task is to:
|
292 |
+
1. Determine whether the generated 3D asset roughly reflects the object class
|
293 |
+
or a semantically adjacent category described in the text.
|
294 |
+
2. Evaluate the geometry quality of the 3D asset generation based on the rendered views.
|
295 |
+
|
296 |
+
Criteria:
|
297 |
+
- Determine if the generated 3D asset belongs to the text described or a similar category.
|
298 |
+
- Focus on functional similarity: if the object serves the same general
|
299 |
+
purpose (e.g., writing, placing items), it should be accepted.
|
300 |
+
- Is the geometry complete and well-formed, with no missing parts,
|
301 |
+
distortions, visual artifacts, or redundant structures?
|
302 |
+
- Does the number of object instances match the description?
|
303 |
+
There should be only one object unless otherwise specified.
|
304 |
+
- Minor flaws in geometry or texture are acceptable, high tolerance for texture quality defects.
|
305 |
+
- Minor simplifications in geometry or texture (e.g. soft edges, less detail)
|
306 |
+
are acceptable if the object is still recognizable.
|
307 |
+
- The asset should not contain any unrelated elements, such as
|
308 |
+
ground planes, platforms, or background props (e.g., paper, flooring).
|
309 |
+
|
310 |
+
Example:
|
311 |
+
Text: "yellow cup"
|
312 |
+
Image: shows a yellow cup standing on a flat white plane -> NO: extra surface under the object.
|
313 |
+
|
314 |
+
Instructions:
|
315 |
+
- If the quality of generated asset is acceptable and faithfully represents the text, return "YES".
|
316 |
+
- Otherwise, return "NO" followed by a brief explanation (no more than 20 words).
|
317 |
+
|
318 |
+
Respond in exactly one of the following formats:
|
319 |
+
YES
|
320 |
+
or
|
321 |
+
NO: brief explanation
|
322 |
+
|
323 |
+
Input:
|
324 |
+
Text description: {}
|
325 |
+
"""
|
326 |
+
|
327 |
+
def query(self, text: str, image: list[Image.Image | str]) -> str:
|
328 |
+
|
329 |
+
return self.gpt_client.query(
|
330 |
+
text_prompt=self.prompt.format(text),
|
331 |
+
image_base64=image,
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
class PanoImageGenChecker(BaseChecker):
|
336 |
+
"""A checker class that validates the quality and realism of generated panoramic indoor images.
|
337 |
+
|
338 |
+
Attributes:
|
339 |
+
gpt_client (GPTclient): A GPT client instance used to query for image validation.
|
340 |
+
prompt (str): The instruction prompt passed to the GPT model. If None, a default prompt is used.
|
341 |
+
verbose (bool): Whether to print internal processing information for debugging.
|
342 |
+
"""
|
343 |
+
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
gpt_client: GPTclient,
|
347 |
+
prompt: str = None,
|
348 |
+
verbose: bool = False,
|
349 |
+
) -> None:
|
350 |
+
super().__init__(prompt, verbose)
|
351 |
+
self.gpt_client = gpt_client
|
352 |
+
if self.prompt is None:
|
353 |
+
self.prompt = """
|
354 |
+
You are a panoramic image analyzer specializing in indoor room structure validation.
|
355 |
+
|
356 |
+
Given a generated panoramic image, assess if it meets all the criteria:
|
357 |
+
- Floor Space: β₯30 percent of the floor is free of objects or obstructions.
|
358 |
+
- Visual Clarity: Floor, walls, and ceiling are clear, with no distortion, blur, noise.
|
359 |
+
- Structural Continuity: Surfaces form plausible, continuous geometry
|
360 |
+
without breaks, floating parts, or abrupt cuts.
|
361 |
+
- Spatial Completeness: Full 360Β° coverage without missing areas,
|
362 |
+
seams, gaps, or stitching artifacts.
|
363 |
+
Instructions:
|
364 |
+
- If all criteria are met, reply with "YES".
|
365 |
+
- Otherwise, reply with "NO: <brief explanation>" (max 20 words).
|
366 |
+
|
367 |
+
Respond exactly as:
|
368 |
+
"YES"
|
369 |
+
or
|
370 |
+
"NO: brief explanation."
|
371 |
+
"""
|
372 |
+
|
373 |
+
def query(self, image_paths: str | Image.Image) -> str:
|
374 |
+
|
375 |
+
return self.gpt_client.query(
|
376 |
+
text_prompt=self.prompt,
|
377 |
+
image_base64=image_paths,
|
378 |
+
)
|
379 |
+
|
380 |
+
|
381 |
+
class PanoImageOccChecker(BaseChecker):
|
382 |
+
"""Checks for physical obstacles in the bottom-center region of a panoramic image.
|
383 |
+
|
384 |
+
This class crops a specified region from the input panoramic image and uses
|
385 |
+
a GPT client to determine whether any physical obstacles there.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
gpt_client (GPTclient): The GPT-based client used for visual reasoning.
|
389 |
+
box_hw (tuple[int, int]): The height and width of the crop box.
|
390 |
+
prompt (str, optional): Custom prompt for the GPT client. Defaults to a predefined one.
|
391 |
+
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
392 |
+
"""
|
393 |
+
|
394 |
+
def __init__(
|
395 |
+
self,
|
396 |
+
gpt_client: GPTclient,
|
397 |
+
box_hw: tuple[int, int],
|
398 |
+
prompt: str = None,
|
399 |
+
verbose: bool = False,
|
400 |
+
) -> None:
|
401 |
+
super().__init__(prompt, verbose)
|
402 |
+
self.gpt_client = gpt_client
|
403 |
+
self.box_hw = box_hw
|
404 |
+
if self.prompt is None:
|
405 |
+
self.prompt = """
|
406 |
+
This image is a cropped region from the bottom-center of a panoramic view.
|
407 |
+
Please determine whether there is any obstacle present β such as furniture, tables, or other physical objects.
|
408 |
+
Ignore floor textures, rugs, carpets, shadows, and lighting effects β they do not count as obstacles.
|
409 |
+
Only consider real, physical objects that could block walking or movement.
|
410 |
+
|
411 |
+
Instructions:
|
412 |
+
- If there is no obstacle, reply: "YES".
|
413 |
+
- Otherwise, reply: "NO: <brief explanation>" (max 20 words).
|
414 |
+
|
415 |
+
Respond exactly as:
|
416 |
+
"YES"
|
417 |
+
or
|
418 |
+
"NO: brief explanation."
|
419 |
+
"""
|
420 |
+
|
421 |
+
def query(self, image_paths: str | Image.Image) -> str:
|
422 |
+
if isinstance(image_paths, str):
|
423 |
+
image_paths = Image.open(image_paths)
|
424 |
+
|
425 |
+
w, h = image_paths.size
|
426 |
+
image_paths = image_paths.crop(
|
427 |
+
(
|
428 |
+
(w - self.box_hw[1]) // 2,
|
429 |
+
h - self.box_hw[0],
|
430 |
+
(w + self.box_hw[1]) // 2,
|
431 |
+
h,
|
432 |
+
)
|
433 |
+
)
|
434 |
+
|
435 |
+
return self.gpt_client.query(
|
436 |
+
text_prompt=self.prompt,
|
437 |
+
image_base64=image_paths,
|
438 |
+
)
|
439 |
+
|
440 |
+
|
441 |
+
class PanoHeightEstimator(object):
|
442 |
+
"""Estimate the real ceiling height of an indoor space from a 360Β° panoramic image.
|
443 |
+
|
444 |
+
Attributes:
|
445 |
+
gpt_client (GPTclient): The GPT client used to perform image-based reasoning and return height estimates.
|
446 |
+
default_value (float): The fallback height in meters if parsing the GPT output fails.
|
447 |
+
prompt (str): The textual instruction used to guide the GPT model for height estimation.
|
448 |
+
"""
|
449 |
+
|
450 |
+
def __init__(
|
451 |
+
self,
|
452 |
+
gpt_client: GPTclient,
|
453 |
+
default_value: float = 3.5,
|
454 |
+
) -> None:
|
455 |
+
self.gpt_client = gpt_client
|
456 |
+
self.default_value = default_value
|
457 |
+
self.prompt = """
|
458 |
+
You are an expert in building height estimation and panoramic image analysis.
|
459 |
+
Your task is to analyze a 360Β° indoor panoramic image and estimate the **actual height** of the space in meters.
|
460 |
+
|
461 |
+
Consider the following visual cues:
|
462 |
+
1. Ceiling visibility and reference objects (doors, windows, furniture, appliances).
|
463 |
+
2. Floor features or level differences.
|
464 |
+
3. Room type (e.g., residential, office, commercial).
|
465 |
+
4. Object-to-ceiling proportions (e.g., height of doors relative to ceiling).
|
466 |
+
5. Architectural elements (e.g., chandeliers, shelves, kitchen cabinets).
|
467 |
+
|
468 |
+
Input: A full 360Β° panoramic indoor photo.
|
469 |
+
Output: A single number in meters representing the estimated room height. Only return the number (e.g., `3.2`)
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __call__(self, image_paths: str | Image.Image) -> float:
|
473 |
+
result = self.gpt_client.query(
|
474 |
+
text_prompt=self.prompt,
|
475 |
+
image_base64=image_paths,
|
476 |
+
)
|
477 |
+
try:
|
478 |
+
result = float(result.strip())
|
479 |
+
except Exception as e:
|
480 |
+
logger.error(
|
481 |
+
f"Parser error: failed convert {result} to float, {e}, use default value {self.default_value}."
|
482 |
)
|
483 |
+
result = self.default_value
|
484 |
+
|
485 |
+
return result
|
486 |
+
|
487 |
|
488 |
+
class SemanticMatcher(BaseChecker):
|
489 |
+
def __init__(
|
490 |
+
self,
|
491 |
+
gpt_client: GPTclient,
|
492 |
+
prompt: str = None,
|
493 |
+
verbose: bool = False,
|
494 |
+
seed: int = None,
|
495 |
+
) -> None:
|
496 |
+
super().__init__(prompt, verbose)
|
497 |
+
self.gpt_client = gpt_client
|
498 |
+
self.seed = seed
|
499 |
+
random.seed(seed)
|
500 |
+
if self.prompt is None:
|
501 |
+
self.prompt = """
|
502 |
+
You are an expert in semantic similarity and scene retrieval.
|
503 |
+
You will be given:
|
504 |
+
- A dictionary where each key is a scene ID, and each value is a scene description.
|
505 |
+
- A query text describing a target scene.
|
506 |
+
|
507 |
+
Your task:
|
508 |
+
return_num = 2
|
509 |
+
- Find the <return_num> most semantically similar scene IDs to the query text.
|
510 |
+
- If there are fewer than <return_num> distinct relevant matches, repeat the closest ones to make a list of <return_num>.
|
511 |
+
- Only output the list of <return_num> scene IDs, sorted from most to less similar.
|
512 |
+
- Do NOT use markdown, JSON code blocks, or any formatting syntax, only return a plain list like ["id1", ...].
|
513 |
+
|
514 |
+
Input example:
|
515 |
+
Dictionary:
|
516 |
+
"{{
|
517 |
+
"t_scene_008": "A study room with full bookshelves and a lamp in the corner.",
|
518 |
+
"t_scene_019": "A child's bedroom with pink walls and a small desk.",
|
519 |
+
"t_scene_020": "A living room with a wooden floor.",
|
520 |
+
"t_scene_021": "A living room with toys scattered on the floor.",
|
521 |
+
...
|
522 |
+
"t_scene_office_001": "A very spacious, modern open-plan office with wide desks and no people, panoramic view."
|
523 |
+
}}"
|
524 |
+
Text:
|
525 |
+
"A traditional indoor room"
|
526 |
+
Output:
|
527 |
+
'["t_scene_office_001", ...]'
|
528 |
+
|
529 |
+
Input:
|
530 |
+
Dictionary:
|
531 |
+
{context}
|
532 |
+
Text:
|
533 |
+
{text}
|
534 |
+
Output:
|
535 |
+
<topk_key_list>
|
536 |
+
"""
|
537 |
|
538 |
+
def query(
|
539 |
+
self, text: str, context: dict, rand: bool = True, params: dict = None
|
540 |
+
) -> str:
|
541 |
+
match_list = self.gpt_client.query(
|
542 |
+
self.prompt.format(context=context, text=text),
|
543 |
+
params=params,
|
544 |
+
)
|
545 |
+
match_list = json_repair.loads(match_list)
|
546 |
+
result = random.choice(match_list) if rand else match_list[0]
|
547 |
+
|
548 |
+
return result
|
549 |
+
|
550 |
+
|
551 |
+
def test_semantic_matcher(
|
552 |
+
bg_file: str = "outputs/bg_scenes/bg_scene_list.txt",
|
553 |
+
):
|
554 |
+
bg_file = "outputs/bg_scenes/bg_scene_list.txt"
|
555 |
+
scene_dict = {}
|
556 |
+
with open(bg_file, "r") as f:
|
557 |
+
for line in f:
|
558 |
+
line = line.strip()
|
559 |
+
if not line or ":" not in line:
|
560 |
+
continue
|
561 |
+
scene_id, desc = line.split(":", 1)
|
562 |
+
scene_dict[scene_id.strip()] = desc.strip()
|
563 |
+
|
564 |
+
office_scene = scene_dict.get("t_scene_office_001")
|
565 |
+
text = "bright kitchen"
|
566 |
+
SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
|
567 |
+
# gpt_params = {
|
568 |
+
# "temperature": 0.8,
|
569 |
+
# "max_tokens": 500,
|
570 |
+
# "top_p": 0.8,
|
571 |
+
# "frequency_penalty": 0.3,
|
572 |
+
# "presence_penalty": 0.3,
|
573 |
+
# }
|
574 |
+
gpt_params = None
|
575 |
+
match_key = SCENE_MATCHER.query(text, str(scene_dict))
|
576 |
+
print(match_key, ",", scene_dict[match_key])
|
577 |
+
|
578 |
+
|
579 |
+
if __name__ == "__main__":
|
580 |
+
test_semantic_matcher()
|
embodied_gen/validators/urdf_convertor.py
CHANGED
@@ -101,34 +101,42 @@ class URDFGenerator(object):
|
|
101 |
prompt_template = (
|
102 |
view_desc
|
103 |
+ """of the 3D object asset,
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
"""
|
133 |
)
|
134 |
|
@@ -297,20 +305,24 @@ class URDFGenerator(object):
|
|
297 |
if not os.path.exists(urdf_path):
|
298 |
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
299 |
|
300 |
-
|
301 |
tree = ET.parse(urdf_path)
|
302 |
root = tree.getroot()
|
303 |
extra_info = root.find(attr_root)
|
304 |
if extra_info is not None:
|
305 |
scale_element = extra_info.find(attr_name)
|
306 |
if scale_element is not None:
|
307 |
-
|
|
|
|
|
|
|
|
|
308 |
|
309 |
-
return
|
310 |
|
311 |
@staticmethod
|
312 |
def add_quality_tag(
|
313 |
-
urdf_path: str, results, output_path: str = None
|
314 |
) -> None:
|
315 |
if output_path is None:
|
316 |
output_path = urdf_path
|
@@ -366,17 +378,11 @@ class URDFGenerator(object):
|
|
366 |
output_root,
|
367 |
num_images=self.render_view_num,
|
368 |
output_subdir=self.output_render_dir,
|
|
|
369 |
)
|
370 |
|
371 |
-
# Hardcode tmp because of the openrouter can't input multi images.
|
372 |
-
if "openrouter" in self.gpt_client.endpoint:
|
373 |
-
from embodied_gen.utils.process_media import (
|
374 |
-
combine_images_to_base64,
|
375 |
-
)
|
376 |
-
|
377 |
-
image_path = combine_images_to_base64(image_path)
|
378 |
-
|
379 |
response = self.gpt_client.query(text_prompt, image_path)
|
|
|
380 |
if response is None:
|
381 |
asset_attrs = {
|
382 |
"category": category.lower(),
|
@@ -412,14 +418,18 @@ class URDFGenerator(object):
|
|
412 |
if __name__ == "__main__":
|
413 |
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
414 |
urdf_path = urdf_gen(
|
415 |
-
mesh_path="outputs/
|
416 |
output_root="outputs/test_urdf",
|
417 |
-
|
418 |
# min_height=1.0,
|
419 |
# max_height=1.2,
|
420 |
version=VERSION,
|
421 |
)
|
422 |
|
|
|
|
|
|
|
|
|
423 |
# zip_files(
|
424 |
# input_paths=[
|
425 |
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
|
|
|
101 |
prompt_template = (
|
102 |
view_desc
|
103 |
+ """of the 3D object asset,
|
104 |
+
category: {category}.
|
105 |
+
You are an expert in 3D object analysis and physical property estimation.
|
106 |
+
Give the category of this object asset (within 3 words), (if category is
|
107 |
+
already provided, use it directly), accurately describe this 3D object asset (within 15 words),
|
108 |
+
Determine the pose of the object in the first image and estimate the true vertical height
|
109 |
+
(vertical projection) range of the object (in meters), i.e., how tall the object appears from top
|
110 |
+
to bottom in the front view (first) image. also weight range (unit: kilogram), the average
|
111 |
+
static friction coefficient of the object relative to rubber and the average dynamic friction
|
112 |
+
coefficient of the object relative to rubber. Return response format as shown in Output Example.
|
113 |
+
|
114 |
+
Output Example:
|
115 |
+
Category: cup
|
116 |
+
Description: shiny golden cup with floral design
|
117 |
+
Height: 0.1-0.15 m
|
118 |
+
Weight: 0.3-0.6 kg
|
119 |
+
Static friction coefficient: 0.6
|
120 |
+
Dynamic friction coefficient: 0.5
|
121 |
+
|
122 |
+
IMPORTANT: Estimating Vertical Height from the First (Front View) Image.
|
123 |
+
- The "vertical height" refers to the real-world vertical size of the object
|
124 |
+
as projected in the first image, aligned with the image's vertical axis.
|
125 |
+
- For flat objects like plates or disks or book, if their face is visible in the front view,
|
126 |
+
use the diameter as the vertical height. If the edge is visible, use the thickness instead.
|
127 |
+
- This is not necessarily the full length of the object, but how tall it appears
|
128 |
+
in the first image vertically, based on its pose and orientation.
|
129 |
+
- For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in
|
130 |
+
the first image, tilted at 45Β° will appear shorter vertically than when upright.
|
131 |
+
Estimate the vertical projection of their real length based on its pose.
|
132 |
+
For example:
|
133 |
+
- A pen standing upright in the first view (aligned with the image's vertical axis)
|
134 |
+
full body visible in the first image: β vertical height β 0.14-0.20 m
|
135 |
+
- A pen lying flat in the front view (showing thickness) β vertical height β 0.018-0.025 m
|
136 |
+
- Tilted pen in the first image (e.g., ~45Β° angle): vertical height β 0.07-0.12 m
|
137 |
+
- Use the rest views(except the first image) to help determine the object's 3D pose and orientation.
|
138 |
+
Assume the object is in real-world scale and estimate the approximate vertical height
|
139 |
+
(in meters) based on how large it appears vertically in the first image.
|
140 |
"""
|
141 |
)
|
142 |
|
|
|
305 |
if not os.path.exists(urdf_path):
|
306 |
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
|
307 |
|
308 |
+
mesh_attr = None
|
309 |
tree = ET.parse(urdf_path)
|
310 |
root = tree.getroot()
|
311 |
extra_info = root.find(attr_root)
|
312 |
if extra_info is not None:
|
313 |
scale_element = extra_info.find(attr_name)
|
314 |
if scale_element is not None:
|
315 |
+
mesh_attr = scale_element.text
|
316 |
+
try:
|
317 |
+
mesh_attr = float(mesh_attr)
|
318 |
+
except ValueError as e:
|
319 |
+
pass
|
320 |
|
321 |
+
return mesh_attr
|
322 |
|
323 |
@staticmethod
|
324 |
def add_quality_tag(
|
325 |
+
urdf_path: str, results: list, output_path: str = None
|
326 |
) -> None:
|
327 |
if output_path is None:
|
328 |
output_path = urdf_path
|
|
|
378 |
output_root,
|
379 |
num_images=self.render_view_num,
|
380 |
output_subdir=self.output_render_dir,
|
381 |
+
no_index_file=True,
|
382 |
)
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
response = self.gpt_client.query(text_prompt, image_path)
|
385 |
+
# logger.info(response)
|
386 |
if response is None:
|
387 |
asset_attrs = {
|
388 |
"category": category.lower(),
|
|
|
418 |
if __name__ == "__main__":
|
419 |
urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
|
420 |
urdf_path = urdf_gen(
|
421 |
+
mesh_path="outputs/layout2/asset3d/marker/result/mesh/marker.obj",
|
422 |
output_root="outputs/test_urdf",
|
423 |
+
category="marker",
|
424 |
# min_height=1.0,
|
425 |
# max_height=1.2,
|
426 |
version=VERSION,
|
427 |
)
|
428 |
|
429 |
+
URDFGenerator.add_quality_tag(
|
430 |
+
urdf_path, [[urdf_gen.__class__.__name__, "OK"]]
|
431 |
+
)
|
432 |
+
|
433 |
# zip_files(
|
434 |
# input_paths=[
|
435 |
# "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
|