From 0cc9f022a79db3133a2a006ee3d8c49dce6a8196 Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Mon, 29 Jun 2026 16:23:25 +0800 Subject: [PATCH 1/4] feat: add text/image-to-3d-scene pipeline --- embodichain/gen_sim/prompt2scene/.gitignore | 7 + embodichain/gen_sim/prompt2scene/__init__.py | 15 + .../prompt2scene/agent_tools/__init__.py | 1 + .../agent_tools/clients/__init__.py | 31 + .../prompt2scene/agent_tools/clients/base.py | 131 ++++ .../agent_tools/clients/common.py | 139 ++++ .../agent_tools/clients/config.py | 50 ++ .../geometry_generation_client/__init__.py | 49 ++ .../geometry_generation_client/client.py | 213 ++++++ .../geometry_generation_client/parser.py | 255 +++++++ .../geometry_generation_client/schemas.py | 134 ++++ .../image_generation_client/__init__.py | 39 ++ .../clients/image_generation_client/client.py | 117 ++++ .../clients/image_generation_client/parser.py | 65 ++ .../image_generation_client/schemas.py | 72 ++ .../image_segmentation_client/__init__.py | 61 ++ .../image_segmentation_client/client.py | 132 ++++ .../image_segmentation_client/parser.py | 218 ++++++ .../image_segmentation_client/schemas.py | 103 +++ .../image_segmentation_client/utils.py | 322 +++++++++ .../blender_rendering_manager/__init__.py | 31 + .../blender_rendering_manager/manager.py | 175 +++++ .../blender_rendering_manager/schemas.py | 39 ++ .../geometry_generation_manager/__init__.py | 45 ++ .../geometry_generation_manager/manager.py | 209 ++++++ .../geometry_generation_manager/schemas.py | 105 +++ .../managers/geometry_manager/__init__.py | 69 ++ .../managers/geometry_manager/manager.py | 584 ++++++++++++++++ .../geometry_manager/scene_geometry.py | 567 ++++++++++++++++ .../managers/geometry_manager/schemas.py | 201 ++++++ .../image_generation_manager/__init__.py | 35 + .../image_generation_manager/manager.py | 76 +++ .../image_generation_manager/schemas.py | 43 ++ .../managers/image_scene_manager/__init__.py | 29 + .../managers/image_scene_manager/alignment.py | 537 +++++++++++++++ .../managers/image_scene_manager/manifests.py | 212 ++++++ .../managers/image_scene_manager/prompts.py | 106 +++ .../managers/image_scene_manager/schemas.py | 71 ++ .../image_segmentation_manager/__init__.py | 33 + .../image_segmentation_manager/manager.py | 90 +++ .../image_segmentation_manager/schemas.py | 48 ++ .../managers/matplotlib_manager/__init__.py | 43 ++ .../managers/matplotlib_manager/manager.py | 401 +++++++++++ .../managers/matplotlib_manager/schemas.py | 101 +++ .../managers/metric_scale_manager/__init__.py | 37 + .../managers/metric_scale_manager/manager.py | 431 ++++++++++++ .../managers/metric_scale_manager/schemas.py | 73 ++ .../managers/optimization_manager/__init__.py | 37 + .../managers/optimization_manager/manager.py | 633 +++++++++++++++++ .../managers/simready_manager/__init__.py | 35 + .../managers/simready_manager/manager.py | 396 +++++++++++ .../managers/simready_manager/schemas.py | 58 ++ .../managers/simulation_manager/__init__.py | 31 + .../managers/simulation_manager/manager.py | 124 ++++ .../managers/simulation_manager/schemas.py | 42 ++ .../table_clutter_fit_manager/__init__.py | 23 + .../table_clutter_fit_manager/manager.py | 298 ++++++++ .../managers/text_layout_manager/__init__.py | 33 + .../managers/text_layout_manager/layout.py | 383 +++++++++++ .../text_layout_manager/optimization.py | 404 +++++++++++ .../managers/text_layout_manager/settle.py | 429 ++++++++++++ .../agent_tools/servers/__init__.py | 16 + .../agent_tools/tools/__init__.py | 19 + .../agent_tools/tools/gym_export.py | 319 +++++++++ .../tools/image_scene_asset_generation.py | 636 ++++++++++++++++++ .../agent_tools/tools/table_fit_scene.py | 105 +++ .../tools/text_asset_generation.py | 294 ++++++++ .../agent_tools/tools/text_clutter_layout.py | 62 ++ .../tools/text_scene_metric_scale.py | 161 +++++ .../gen_sim/prompt2scene/cli/__init__.py | 19 + embodichain/gen_sim/prompt2scene/cli/start.py | 90 +++ .../prompt2scene/configs/client_config.json | 21 + .../prompt2scene/configs/llm_config.json | 11 + .../gen_sim/prompt2scene/llms/__init__.py | 31 + .../gen_sim/prompt2scene/llms/config.py | 49 ++ .../prompt2scene/llms/openai_compatible.py | 115 ++++ .../gen_sim/prompt2scene/pipeline/__init__.py | 25 + .../gen_sim/prompt2scene/pipeline/runner.py | 239 +++++++ .../gen_sim/prompt2scene/prompts/__init__.py | 48 ++ .../gen_sim/prompt2scene/prompts/base.py | 79 +++ .../prompt2scene/prompts/data/__init__.py | 21 + .../prompts/data/image_relations.yaml | 238 +++++++ .../prompts/data/scene_intake.yaml | 468 +++++++++++++ .../prompts/data/text_relations.yaml | 110 +++ .../prompts/data/unified_scene_gen.yaml | 225 +++++++ .../gen_sim/prompt2scene/utils/__init__.py | 39 ++ embodichain/gen_sim/prompt2scene/utils/io.py | 66 ++ embodichain/gen_sim/prompt2scene/utils/log.py | 62 ++ .../prompt2scene/workflows/__init__.py | 41 ++ .../prompt2scene/workflows/artifact_writer.py | 271 ++++++++ .../prompt2scene/workflows/attempt_state.py | 30 + .../workflows/image_relations/__init__.py | 24 + .../workflows/image_relations/graph.py | 189 ++++++ .../workflows/image_relations/nodes.py | 511 ++++++++++++++ .../workflows/image_relations/prompts.py | 113 ++++ .../workflows/image_relations/schema.py | 250 +++++++ .../workflows/image_relations/state.py | 42 ++ .../workflows/image_relations/utils.py | 435 ++++++++++++ .../prompt2scene/workflows/llm_output.py | 285 ++++++++ .../gen_sim/prompt2scene/workflows/request.py | 110 +++ .../workflows/scene_intake/__init__.py | 24 + .../workflows/scene_intake/graph.py | 142 ++++ .../workflows/scene_intake/nodes.py | 211 ++++++ .../workflows/scene_intake/prompts.py | 197 ++++++ .../workflows/scene_intake/schema.py | 244 +++++++ .../workflows/scene_intake/state.py | 37 + .../workflows/scene_intake/utils.py | 229 +++++++ .../gen_sim/prompt2scene/workflows/spatial.py | 309 +++++++++ .../prompt2scene/workflows/stage_errors.py | 40 ++ .../workflows/text_relations/__init__.py | 24 + .../workflows/text_relations/graph.py | 124 ++++ .../workflows/text_relations/nodes.py | 144 ++++ .../workflows/text_relations/prompts.py | 55 ++ .../workflows/text_relations/schema.py | 164 +++++ .../workflows/text_relations/state.py | 42 ++ .../workflows/text_relations/utils.py | 191 ++++++ .../workflows/unified_scene/__init__.py | 19 + .../workflows/unified_scene/graph.py | 97 +++ .../workflows/unified_scene/nodes.py | 57 ++ .../workflows/unified_scene/schema.py | 157 +++++ .../workflows/unified_scene/state.py | 45 ++ .../workflows/unified_scene/utils.py | 332 +++++++++ .../workflows/unified_scene_gen/__init__.py | 27 + .../workflows/unified_scene_gen/graph.py | 106 +++ .../workflows/unified_scene_gen/nodes.py | 392 +++++++++++ .../workflows/unified_scene_gen/paths.py | 102 +++ .../workflows/unified_scene_gen/prompts.py | 141 ++++ .../unified_scene_gen/scene_update.py | 76 +++ .../workflows/unified_scene_gen/schema.py | 71 ++ .../workflows/unified_scene_gen/state.py | 40 ++ 130 files changed, 19179 insertions(+) create mode 100644 embodichain/gen_sim/prompt2scene/.gitignore create mode 100644 embodichain/gen_sim/prompt2scene/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py create mode 100644 embodichain/gen_sim/prompt2scene/cli/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/cli/start.py create mode 100644 embodichain/gen_sim/prompt2scene/configs/client_config.json create mode 100644 embodichain/gen_sim/prompt2scene/configs/llm_config.json create mode 100644 embodichain/gen_sim/prompt2scene/llms/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/llms/config.py create mode 100644 embodichain/gen_sim/prompt2scene/llms/openai_compatible.py create mode 100644 embodichain/gen_sim/prompt2scene/pipeline/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/pipeline/runner.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/base.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml create mode 100644 embodichain/gen_sim/prompt2scene/utils/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/utils/io.py create mode 100644 embodichain/gen_sim/prompt2scene/utils/log.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/attempt_state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/llm_output.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/request.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/spatial.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/stage_errors.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py diff --git a/embodichain/gen_sim/prompt2scene/.gitignore b/embodichain/gen_sim/prompt2scene/.gitignore new file mode 100644 index 00000000..75f4908e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/.gitignore @@ -0,0 +1,7 @@ +cli/preview* +cli/export* +agent_tools/servers/geometry_generation_server/* + +# Python cache +__pycache__/ +*.py[cod] diff --git a/embodichain/gen_sim/prompt2scene/__init__.py b/embodichain/gen_sim/prompt2scene/__init__.py new file mode 100644 index 00000000..01ece10d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/__init__.py @@ -0,0 +1,15 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/__init__.py new file mode 100644 index 00000000..a4b11ff0 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/__init__.py @@ -0,0 +1 @@ +"""Internal client + External server for agent tool calling.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py new file mode 100644 index 00000000..3afc32bd --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, + load_client_config, +) + +__all__ = [ + "BaseHttpClient", + "ClientError", + "DEFAULT_CLIENT_CONFIG_PATH", + "load_client_config", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py new file mode 100644 index 00000000..8981602f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py @@ -0,0 +1,131 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Callable + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + ClientError, + build_client_error, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + load_client_config, +) +from embodichain.gen_sim.prompt2scene.utils.log import ( + log_api_request_start, + log_info, + log_warning, +) + +__all__ = ["BaseHttpClient"] + + +class BaseHttpClient: + """Shared HTTP client behavior for agent-tool service clients.""" + + def __init__( + self, + *, + config_key: str, + server_name: str, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + session: requests.Session | None = None, + trust_env: bool = True, + ) -> None: + """Initialize common service client fields from config.""" + self.config = load_client_config(config_key, config_path) + self.server_name = server_name + self.base_url = (base_url or str(self.config["base_url"])).rstrip("/") + self.timeout_s = int(timeout_s or self.config.get("timeout_s", 120)) + self.health_path = str(self.config.get("health_path", "/health")) + self.session = session or requests.Session() + self.session.trust_env = trust_env + log_info(f"{self.server_name} client initialized for {self.base_url}") + + def health_check(self) -> bool: + """Check whether the configured service is healthy.""" + try: + response = self.session.get( + f"{self.base_url}{self.health_path}", + timeout=5, + ) + response.raise_for_status() + return True + except Exception as exc: + log_warning(f"{self.server_name} health check failed: {exc}") + return False + + def post_with_retries( + self, + request_fn: Callable[[], requests.Response], + *, + max_retries: int, + error_cls: type[ClientError] = ClientError, + request_label: str | None = None, + ) -> requests.Response | ClientError: + """Run a POST request function with retry and HTTP error handling.""" + for attempt in range(max_retries): + try: + if request_label is not None: + log_api_request_start( + step=self.server_name, + request=request_label, + attempt=attempt + 1, + ) + response = request_fn() + response.raise_for_status() + return response + + except requests.exceptions.ConnectionError as exc: + if attempt < max_retries - 1: + log_warning( + f"{self.server_name} connection failed; retrying " + f"({attempt + 1}/{max_retries})." + ) + time.sleep(min(2**attempt, 60)) + continue + raise ConnectionError( + f"Failed to connect to {self.server_name} at {self.base_url}" + ) from exc + + except requests.exceptions.HTTPError as exc: + response = exc.response + if response is None: + raise RuntimeError(f"{self.server_name} HTTP request failed.") from exc + if response.status_code >= 500 and attempt < max_retries - 1: + log_warning( + f"{self.server_name} server error; retrying " + f"({attempt + 1}/{max_retries})." + ) + time.sleep(min(2**attempt, 60)) + continue + return build_client_error( + response, + server_name=self.server_name, + error_cls=error_cls, + ) + + except requests.exceptions.Timeout as exc: + raise TimeoutError(f"{self.server_name} request timed out.") from exc + + raise RuntimeError(f"{self.server_name} request failed unexpectedly.") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py new file mode 100644 index 00000000..f1c7dc69 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py @@ -0,0 +1,139 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import requests + +__all__ = [ + "ClientError", + "build_client_error", + "first_string", + "format_http_error", + "parse_error_response", + "parse_json_object_response", + "validate_required_strings", + "validate_png_response", +] + + +@dataclass(frozen=True) +class ClientError: + """Common HTTP client error response.""" + + error_message: str + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + raw_response: dict[str, Any] | None = None + + +def validate_png_response( + response: requests.Response, + png_bytes: bytes, +) -> None: + content_type = response.headers.get("Content-Type", "") + if "image/png" not in content_type.lower(): + raise RuntimeError( + "Image generation server returned non-PNG content: " + f"{content_type or 'unknown'}" + ) + if not png_bytes.startswith(b"\x89PNG\r\n\x1a\n"): + raise RuntimeError("Image generation server returned invalid PNG bytes.") + + +def validate_required_strings(fields: dict[str, object]) -> None: + """Validate required client request string fields.""" + for field_name, value in fields.items(): + if not str(value).strip(): + raise ValueError(f"{field_name} must be non-empty.") + + +def format_http_error(response: requests.Response, *, server_name: str) -> str: + """Format an HTTP error response from an agent-tool server.""" + try: + response_data = response.json() + except ValueError: + return f"{server_name} HTTP error: {response.status_code}" + + error_message = first_string( + response_data, + "error", + "error_message", + "message", + "detail", + ) + if error_message: + return f"{server_name} error: {error_message}" + return f"{server_name} HTTP error: {response.status_code}" + + +def parse_error_response(response: requests.Response) -> dict[str, Any] | None: + """Parse an error response body as a JSON object if possible.""" + try: + response_data = response.json() + except ValueError: + return None + return response_data if isinstance(response_data, dict) else None + + +def build_client_error( + response: requests.Response, + *, + server_name: str, + error_cls: type[ClientError] = ClientError, +) -> ClientError: + """Build a common client error dataclass from an HTTP response.""" + return error_cls( + error_message=format_http_error( + response, + server_name=server_name, + ), + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + raw_response=parse_error_response(response), + ) + + +def parse_json_object_response( + response: requests.Response, + *, + server_name: str, +) -> dict[str, Any]: + """Parse an HTTP response body as a JSON object.""" + try: + response_data = response.json() + except ValueError as exc: + raise RuntimeError( + f"{server_name} returned invalid JSON content: " + f"{response.headers.get('Content-Type') or 'unknown'}" + ) from exc + if not isinstance(response_data, dict): + raise RuntimeError(f"{server_name} response must be a JSON object.") + return response_data + + +def first_string(data: dict[str, Any], *keys: str) -> str | None: + """Return the first string value for the given keys.""" + for key in keys: + value = data.get(key) + if isinstance(value, str): + return value + return None diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py new file mode 100644 index 00000000..5592806a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py @@ -0,0 +1,50 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +__all__ = ["DEFAULT_CLIENT_CONFIG_PATH", "load_client_config"] + +DEFAULT_CLIENT_CONFIG_PATH = ( + Path(__file__).resolve().parents[2] / "configs" / "client_config.json" +) + + +def load_client_config( + config_key: str, + config_path: Path | None = None, +) -> dict[str, Any]: + """Load one agent-tool client config section.""" + resolved_config_path = (config_path or DEFAULT_CLIENT_CONFIG_PATH).resolve() + if not resolved_config_path.is_file(): + raise FileNotFoundError(f"Client config not found: {resolved_config_path}") + + with resolved_config_path.open("r", encoding="utf-8") as f: + raw_config = json.load(f) + + config = raw_config.get(config_key) + if not isinstance(config, dict): + raise ValueError( + f"Client config section {config_key!r} not found in " + f"{resolved_config_path}" + ) + if not config.get("base_url"): + raise ValueError(f"Client config section {config_key!r} requires base_url.") + return config diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py new file mode 100644 index 00000000..3fa63f3b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py @@ -0,0 +1,49 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.client import ( + GeometryGenerationClient, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.schemas import ( + GeometryGenerationError, + GeometryGenerationResult, + GeometryGenerationServerRequest, + GeometryGenerationServerResponse, + MultiObjectGenerationError, + MultiObjectGenerationObject, + MultiObjectGenerationResult, + MultiObjectGenerationServerRequest, + MultiObjectGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "GeometryGenerationClient", + "GeometryGenerationError", + "GeometryGenerationResult", + "GeometryGenerationServerRequest", + "GeometryGenerationServerResponse", + "MultiObjectGenerationError", + "MultiObjectGenerationObject", + "MultiObjectGenerationResult", + "MultiObjectGenerationServerRequest", + "MultiObjectGenerationServerResponse", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py new file mode 100644 index 00000000..0615c6d2 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py @@ -0,0 +1,213 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the SAM3D geometry generation server.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_required_strings, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.parser import ( + parse_geometry_generation_response, + parse_multi_object_generation_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.schemas import ( + GeometryGenerationError, + GeometryGenerationResult, + GeometryGenerationServerRequest, + GeometryGenerationServerResponse, + MultiObjectGenerationError, + MultiObjectGenerationObject, + MultiObjectGenerationServerRequest, + MultiObjectGenerationServerResponse, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "GeometryGenerationClient", + "GeometryGenerationError", + "GeometryGenerationResult", + "GeometryGenerationServerRequest", + "GeometryGenerationServerResponse", + "MultiObjectGenerationError", + "MultiObjectGenerationObject", + "MultiObjectGenerationServerRequest", + "MultiObjectGenerationServerResponse", +] + + +class GeometryGenerationClient(BaseHttpClient): + """Client for making single-object SAM3D geometry generation requests.""" + + def __init__( + self, + *, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + config_key: str = "sam3d_generation", + session: requests.Session | None = None, + ) -> None: + """Initialize the geometry generation client.""" + super().__init__( + config_key=config_key, + server_name="Geometry generation server", + base_url=base_url, + timeout_s=timeout_s, + config_path=config_path, + session=session, + trust_env=False, + ) + self.generate_single_object_path = str( + self.config.get("generate_single_object_path", "/generate_single_object") + ) + self.generate_multiple_objects_path = str( + self.config.get( + "generate_multiple_objects_path", "/generate_multiple_objects" + ) + ) + + def generate( + self, + request: GeometryGenerationServerRequest, + *, + max_retries: int = 3, + ) -> GeometryGenerationServerResponse | GeometryGenerationError: + """Generate one GLB mesh from an object image and save it locally.""" + _validate_request(request) + url = f"{self.base_url}{self.generate_single_object_path}" + response = self.post_with_retries( + lambda: _post_geometry_generation_request(self, url, request), + max_retries=max_retries, + error_cls=GeometryGenerationError, + request_label="geometry_generation", + ) + if isinstance(response, GeometryGenerationError): + return response + return parse_geometry_generation_response(response, request) + + def generate_multiple_objects( + self, + request: MultiObjectGenerationServerRequest, + *, + output_dir: Path | None = None, + max_retries: int = 3, + ) -> MultiObjectGenerationServerResponse | MultiObjectGenerationError: + """Generate multiple GLB meshes from one image and multiple masks.""" + _validate_multi_object_request(request) + url = f"{self.base_url}{self.generate_multiple_objects_path}" + response = self.post_with_retries( + lambda: _post_multi_object_generation_request(self, url, request), + max_retries=max_retries, + error_cls=MultiObjectGenerationError, + request_label="multi_object_geometry_generation", + ) + if isinstance(response, MultiObjectGenerationError): + return response + return parse_multi_object_generation_response( + response, + self.base_url, + output_dir=output_dir, + session=self.session, + ) + + +def _validate_request(request: GeometryGenerationServerRequest) -> None: + validate_required_strings( + { + "Geometry generation image_path": request.image_path, + "Geometry generation output_path": request.output_path, + } + ) + image_path = Path(request.image_path).expanduser() + if not image_path.is_file(): + raise FileNotFoundError(f"Geometry generation input not found: {image_path}") + if not str(request.output_path).lower().endswith(".glb"): + raise ValueError("Geometry generation output_path must be a GLB file path.") + + +def _post_geometry_generation_request( + client: GeometryGenerationClient, + url: str, + request: GeometryGenerationServerRequest, +) -> requests.Response: + with _open_image_file(request.image_path) as image_file: + return client.session.post( + url, + data=request.to_form_data(), + files={ + "image": ( + Path(request.image_path).name, + image_file, + ) + }, + timeout=(10, client.timeout_s), + ) + + +def _open_image_file(image_path: str | Path) -> Any: + return Path(image_path).expanduser().resolve().open("rb") + + +def _validate_multi_object_request( + request: MultiObjectGenerationServerRequest, +) -> None: + validate_required_strings( + {"Multi-object geometry generation image_path": request.image_path} + ) + image_path = Path(request.image_path).expanduser() + if not image_path.is_file(): + raise FileNotFoundError( + f"Multi-object geometry generation input not found: {image_path}" + ) + if not request.mask_paths: + raise ValueError("mask_paths must be non-empty.") + for mask_path in request.mask_paths: + if not Path(mask_path).expanduser().is_file(): + raise FileNotFoundError( + f"Multi-object geometry mask not found: {mask_path}" + ) + + +def _post_multi_object_generation_request( + client: GeometryGenerationClient, + url: str, + request: MultiObjectGenerationServerRequest, +) -> requests.Response: + mask_files = [ + ("masks", (Path(p).name, Path(p).expanduser().resolve().open("rb"))) + for p in request.mask_paths + ] + try: + return client.session.post( + url, + data=request.to_form_data(), + files=[("image", (Path(request.image_path).name, _open_image_file(request.image_path)))] + mask_files, + timeout=(10, client.timeout_s), + ) + finally: + for _, (_, f) in mask_files: + f.close() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py new file mode 100644 index 00000000..4d3c0967 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py @@ -0,0 +1,255 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.schemas import ( + GeometryGenerationResult, + GeometryGenerationServerRequest, + GeometryGenerationServerResponse, + MultiObjectGenerationObject, + MultiObjectGenerationResult, + MultiObjectGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = ["parse_geometry_generation_response", "parse_multi_object_generation_response"] + + +def parse_geometry_generation_response( + response: requests.Response, + request: GeometryGenerationServerRequest, +) -> GeometryGenerationServerResponse: + """Parse a geometry GLB response and save it to the request output path.""" + glb_bytes = response.content + _validate_glb_response(response, glb_bytes) + output_path = _write_glb_output(request, glb_bytes) + result = GeometryGenerationResult(geometry_path=str(output_path)) + return GeometryGenerationServerResponse( + ok=True, + status="ok", + result=result, + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _validate_glb_response( + response: requests.Response, + glb_bytes: bytes, +) -> None: + if not glb_bytes.startswith(b"glTF"): + content_type = response.headers.get("Content-Type", "") + raise RuntimeError( + "Geometry generation server returned invalid GLB content: " + f"{content_type or 'unknown'}" + ) + + +def _write_glb_output( + request: GeometryGenerationServerRequest, + glb_bytes: bytes, +) -> Path: + output_path = Path(request.output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(glb_bytes) + if not output_path.is_file(): + raise FileNotFoundError(f"Generated geometry was not written: {output_path}") + log_info(f"Generated geometry written: {output_path}") + return output_path + + +def parse_multi_object_generation_response( + response: requests.Response, + base_url: str, + *, + output_dir: Path | None = None, + session: requests.Session | None = None, +) -> MultiObjectGenerationServerResponse: + """Parse a multi-object geometry response, download GLBs if output_dir given.""" + body = _parse_json_body(response) + ok = body.get("ok", False) + if not isinstance(ok, bool) or not ok: + error_msg = body.get("error", "ok is not true") + raise RuntimeError( + f"Multi-object geometry generation failed: {error_msg}" + ) + + result_data = body.get("result") + if not isinstance(result_data, dict): + raise RuntimeError( + "Multi-object geometry generation response missing 'result' object" + ) + base = base_url.rstrip("/") + objects = _parse_multi_object_items( + result_data, + base, + output_dir=output_dir, + session=session, + ) + + return MultiObjectGenerationServerResponse( + ok=True, + status=str(body.get("status") or "ok"), + result=MultiObjectGenerationResult(objects=objects), + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _parse_multi_object_items( + body: dict[str, object], + base_url: str, + *, + output_dir: Path | None, + session: requests.Session | None, +) -> list[MultiObjectGenerationObject]: + response_objects = body.get("objects") + if not isinstance(response_objects, list) or not response_objects: + raise RuntimeError( + "Multi-object geometry generation response missing 'result.objects' list" + ) + return [ + _parse_multi_object_item( + item, + index=i, + base_url=base_url, + output_dir=output_dir, + session=session, + ) + for i, item in enumerate(response_objects) + ] + + +def _parse_multi_object_item( + item: object, + *, + index: int, + base_url: str, + output_dir: Path | None, + session: requests.Session | None, +) -> MultiObjectGenerationObject: + if not isinstance(item, dict): + raise RuntimeError(f"Multi-object item {index} must be a JSON object") + + mesh_rel_path = item.get("mesh") + if not isinstance(mesh_rel_path, str) or not mesh_rel_path: + raise RuntimeError(f"Multi-object item {index} missing 'mesh'") + + name = str(item.get("name") or Path(mesh_rel_path).stem or index) + geometry_path = _resolve_or_download_glb( + base_url, + mesh_rel_path, + name=name, + index=index, + output_dir=output_dir, + session=session, + ) + + return MultiObjectGenerationObject( + name=name, + geometry_path=geometry_path, + rotation_quaternion_wxyz=_float_list( + item.get("rotation_quaternion_wxyz"), + expected_len=4, + field_name=f"objects[{index}].rotation_quaternion_wxyz", + ), + translation=_float_list( + item.get("translation"), + expected_len=3, + field_name=f"objects[{index}].translation", + ), + scale=_float_list( + item.get("scale"), + expected_len=3, + field_name=f"objects[{index}].scale", + ), + ) + + +def _resolve_or_download_glb( + base_url: str, + mesh_rel_path: str, + *, + name: str, + index: int, + output_dir: Path | None, + session: requests.Session | None, +) -> str: + url = _join_url(base_url, mesh_rel_path) + if output_dir is None: + return url + + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + filename = f"{name}.glb" if name else f"{index}.glb" + dest = output_dir / filename + _download_glb(url, dest, session=session) + return str(dest) + + +def _join_url(base_url: str, path_or_url: str) -> str: + if path_or_url.startswith(("http://", "https://")): + return path_or_url + if path_or_url.startswith("/"): + return f"{base_url}{path_or_url}" + return f"{base_url}/{path_or_url}" + + +def _float_list(value: object, *, expected_len: int, field_name: str) -> list[float]: + if not isinstance(value, list) or len(value) != expected_len: + raise RuntimeError(f"Multi-object geometry response missing '{field_name}'") + try: + return [float(v) for v in value] + except (TypeError, ValueError) as exc: + raise RuntimeError( + f"Multi-object geometry response field '{field_name}' must be numeric" + ) from exc + + +def _parse_json_body(response: requests.Response) -> dict[str, object]: + try: + body = response.json() + except ValueError as exc: + raise RuntimeError( + "Multi-object geometry generation server returned invalid JSON" + ) from exc + if not isinstance(body, dict): + raise RuntimeError( + "Multi-object geometry generation response must be a JSON object" + ) + return body + + +def _download_glb( + url: str, + dest: Path, + *, + session: requests.Session | None, +) -> None: + """Download a GLB from the geometry server.""" + http = session or requests.Session() + r = http.get(url, timeout=30) + r.raise_for_status() + _validate_glb_response(r, r.content) + dest.write_bytes(r.content) + log_info(f"Generated geometry written: {dest}") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py new file mode 100644 index 00000000..d8ede9ee --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py @@ -0,0 +1,134 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError + +__all__ = [ + "GeometryGenerationError", + "GeometryGenerationResult", + "GeometryGenerationServerRequest", + "GeometryGenerationServerResponse", + "MultiObjectGenerationError", + "MultiObjectGenerationObject", + "MultiObjectGenerationResult", + "MultiObjectGenerationServerRequest", + "MultiObjectGenerationServerResponse", +] + + +@dataclass(frozen=True) +class GeometryGenerationServerRequest: + """Request sent to the Geometry Generation server. + + Args: + image_path: Local object image path. + output_path: Local output GLB path where the client saves the generated geometry. + """ + + image_path: str | Path + output_path: str | Path + + def to_form_data(self) -> dict[str, str]: + """Convert the request to the geometry server multipart form fields.""" + return {} + + +@dataclass(frozen=True) +class GeometryGenerationResult: + """Successful Geometry Generation result.""" + + geometry_path: str + + +@dataclass(frozen=True) +class GeometryGenerationServerResponse: + """Parsed successful response from the Geometry Generation server.""" + + ok: bool + result: GeometryGenerationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class GeometryGenerationError(ClientError): + """Geometry generation failure returned by the server.""" + + +@dataclass(frozen=True) +class MultiObjectGenerationServerRequest: + """Request sent to the Geometry Generation server (multi-object). + + Args: + image_path: Local scene RGB image path. + mask_paths: Local mask PNG file paths (one per object). + """ + + image_path: str | Path + mask_paths: list[Path] + + def to_form_data(self) -> dict[str, str]: + """Convert the request to the geometry server multipart form fields.""" + return {"json": "1"} + + +@dataclass(frozen=True) +class MultiObjectGenerationObject: + """Successful Multi-Object Geometry Generation result.""" + + name: str + geometry_path: str + rotation_quaternion_wxyz: list[float] + translation: list[float] + scale: list[float] + + +@dataclass(frozen=True) +class MultiObjectGenerationResult: + """Successful Multi-Object Geometry Generation result.""" + + objects: list[MultiObjectGenerationObject] + + @property + def geometry_paths(self) -> list[str]: + """Paths to the generated GLB files.""" + return [item.geometry_path for item in self.objects] + + +@dataclass(frozen=True) +class MultiObjectGenerationServerResponse: + """Parsed successful response from the Geometry Generation server.""" + + ok: bool + result: MultiObjectGenerationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class MultiObjectGenerationError(ClientError): + """Multi-object geometry generation failure returned by the server.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py new file mode 100644 index 00000000..c112bd3d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.client import ( + ImageGenerationClient, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.schemas import ( + ImageGenerationError, + ImageGenerationResult, + ImageGenerationServerRequest, + ImageGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageGenerationClient", + "ImageGenerationError", + "ImageGenerationResult", + "ImageGenerationServerRequest", + "ImageGenerationServerResponse", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py new file mode 100644 index 00000000..6f23d47b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py @@ -0,0 +1,117 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the Z-Image image generation server.""" + +from __future__ import annotations + +from pathlib import Path + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_required_strings, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.parser import ( + parse_generation_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.schemas import ( + ImageGenerationError, + ImageGenerationResult, + ImageGenerationServerRequest, + ImageGenerationServerResponse, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageGenerationClient", + "ImageGenerationError", + "ImageGenerationResult", + "ImageGenerationServerRequest", + "ImageGenerationServerResponse", +] + + +class ImageGenerationClient(BaseHttpClient): + """Client for making single-image Z-Image generation requests.""" + + def __init__( + self, + *, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + config_key: str = "zimage", + session: requests.Session | None = None, + ) -> None: + """Initialize the image generation client.""" + super().__init__( + config_key=config_key, + server_name="Image generation server", + base_url=base_url, + timeout_s=timeout_s, + config_path=config_path, + session=session, + ) + self.generate_single_object_path = str( + self.config.get("generate_single_object_path", "/generate.png") + ) + + def generate( + self, + request: ImageGenerationServerRequest, + *, + max_retries: int = 3, + ) -> ImageGenerationServerResponse | ImageGenerationError: + """Generate one image and save the returned PNG locally.""" + _validate_request(request) + url = f"{self.base_url}{self.generate_single_object_path}" + response = self.post_with_retries( + lambda: _post_generation_request(self, url, request), + max_retries=max_retries, + error_cls=ImageGenerationError, + request_label="image_generation", + ) + if isinstance(response, ImageGenerationError): + return response + return parse_generation_response(response, request) + + +def _validate_request(request: ImageGenerationServerRequest) -> None: + validate_required_strings( + { + "Image generation prompt": request.prompt, + "Image generation output_path": request.output_path, + } + ) + if not str(request.output_path).lower().endswith(".png"): + raise ValueError("Image generation output_path must be a PNG file path.") + + +def _post_generation_request( + client: ImageGenerationClient, + url: str, + request: ImageGenerationServerRequest, +) -> requests.Response: + return client.session.post( + url, + json=request.to_dict(), + timeout=(10, client.timeout_s), + ) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py new file mode 100644 index 00000000..a43ee030 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py @@ -0,0 +1,65 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_png_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.schemas import ( + ImageGenerationResult, + ImageGenerationServerRequest, + ImageGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = ["parse_generation_response"] + + +def parse_generation_response( + response: requests.Response, + request: ImageGenerationServerRequest, +) -> ImageGenerationServerResponse: + """Parse a Z-Image PNG response and save it to the request output path.""" + png_bytes = response.content + validate_png_response(response, png_bytes) + output_path = _write_png_output(request, png_bytes) + result = ImageGenerationResult(image_path=str(output_path)) + return ImageGenerationServerResponse( + ok=True, + status="ok", + result=result, + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _write_png_output( + request: ImageGenerationServerRequest, + png_bytes: bytes, +) -> Path: + output_path = Path(request.output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(png_bytes) + if not output_path.is_file(): + raise FileNotFoundError(f"Generated image was not written: {output_path}") + log_info(f"Generated image written: {output_path}") + return output_path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py new file mode 100644 index 00000000..09c845ba --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py @@ -0,0 +1,72 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError + +__all__ = [ + "ImageGenerationError", + "ImageGenerationResult", + "ImageGenerationServerRequest", + "ImageGenerationServerResponse", +] + + +@dataclass(frozen=True) +class ImageGenerationServerRequest: + """Request sent to the Z-Image server. + + Args: + prompt: Text prompt used to generate the image. + output_path: Local output PNG path where the client saves the response. + """ + + prompt: str + output_path: str | Path + + def to_dict(self) -> dict[str, Any]: + """Convert the request to the Z-Image server JSON payload.""" + return {"prompt": self.prompt} + + +@dataclass(frozen=True) +class ImageGenerationResult: + """Successful Z-Image generation result.""" + + image_path: str + + +@dataclass(frozen=True) +class ImageGenerationServerResponse: + """Parsed successful response from the Z-Image server.""" + + ok: bool + result: ImageGenerationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ImageGenerationError(ClientError): + """Image generation failure returned by the server.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py new file mode 100644 index 00000000..a503f287 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py @@ -0,0 +1,61 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.client import ( + ImageSegmentationClient, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, + ImageSegmentationError, + ImageSegmentationResult, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.utils import ( + apply_mask_to_alpha, + bbox_iou, + decode_rle_mask, + draw_labeled_bboxes, + draw_numbered_bboxes, + draw_numbered_masks, + is_usable_segmentation_candidate, + save_candidate_rgba_and_mask, + sort_segments_by_bbox, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageSegmentationCandidate", + "ImageSegmentationClient", + "ImageSegmentationError", + "ImageSegmentationResult", + "ImageSegmentationServerRequest", + "ImageSegmentationServerResponse", + "apply_mask_to_alpha", + "bbox_iou", + "decode_rle_mask", + "draw_labeled_bboxes", + "draw_numbered_bboxes", + "draw_numbered_masks", + "is_usable_segmentation_candidate", + "save_candidate_rgba_and_mask", + "sort_segments_by_bbox", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py new file mode 100644 index 00000000..1a880bb6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py @@ -0,0 +1,132 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the SAM3 image segmentation server.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_required_strings, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.parser import ( + parse_segmentation_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, + ImageSegmentationError, + ImageSegmentationResult, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageSegmentationCandidate", + "ImageSegmentationClient", + "ImageSegmentationError", + "ImageSegmentationResult", + "ImageSegmentationServerRequest", + "ImageSegmentationServerResponse", +] + + +class ImageSegmentationClient(BaseHttpClient): + """Client for making single-image SAM3 segmentation requests.""" + + def __init__( + self, + *, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + config_key: str = "sam3_segmentation", + session: requests.Session | None = None, + ) -> None: + """Initialize the image segmentation client.""" + super().__init__( + config_key=config_key, + server_name="Image segmentation server", + base_url=base_url, + timeout_s=timeout_s, + config_path=config_path, + session=session, + trust_env=False, + ) + self.segmentation_path = str( + self.config.get("segment_single_object_path", "/segment_single_object") + ) + + def segment( + self, + request: ImageSegmentationServerRequest, + *, + max_retries: int = 3, + ) -> ImageSegmentationServerResponse | ImageSegmentationError: + """Segment one image with a text prompt.""" + _validate_request(request) + url = f"{self.base_url}{self.segmentation_path}" + response = self.post_with_retries( + lambda: _post_segmentation_request(self, url, request), + max_retries=max_retries, + error_cls=ImageSegmentationError, + request_label="image_segmentation", + ) + if isinstance(response, ImageSegmentationError): + return response + return parse_segmentation_response(response, request) + + +def _validate_request(request: ImageSegmentationServerRequest) -> None: + validate_required_strings( + { + "Image segmentation image_path": request.image_path, + } + ) + image_path = Path(request.image_path).expanduser() + if not image_path.is_file(): + raise FileNotFoundError(f"Image segmentation input not found: {image_path}") + + +def _post_segmentation_request( + client: ImageSegmentationClient, + url: str, + request: ImageSegmentationServerRequest, +) -> requests.Response: + with _open_image_file(request.image_path) as image_file: + return client.session.post( + url, + data=request.to_form_data(), + files={ + "image": ( + Path(request.image_path).name, + image_file, + ) + }, + timeout=(10, client.timeout_s), + ) + + +def _open_image_file(image_path: str | Path) -> Any: + return Path(image_path).expanduser().resolve().open("rb") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py new file mode 100644 index 00000000..762a1b43 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py @@ -0,0 +1,218 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + parse_json_object_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, + ImageSegmentationResult, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, +) + +__all__ = ["parse_segmentation_response"] + +SERVER_NAME = "Image segmentation server" + + +def parse_segmentation_response( + response: requests.Response, + request: ImageSegmentationServerRequest, +) -> ImageSegmentationServerResponse: + """Parse a SAM3 server JSON response into typed segmentation records.""" + response_data = parse_json_object_response( + response, + server_name=SERVER_NAME, + ) + result = _parse_segmentation_result(response_data, request) + return ImageSegmentationServerResponse( + ok=bool(response_data.get("ok", True)), + status=_string_or_none(response_data.get("status")) or "ok", + result=result, + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _parse_segmentation_result( + response_data: dict[str, Any], + request: ImageSegmentationServerRequest, +) -> ImageSegmentationResult: + result_data = response_data.get("result") + if not isinstance(result_data, dict): + result_data = response_data.get("data") + if not isinstance(result_data, dict): + result_data = response_data + + return ImageSegmentationResult( + image_path=_string_or_none(result_data.get("image_path")) + or str(request.image_path), + prompt=_string_or_none(result_data.get("prompt")) or request.prompt, + candidates=_parse_candidates(result_data), + request_id=_string_or_none(result_data.get("request_id")), + elapsed_sec=_float_or_none(result_data.get("elapsed_sec")), + count=_int_or_none(result_data.get("count")), + image_width=_parse_image_width(result_data), + image_height=_parse_image_height(result_data), + box_format=_string_or_none(result_data.get("box_format")) or "xyxy", + mask_format=_string_or_none(result_data.get("mask_format")) or "rle", + ) + + +def _parse_candidates(result_data: dict[str, Any]) -> list[ImageSegmentationCandidate]: + for key in ("instances", "candidates", "segmentations", "detections"): + items = result_data.get(key) + if isinstance(items, list): + return [ + _parse_candidate_item(item, index) + for index, item in enumerate(items) + if isinstance(item, dict) + ] + + boxes = result_data.get("boxes", []) + scores = result_data.get("scores", []) + masks = result_data.get("masks", []) + if not isinstance(boxes, list): + return [] + + candidates: list[ImageSegmentationCandidate] = [] + for index, box in enumerate(boxes): + candidates.append( + ImageSegmentationCandidate( + candidate_id=f"candidate_{index}", + bbox_xyxy=_float_list(box), + score=_float_or_zero(_list_get(scores, index)), + mask_rle=_mask_or_none(_list_get(masks, index)), + ) + ) + return candidates + + +def _parse_candidate_item( + item: dict[str, Any], + index: int, +) -> ImageSegmentationCandidate: + known_keys = { + "candidate_id", + "id", + "index", + "bbox_xyxy", + "box_xyxy", + "box", + "bbox", + "score", + "mask_rle", + "mask", + "segmentation", + "mask_path", + "label", + } + mask_value = item.get("mask_rle") or item.get("mask") or item.get("segmentation") + return ImageSegmentationCandidate( + candidate_id=_string_or_none(item.get("candidate_id")) + or _string_or_none(item.get("id")) + or _index_id_or_none(item.get("index")) + or f"candidate_{index}", + bbox_xyxy=_float_list( + item.get("bbox_xyxy") + or item.get("box_xyxy") + or item.get("box") + or item.get("bbox") + ), + score=_float_or_zero(item.get("score")), + mask_rle=_mask_or_none(mask_value), + mask_path=_string_or_none(item.get("mask_path")), + label=_string_or_none(item.get("label")), + metadata={k: v for k, v in item.items() if k not in known_keys}, + ) + + +def _list_get(values: Any, index: int) -> Any: + if not isinstance(values, list) or index >= len(values): + return None + return values[index] + + +def _float_list(value: Any) -> list[float]: + if not isinstance(value, list): + return [] + parsed: list[float] = [] + for item in value: + try: + parsed.append(float(item)) + except (TypeError, ValueError): + continue + return parsed + + +def _float_or_zero(value: Any) -> float: + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def _float_or_none(value: Any) -> float | None: + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _int_or_none(value: Any) -> int | None: + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _string_or_none(value: Any) -> str | None: + return value if isinstance(value, str) else None + + +def _mask_or_none(value: Any) -> dict[str, Any] | None: + return value if isinstance(value, dict) else None + + +def _index_id_or_none(value: Any) -> str | None: + index = _int_or_none(value) + return f"candidate_{index}" if index is not None else None + + +def _parse_image_width(result_data: dict[str, Any]) -> int | None: + image_size = result_data.get("image_size") + if isinstance(image_size, dict): + width = _int_or_none(image_size.get("width")) + if width is not None: + return width + return _int_or_none(result_data.get("image_width")) + + +def _parse_image_height(result_data: dict[str, Any]) -> int | None: + image_size = result_data.get("image_size") + if isinstance(image_size, dict): + height = _int_or_none(image_size.get("height")) + if height is not None: + return height + return _int_or_none(result_data.get("image_height")) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py new file mode 100644 index 00000000..3945bf4b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py @@ -0,0 +1,103 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError + +__all__ = [ + "ImageSegmentationCandidate", + "ImageSegmentationError", + "ImageSegmentationResult", + "ImageSegmentationServerRequest", + "ImageSegmentationServerResponse", +] + + +@dataclass(frozen=True) +class ImageSegmentationServerRequest: + """Request sent to the SAM3 server. + + Args: + prompt: Short text concept prompt. + image_path: Local input image path. + """ + + prompt: str + image_path: str | Path + + def to_form_data(self) -> dict[str, str]: + """Convert the request to the SAM3 server multipart form fields.""" + return { + "prompt": self.prompt, + "score_threshold": "0.0", + "max_instances": "5", + } + + +@dataclass(frozen=True) +class ImageSegmentationCandidate: + """One SAM3 segmentation candidate for a prompted concept. + + SAM3 image inference returns parallel masks, boxes, and scores. The client + normalizes one aligned mask/box/score item into this candidate record. + """ + + candidate_id: str + bbox_xyxy: list[float] + score: float + mask_rle: dict[str, Any] | None = None + mask_path: str | None = None + label: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ImageSegmentationResult: + """Successful SAM3 segmentation result.""" + + image_path: str + prompt: str + candidates: list[ImageSegmentationCandidate] + request_id: str | None = None + elapsed_sec: float | None = None + count: int | None = None + image_width: int | None = None + image_height: int | None = None + box_format: str = "xyxy" + mask_format: str | None = None + + +@dataclass(frozen=True) +class ImageSegmentationServerResponse: + """Parsed successful response from the SAM3 server.""" + + ok: bool + result: ImageSegmentationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ImageSegmentationError(ClientError): + """Image segmentation failure returned by the server.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py new file mode 100644 index 00000000..83457358 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py @@ -0,0 +1,322 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from PIL import Image, ImageDraw, ImageFont + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = [ + "apply_mask_to_alpha", + "bbox_iou", + "decode_rle_mask", + "draw_labeled_bboxes", + "draw_numbered_bboxes", + "draw_numbered_masks", + "is_usable_segmentation_candidate", + "save_candidate_rgba_and_mask", + "sort_segments_by_bbox", +] + + +def decode_rle_mask(mask_rle: dict[str, Any]) -> Image.Image: + """Decode an uncompressed SAM3 RLE mask into a grayscale PIL image.""" + size = mask_rle.get("size") + counts = mask_rle.get("counts") + if not _is_size_pair(size): + raise ValueError("SAM3 mask_rle requires size=[height, width].") + if not isinstance(counts, list): + raise ValueError("SAM3 mask_rle counts must be an uncompressed list.") + + height = int(size[0]) + width = int(size[1]) + expected_pixels = height * width + starts_with = int(mask_rle.get("starts_with", 0)) + value = 255 if starts_with else 0 + pixels = bytearray(expected_pixels) + offset = 0 + + for count_value in counts: + count = int(count_value) + if count < 0: + raise ValueError("SAM3 mask_rle counts must be non-negative.") + next_offset = offset + count + if next_offset > expected_pixels: + raise ValueError("SAM3 mask_rle counts exceed the expected image size.") + if value: + pixels[offset:next_offset] = b"\xff" * count + offset = next_offset + value = 0 if value else 255 + + if offset != expected_pixels: + raise ValueError( + "SAM3 mask_rle counts do not cover the expected image size: " + f"{offset} != {expected_pixels}." + ) + return Image.frombytes("L", (width, height), bytes(pixels)) + + +def apply_mask_to_alpha( + image_path: str | Path, + mask: Image.Image, +) -> Image.Image: + """Return an RGBA image whose alpha channel is the provided mask.""" + image = Image.open(image_path).convert("RGBA") + alpha = mask.convert("L") + if alpha.size != image.size: + alpha = alpha.resize(image.size, Image.Resampling.NEAREST) + image.putalpha(alpha) + return image + + +def save_candidate_rgba_and_mask( + *, + image_path: str | Path, + candidate: ImageSegmentationCandidate, + output_dir: str | Path, + prefix: str | None = None, +) -> dict[str, str]: + """Save one candidate's mask image and RGBA image for SAM3D input.""" + if candidate.mask_rle is None: + raise ValueError(f"Candidate {candidate.candidate_id} has no mask_rle.") + + output_dir = Path(output_dir).expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + filename_prefix = prefix or candidate.candidate_id + mask_path = output_dir / f"{filename_prefix}_mask.png" + rgba_path = output_dir / f"{filename_prefix}_rgba.png" + + mask = decode_rle_mask(candidate.mask_rle) + mask.save(mask_path) + rgba = apply_mask_to_alpha(image_path, mask) + rgba.save(rgba_path) + log_info(f"SAM3 mask written: {mask_path}") + log_info(f"SAM3 RGBA image written: {rgba_path}") + return { + "mask_path": str(mask_path), + "rgba_path": str(rgba_path), + } + + +def draw_numbered_bboxes( + *, + image_path: str | Path, + segments: list[dict[str, Any]], + output_path: str | Path, +) -> Path: + """Draw numbered bounding boxes for visual segmentation verification.""" + image = Image.open(image_path).convert("RGB") + draw = ImageDraw.Draw(image) + font = _load_label_font(image.width) + for index, segment in enumerate(segments, start=1): + _draw_bbox_label( + draw=draw, + bbox_xyxy=segment["bbox_xyxy"], + label=str(index), + font=font, + ) + + output_path = Path(output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(output_path) + return output_path + + +def draw_numbered_masks( + *, + image_path: str | Path, + segments: list[dict[str, Any]], + output_path: str | Path, +) -> Path: + """Draw numbered segmentation masks for visual segmentation verification.""" + image = Image.open(image_path).convert("RGBA") + overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) + draw_overlay = ImageDraw.Draw(overlay) + font = _load_label_font(image.width) + colors = [ + (255, 64, 64, 110), + (64, 160, 255, 110), + (64, 220, 120, 110), + (255, 190, 64, 110), + (190, 96, 255, 110), + (255, 96, 190, 110), + ] + + for index, segment in enumerate(segments, start=1): + mask_rle = segment.get("mask_rle") + if mask_rle is None: + continue + mask = decode_rle_mask(mask_rle) + if mask.size != image.size: + mask = mask.resize(image.size, Image.Resampling.NEAREST) + color = colors[(index - 1) % len(colors)] + color_layer = Image.new("RGBA", image.size, color) + transparent = Image.new("RGBA", image.size) + overlay.alpha_composite(Image.composite(color_layer, transparent, mask)) + _draw_mask_label( + draw=draw_overlay, + segment=segment, + mask=mask, + label=str(index), + font=font, + ) + + result = Image.alpha_composite(image, overlay).convert("RGB") + output_path = Path(output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + result.save(output_path) + return output_path + + +def draw_labeled_bboxes( + *, + image_path: str | Path, + boxes: list[dict[str, Any]], + output_path: str | Path, +) -> Path: + """Draw labeled bounding boxes for final segmentation visualization.""" + image = Image.open(image_path).convert("RGB") + draw = ImageDraw.Draw(image) + font = _load_label_font(image.width) + for box in boxes: + x1, y1, x2, y2 = box["bbox_xyxy"] + label = str(box["label"]) + _draw_bbox_label( + draw=draw, + bbox_xyxy=[x1, y1, x2, y2], + label=label, + font=font, + ) + + output_path = Path(output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(output_path) + return output_path + + +def sort_segments_by_bbox(segments: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Sort segments by top-left image position, then by descending score.""" + return sorted( + segments, + key=lambda segment: ( + float(segment["bbox_xyxy"][1]), + float(segment["bbox_xyxy"][0]), + -float(segment["score"]), + ), + ) + + +def bbox_iou(bbox_a: list[float], bbox_b: list[float]) -> float: + """Compute IoU for two xyxy bounding boxes.""" + ax1, ay1, ax2, ay2 = bbox_a + bx1, by1, bx2, by2 = bbox_b + ix1 = max(ax1, bx1) + iy1 = max(ay1, by1) + ix2 = min(ax2, bx2) + iy2 = min(ay2, by2) + iw = max(0.0, ix2 - ix1) + ih = max(0.0, iy2 - iy1) + intersection = iw * ih + area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1) + area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1) + union = area_a + area_b - intersection + return intersection / union if union > 0 else 0.0 + + +def is_usable_segmentation_candidate( + candidate: ImageSegmentationCandidate, +) -> bool: + """Return whether a candidate has the fields needed by downstream stages.""" + return candidate.mask_rle is not None and len(candidate.bbox_xyxy) == 4 + + +def _is_size_pair(value: Any) -> bool: + return ( + isinstance(value, list) + and len(value) == 2 + and isinstance(value[0], int) + and isinstance(value[1], int) + ) + + +def _load_label_font(image_width: int) -> ImageFont.ImageFont: + font_size = max(24, image_width // 80) + try: + return ImageFont.truetype("DejaVuSans-Bold.ttf", font_size) + except OSError: + return ImageFont.load_default() + + +def _draw_bbox_label( + *, + draw: ImageDraw.ImageDraw, + bbox_xyxy: list[float], + label: str, + font: ImageFont.ImageFont, +) -> None: + x1, y1, x2, y2 = bbox_xyxy + draw.rectangle((x1, y1, x2, y2), outline="red", width=6) + label_box = draw.textbbox((x1, y1), label, font=font) + padding = 8 + draw.rectangle( + ( + label_box[0] - padding, + label_box[1] - padding, + label_box[2] + padding, + label_box[3] + padding, + ), + fill="red", + ) + draw.text((x1, y1), label, fill="white", font=font) + + +def _draw_mask_label( + *, + draw: ImageDraw.ImageDraw, + segment: dict[str, Any], + mask: Image.Image, + label: str, + font: ImageFont.ImageFont, +) -> None: + bbox = mask.getbbox() + if bbox is None: + x1, y1, x2, y2 = segment["bbox_xyxy"] + x = float(x1 + x2) * 0.5 + y = float(y1 + y2) * 0.5 + else: + x1, y1, x2, y2 = bbox + x = float(x1 + x2) * 0.5 + y = float(y1 + y2) * 0.5 + + label_box = draw.textbbox((x, y), label, font=font) + padding = 8 + draw.rectangle( + ( + label_box[0] - padding, + label_box[1] - padding, + label_box[2] + padding, + label_box[3] + padding, + ), + fill="red", + ) + draw.text((x, y), label, fill="white", font=font) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py new file mode 100644 index 00000000..32f8ef6c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager.manager import ( + BlenderRenderingManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager.schemas import ( + RenderObjectScenesRequest, + RenderObjectScenesResult, +) + +__all__ = [ + "BlenderRenderingManager", + "RenderObjectScenesRequest", + "RenderObjectScenesResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py new file mode 100644 index 00000000..8617f297 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py @@ -0,0 +1,175 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import subprocess +import tempfile +from pathlib import Path + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager.schemas import ( + RenderObjectScenesRequest, + RenderObjectScenesResult, +) + +__all__ = ["BlenderRenderingManager"] + + +class BlenderRenderingManager: + """Render simulation scenes through Blender's background CLI.""" + + def render_object_scenes( + self, + request: RenderObjectScenesRequest, + ) -> RenderObjectScenesResult: + """Render a front-oblique view of a collection of Z-up scenes.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(prefix="p2s_blender_render_") as tmp_dir: + glb_paths = self._export_y_up_scenes( + request.object_scenes, + Path(tmp_dir), + ) + self._render_glbs( + glb_paths, + output_path, + timeout_seconds=request.timeout_seconds, + ) + return RenderObjectScenesResult(output_path=output_path) + + @staticmethod + def _export_y_up_scenes( + object_scenes: list[tuple[str, object]], + output_dir: Path, + ) -> list[Path]: + z_up_to_y_up = np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + paths: list[Path] = [] + for object_id, scene in object_scenes: + path = output_dir / f"{object_id}_render.glb" + copied = scene.copy() + copied.apply_transform(z_up_to_y_up) + copied.export(path) + paths.append(path) + return paths + + @classmethod + def _render_glbs( + cls, + glb_paths: list[Path], + output_path: Path, + *, + timeout_seconds: int, + ) -> None: + script = cls._front_oblique_script(glb_paths, output_path) + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + encoding="utf-8", + delete=False, + ) as file: + script_path = Path(file.name) + file.write(script) + try: + subprocess.run( + ["blender", "--background", "--python", str(script_path)], + check=True, + timeout=timeout_seconds, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except subprocess.CalledProcessError as exc: + stderr_tail = (exc.stderr or "").strip()[-4000:] + raise RuntimeError( + f"Blender front-oblique render failed:\n{stderr_tail}" + ) from exc + finally: + script_path.unlink(missing_ok=True) + if not output_path.is_file(): + raise FileNotFoundError(f"Blender render was not written: {output_path}") + + @staticmethod + def _front_oblique_script(glb_paths: list[Path], output_path: Path) -> str: + object_paths_json = json.dumps([str(path.resolve()) for path in glb_paths]) + output_path_json = json.dumps(str(output_path.resolve())) + return f"""\ +import bpy +import json +import mathutils + +object_paths = json.loads({object_paths_json!r}) +output_path = json.loads({output_path_json!r}) +bpy.ops.object.select_all(action="SELECT") +bpy.ops.object.delete() +for path in object_paths: + bpy.ops.import_scene.gltf(filepath=path) +mesh_objects = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] +if not mesh_objects: + raise RuntimeError("No mesh objects were imported.") +min_corner = mathutils.Vector((float("inf"), float("inf"), float("inf"))) +max_corner = mathutils.Vector((float("-inf"), float("-inf"), float("-inf"))) +for obj in mesh_objects: + for corner in obj.bound_box: + world = obj.matrix_world @ mathutils.Vector(corner) + min_corner.x = min(min_corner.x, world.x) + min_corner.y = min(min_corner.y, world.y) + min_corner.z = min(min_corner.z, world.z) + max_corner.x = max(max_corner.x, world.x) + max_corner.y = max(max_corner.y, world.y) + max_corner.z = max(max_corner.z, world.z) +center = (min_corner + max_corner) * 0.5 +span_x = max(max_corner.x - min_corner.x, 1.0e-4) +span_y = max(max_corner.y - min_corner.y, 1.0e-4) +span_z = max(max_corner.z - min_corner.z, 1.0e-4) +camera_data = bpy.data.cameras.new("front_oblique_camera") +camera = bpy.data.objects.new("front_oblique_camera", camera_data) +bpy.context.collection.objects.link(camera) +view_distance = max(span_x, span_y, span_z) * 2.4 +camera.location = (center.x, center.y - view_distance, center.z + view_distance * 0.75) +camera.rotation_euler = (center - camera.location).to_track_quat("-Z", "Y").to_euler() +camera_data.type = "ORTHO" +camera_data.ortho_scale = max(span_x, span_y, span_z * 1.8) * 1.35 +bpy.context.scene.camera = camera +light_data = bpy.data.lights.new("front_oblique_area_light", "AREA") +light = bpy.data.objects.new("front_oblique_area_light", light_data) +bpy.context.collection.objects.link(light) +light.location = camera.location +light_data.energy = 600.0 +light_data.size = max(span_x, span_y) * 2.0 +bpy.context.scene.world.color = (1.0, 1.0, 1.0) +try: + bpy.context.scene.render.engine = "BLENDER_EEVEE_NEXT" +except Exception: + bpy.context.scene.render.engine = "BLENDER_EEVEE" +bpy.context.scene.render.resolution_x = 768 +bpy.context.scene.render.resolution_y = 768 +bpy.context.scene.render.film_transparent = False +bpy.context.scene.view_settings.view_transform = "Standard" +bpy.context.scene.view_settings.look = "Medium High Contrast" +bpy.context.scene.render.filepath = output_path +bpy.ops.render.render(write_still=True) +""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py new file mode 100644 index 00000000..e3f986c7 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = ["RenderObjectScenesRequest", "RenderObjectScenesResult"] + + +@dataclass(frozen=True) +class RenderObjectScenesRequest: + """Request to render internal Z-up object scenes with Blender.""" + + object_scenes: list[tuple[str, Any]] + output_path: Path + timeout_seconds: int = 180 + + +@dataclass(frozen=True) +class RenderObjectScenesResult: + """Result of rendering object scenes.""" + + output_path: Path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py new file mode 100644 index 00000000..ef8b9315 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py @@ -0,0 +1,45 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager.manager import ( + GeometryGenerationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager.schemas import ( + GeometryGenerationRequest, + GeometryGenerationResult, + MultiObjectGenerationObject, + MultiObjectGenerationRequest, + MultiObjectGenerationResult, + RgbaImageToGeometryRequest, + RgbaImagesToGeometriesObject, + RgbaImagesToGeometriesRequest, + RgbaImagesToGeometriesResult, +) + +__all__ = [ + "GeometryGenerationManager", + "GeometryGenerationRequest", + "GeometryGenerationResult", + "MultiObjectGenerationObject", + "MultiObjectGenerationRequest", + "MultiObjectGenerationResult", + "RgbaImageToGeometryRequest", + "RgbaImagesToGeometriesObject", + "RgbaImagesToGeometriesRequest", + "RgbaImagesToGeometriesResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py new file mode 100644 index 00000000..d30ea09a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py @@ -0,0 +1,209 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +from PIL import Image + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client import ( + GeometryGenerationClient, + GeometryGenerationError, + GeometryGenerationServerRequest, + MultiObjectGenerationError, + MultiObjectGenerationServerRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager.schemas import ( + GeometryGenerationRequest, + GeometryGenerationResult, + MultiObjectGenerationObject, + MultiObjectGenerationRequest, + MultiObjectGenerationResult, + RgbaImageToGeometryRequest, + RgbaImagesToGeometriesObject, + RgbaImagesToGeometriesRequest, + RgbaImagesToGeometriesResult, +) + + +class GeometryGenerationManager: + """Geometry generation domain operations.""" + + def __init__(self, *, client: GeometryGenerationClient | None = None) -> None: + self.client = client or GeometryGenerationClient() + + def generate_single_object_mesh( + self, + request: GeometryGenerationRequest, + ) -> GeometryGenerationResult: + image_path = request.image_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + _validate_single_object_request(image_path=image_path, output_path=output_path) + + response = self.client.generate( + GeometryGenerationServerRequest( + image_path=image_path, + output_path=output_path, + ), + ) + if isinstance(response, GeometryGenerationError): + raise RuntimeError(response.error_message) + + return GeometryGenerationResult( + output_path=Path(response.result.geometry_path).expanduser().resolve(), + ) + + def generate_multi_object_meshes( + self, + request: MultiObjectGenerationRequest, + ) -> MultiObjectGenerationResult: + image_path = request.image_path.expanduser().resolve() + output_dir = request.output_dir.expanduser().resolve() + _validate_multi_object_request( + image_path=image_path, + mask_paths=request.mask_paths, + output_dir=output_dir, + ) + + response = self.client.generate_multiple_objects( + MultiObjectGenerationServerRequest( + image_path=image_path, + mask_paths=[p.expanduser().resolve() for p in request.mask_paths], + ), + output_dir=output_dir, + ) + if isinstance(response, MultiObjectGenerationError): + raise RuntimeError(response.error_message) + + objects = [ + MultiObjectGenerationObject( + name=item.name, + geometry_path=Path(item.geometry_path).expanduser().resolve(), + rotation_quaternion_wxyz=item.rotation_quaternion_wxyz, + translation=item.translation, + scale=item.scale, + ) + for item in response.result.objects + ] + return MultiObjectGenerationResult(objects=objects) + + def convert_rgba_image_to_geometry( + self, + request: RgbaImageToGeometryRequest, + ) -> Path: + image_path = request.image_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + _validate_rgba_image(image_path) + + result = self.generate_single_object_mesh( + GeometryGenerationRequest(image_path=image_path, output_path=output_path) + ) + return _postprocess_mesh(result.output_path) + + def convert_rgba_images_to_geometries( + self, + request: RgbaImagesToGeometriesRequest, + ) -> RgbaImagesToGeometriesResult: + image_path = request.image_path.expanduser().resolve() + output_dir = request.output_dir.expanduser().resolve() + _validate_rgba_images_request(image_path, request.mask_paths) + + result = self.generate_multi_object_meshes( + MultiObjectGenerationRequest( + image_path=image_path, + mask_paths=request.mask_paths, + output_dir=output_dir, + ) + ) + objects = [ + RgbaImagesToGeometriesObject( + name=item.name, + geometry_path=_postprocess_mesh(item.geometry_path), + rotation_quaternion_wxyz=item.rotation_quaternion_wxyz, + translation=item.translation, + scale=item.scale, + ) + for item in result.objects + ] + return RgbaImagesToGeometriesResult(objects=objects) + + +def _validate_single_object_request(*, image_path: Path, output_path: Path) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"Geometry generation input not found: {image_path}") + if output_path.suffix.lower() != ".glb": + raise ValueError("Geometry generation output_path must be a GLB file path.") + if output_path.exists() and output_path.is_dir(): + raise ValueError(f"Geometry generation output_path is a directory: {output_path}") + + +def _validate_multi_object_request( + *, + image_path: Path, + mask_paths: list[Path], + output_dir: Path, +) -> None: + if not image_path.is_file(): + raise FileNotFoundError( + f"Multi-object geometry generation input not found: {image_path}" + ) + if not mask_paths: + raise ValueError("mask_paths must be non-empty.") + for mask_path in mask_paths: + mask_path_resolved = mask_path.expanduser().resolve() + if not mask_path_resolved.is_file(): + raise FileNotFoundError( + f"Multi-object geometry mask not found: {mask_path_resolved}" + ) + if output_dir.exists() and not output_dir.is_dir(): + raise ValueError( + f"Multi-object geometry output_dir is not a directory: {output_dir}" + ) + + +def _validate_rgba_image(image_path: Path) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"RGBA image not found: {image_path}") + + with Image.open(image_path) as image: + if image.mode in {"RGBA", "LA"}: + return + if image.mode == "P" and "transparency" in image.info: + return + raise ValueError( + "Geometry tool requires an image with an alpha channel, " + f"got mode={image.mode!r}: {image_path}" + ) + + +def _validate_rgba_images_request( + image_path: Path, + mask_paths: list[Path], +) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"Scene image not found: {image_path}") + with Image.open(image_path): + pass + if not mask_paths: + raise ValueError("mask_paths must be non-empty.") + for mask_path in mask_paths: + if not mask_path.expanduser().resolve().is_file(): + raise FileNotFoundError(f"Mask not found: {mask_path}") + + +def _postprocess_mesh(mesh_path: Path) -> Path: + return mesh_path.expanduser().resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py new file mode 100644 index 00000000..81f6816a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py @@ -0,0 +1,105 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class RgbaImageToGeometryRequest: + """Request for converting one RGBA asset image to one mesh.""" + + image_path: Path + output_path: Path + + +@dataclass(frozen=True) +class RgbaImagesToGeometriesRequest: + """Request for converting a scene image with object masks to meshes.""" + + image_path: Path + mask_paths: list[Path] + output_dir: Path + + +@dataclass(frozen=True) +class RgbaImagesToGeometriesObject: + """One generated object mesh and its scene placement.""" + + name: str + geometry_path: Path + rotation_quaternion_wxyz: list[float] + translation: list[float] + scale: list[float] + + +@dataclass(frozen=True) +class RgbaImagesToGeometriesResult: + """Result of multi-object geometry generation.""" + + objects: list[RgbaImagesToGeometriesObject] + + @property + def geometry_paths(self) -> list[Path]: + return [item.geometry_path for item in self.objects] + + +@dataclass(frozen=True) +class GeometryGenerationRequest: + """Request for generating one object mesh from one image.""" + + image_path: Path + output_path: Path + + +@dataclass(frozen=True) +class GeometryGenerationResult: + """Generated mesh path.""" + + output_path: Path + + +@dataclass(frozen=True) +class MultiObjectGenerationRequest: + """Request to generate multiple object meshes from one image and masks.""" + + image_path: Path + mask_paths: list[Path] + output_dir: Path + + +@dataclass(frozen=True) +class MultiObjectGenerationObject: + """One generated object mesh and its scene placement.""" + + name: str + geometry_path: Path + rotation_quaternion_wxyz: list[float] + translation: list[float] + scale: list[float] + + +@dataclass(frozen=True) +class MultiObjectGenerationResult: + """Result of multi-object geometry generation.""" + + objects: list[MultiObjectGenerationObject] + + @property + def geometry_paths(self) -> list[Path]: + return [item.geometry_path for item in self.objects] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py new file mode 100644 index 00000000..7d70c81c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py @@ -0,0 +1,69 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.manager import ( + DEFAULT_INPUT_UP_AXIS, + DEFAULT_UP_AXIS, + GeometryManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.schemas import ( + AlignToAxisRequest, + AlignToAxisResult, + AlignXYLongAxisRequest, + AlignXYLongAxisResult, + CenterMeshRequest, + CenterMeshResult, + ConvertUpAxisRequest, + ConvertUpAxisResult, + DetectTabletopRequest, + DetectTabletopResult, + ExportMeshRequest, + ExportMeshResult, + LoadMeshRequest, + LoadMeshResult, + NormalizeRequest, + NormalizeResult, + PlaceAbovePlaneRequest, + PlaceAbovePlaneResult, + SupportPlaneCandidate, +) + +__all__ = [ + "AlignToAxisRequest", + "AlignToAxisResult", + "AlignXYLongAxisRequest", + "AlignXYLongAxisResult", + "CenterMeshRequest", + "CenterMeshResult", + "ConvertUpAxisRequest", + "ConvertUpAxisResult", + "DEFAULT_INPUT_UP_AXIS", + "DEFAULT_UP_AXIS", + "DetectTabletopRequest", + "DetectTabletopResult", + "ExportMeshRequest", + "ExportMeshResult", + "GeometryManager", + "LoadMeshRequest", + "LoadMeshResult", + "NormalizeRequest", + "NormalizeResult", + "PlaceAbovePlaneRequest", + "PlaceAbovePlaneResult", + "SupportPlaneCandidate", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py new file mode 100644 index 00000000..2e5c88ab --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py @@ -0,0 +1,584 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Geometry manager for mesh I/O, transforms, and tabletop detection.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import trimesh + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.schemas import ( + AlignToAxisRequest, + AlignToAxisResult, + AlignXYLongAxisRequest, + AlignXYLongAxisResult, + CenterMeshRequest, + CenterMeshResult, + ConvertUpAxisRequest, + ConvertUpAxisResult, + DetectTabletopRequest, + DetectTabletopResult, + ExportMeshRequest, + ExportMeshResult, + LoadMeshRequest, + LoadMeshResult, + NormalizeRequest, + NormalizeResult, + PlaceAbovePlaneRequest, + PlaceAbovePlaneResult, + SupportPlaneCandidate, +) + +__all__ = ["GeometryManager"] + +DEFAULT_INPUT_UP_AXIS = [0.0, 1.0, 0.0] +DEFAULT_UP_AXIS = [0.0, 0.0, 1.0] + + +class GeometryManager: + """Manager for mesh geometry operations. + + Provides typed methods for mesh I/O, axis conversion, bounding-box + transforms, tabletop plane detection, and PCA alignment, following + the same pattern as service clients. + """ + + + @staticmethod + def load_mesh(request: LoadMeshRequest) -> LoadMeshResult: + """Load a GLB/mesh file as one Trimesh object.""" + mesh_path = request.mesh_path.expanduser().resolve() + if not mesh_path.is_file(): + raise FileNotFoundError(f"Mesh file not found: {mesh_path}") + + loaded = trimesh.load(mesh_path, force=None) + if isinstance(loaded, trimesh.Scene): + geometries = [ + g + for g in loaded.dump(concatenate=False) + if hasattr(g, "vertices") and hasattr(g, "faces") + ] + if not geometries: + raise ValueError(f"Scene contains no mesh geometry: {mesh_path}") + return LoadMeshResult(mesh=trimesh.util.concatenate(geometries)) + return LoadMeshResult(mesh=loaded) + + @staticmethod + def export_mesh(request: ExportMeshRequest) -> ExportMeshResult: + """Export a mesh and return the resolved output path.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + request.mesh.export(output_path) + if not output_path.is_file(): + raise FileNotFoundError(f"Mesh was not written: {output_path}") + return ExportMeshResult(output_path=output_path) + + + @staticmethod + def convert_up_axis(request: ConvertUpAxisRequest) -> ConvertUpAxisResult: + """Convert a mesh from one up-axis convention to another.""" + mesh = GeometryManager._align_vector_to_axis( + request.mesh, + source_axis=request.input_up_axis or DEFAULT_INPUT_UP_AXIS, + target_axis=request.output_up_axis or DEFAULT_UP_AXIS, + ) + return ConvertUpAxisResult(mesh=mesh) + + @staticmethod + def center_by_bbox(request: CenterMeshRequest) -> CenterMeshResult: + """Center a mesh by its bounding box.""" + GeometryManager._validate_mesh(request.mesh) + + bounds = np.asarray(request.mesh.bounds, dtype=float) + if bounds.shape != (2, 3): + raise ValueError("Mesh bounds must have shape (2, 3).") + + bbox_center = (bounds[0] + bounds[1]) * 0.5 + centered = request.mesh.copy() + centered.apply_translation(-bbox_center) + return CenterMeshResult( + mesh=centered, + bbox_center=[float(v) for v in bbox_center], + ) + + @staticmethod + def align_to_axis(request: AlignToAxisRequest) -> AlignToAxisResult: + """Rotate a mesh so a source vector aligns to a target axis.""" + mesh = GeometryManager._align_vector_to_axis( + request.mesh, + source_axis=request.source_axis, + target_axis=request.target_axis, + ) + return AlignToAxisResult(mesh=mesh) + + @staticmethod + def place_above_plane( + request: PlaceAbovePlaneRequest, + ) -> PlaceAbovePlaneResult: + """Translate a mesh so its AABB bottom is above the XY plane.""" + if request.clearance < 0.0: + raise ValueError("clearance must be non-negative.") + + bounds = np.asarray(request.mesh.bounds, dtype=float) + if bounds.shape != (2, 3): + raise ValueError("Mesh bounds must have shape (2, 3).") + + min_z = float(bounds[0][2]) + placed = request.mesh.copy() + placed.apply_translation([0.0, 0.0, request.clearance - min_z]) + return PlaceAbovePlaneResult(mesh=placed) + + @staticmethod + def normalize(request: NormalizeRequest) -> NormalizeResult: + """Scale a mesh so its longest bounding-box axis equals target_size.""" + if request.target_size <= 0.0: + raise ValueError("target_size must be positive.") + + extents = np.asarray( + request.mesh.bounding_box_oriented.primitive.extents, dtype=float + ) + scale_factor = request.target_size / float(np.max(extents)) + normalized = request.mesh.copy() + normalized.apply_scale(scale_factor) + return NormalizeResult(mesh=normalized, scale_factor=scale_factor) + + @staticmethod + def mesh_aabb_size(mesh: Any) -> Any: + """Return a mesh AABB size vector.""" + bounds = np.asarray(mesh.bounds, dtype=np.float64) + if bounds.shape != (2, 3): + raise ValueError("Mesh bounds must have shape (2, 3).") + size = bounds[1] - bounds[0] + if np.any(size <= 0.0): + raise ValueError(f"Mesh AABB size must be positive, got {size.tolist()}.") + return size + + @staticmethod + def bbox_ratio(size: Any) -> Any: + """Return bbox dimensions normalized by the largest axis.""" + size = np.asarray(size, dtype=np.float64) + max_size = float(np.max(size)) + if max_size <= 0.0: + raise ValueError("bbox size max must be positive.") + return size / max_size + + @staticmethod + def best_axis_bbox_scale_match( + *, + source_size_cm: Any, + target_size_cm: Any, + ) -> dict[str, Any]: + """Match target bbox axes to source axes and return a scale candidate.""" + source = np.asarray(source_size_cm, dtype=np.float64) + target = np.asarray(target_size_cm, dtype=np.float64) + if source.shape != (3,) or target.shape != (3,): + raise ValueError("source_size_cm and target_size_cm must have shape (3,).") + if np.any(source <= 0.0) or np.any(target <= 0.0): + raise ValueError("source_size_cm and target_size_cm must be positive.") + + source_ratio = GeometryManager.bbox_ratio(source) + best: dict[str, Any] | None = None + for permutation in [ + (0, 1, 2), + (0, 2, 1), + (1, 0, 2), + (1, 2, 0), + (2, 0, 1), + (2, 1, 0), + ]: + target_perm = target[list(permutation)] + target_ratio = GeometryManager.bbox_ratio(target_perm) + ratio_error = GeometryManager._mean_abs_log_ratio_error( + source_ratio, + target_ratio, + ) + per_axis_scale = target_perm / source + candidate = { + "target_permutation": list(permutation), + "source_size_cm": source.tolist(), + "target_size_cm_original_order": target.tolist(), + "target_size_cm_matched_to_source_axes": target_perm.tolist(), + "source_ratio": source_ratio.tolist(), + "target_ratio_matched": target_ratio.tolist(), + "per_axis_scale": per_axis_scale.tolist(), + "scale_factor": float(np.median(per_axis_scale)), + "shape_ratio_error": float(ratio_error), + } + if best is None or ratio_error < float(best["shape_ratio_error"]): + best = candidate + if best is None: + raise ValueError("Failed to match bbox axes.") + return best + + @staticmethod + def scene_to_mesh(scene: Any) -> Any: + """Convert a trimesh Scene or mesh-like object to one mesh.""" + if isinstance(scene, trimesh.Trimesh): + return scene + dumped = scene.dump(concatenate=True) + if isinstance(dumped, trimesh.Trimesh): + return dumped + meshes = [item for item in dumped if isinstance(item, trimesh.Trimesh)] + if not meshes: + raise ValueError("Scene contains no mesh geometry.") + return trimesh.util.concatenate(meshes) + + @staticmethod + def detect_tabletop( + request: DetectTabletopRequest, + ) -> DetectTabletopResult: + """Detect the most likely tabletop plane in a mesh.""" + candidates = GeometryManager._find_support_plane_candidates( + request.mesh, + normal_angle_tol_deg=request.normal_angle_tol_deg, + plane_distance_tol=request.plane_distance_tol, + min_area_ratio=request.min_area_ratio, + max_candidates=request.max_candidates, + ) + selected = GeometryManager._select_tabletop_plane(candidates) + oriented_normal = GeometryManager._orient_plane_normal( + request.mesh, + plane_normal=selected.normal, + plane_center=selected.center, + ) + return DetectTabletopResult( + selected=selected, + oriented_normal=oriented_normal, + candidates=candidates, + ) + + + @staticmethod + def align_xy_long_axis( + request: AlignXYLongAxisRequest, + ) -> AlignXYLongAxisResult: + """Rotate a table so its XY-projected long axis aligns with the Y axis.""" + vertices = np.asarray(request.mesh.vertices, dtype=float) + xy_vertices = GeometryManager._select_xy_vertices( + request.mesh, vertices, request.face_indices + ) + if xy_vertices.shape[0] < 2: + raise ValueError( + "Mesh must contain at least two vertices for PCA alignment." + ) + + centered_xy = xy_vertices - np.mean(xy_vertices, axis=0) + covariance = centered_xy.T @ centered_xy / max(centered_xy.shape[0] - 1, 1) + eigenvalues, eigenvectors = np.linalg.eigh(covariance) + long_axis = eigenvectors[:, int(np.argmax(eigenvalues))] + if float(np.linalg.norm(long_axis)) == 0.0: + raise ValueError("PCA long axis is degenerate.") + + axis_angle = float(np.arctan2(long_axis[1], long_axis[0])) + rotation_angle = GeometryManager._minimal_angle_to_align_axis( + axis_angle, np.pi / 2.0 + ) + rotation = GeometryManager._z_axis_rotation_transform(rotation_angle) + aligned = request.mesh.copy() + aligned.apply_transform(rotation) + return AlignXYLongAxisResult( + mesh=aligned, + yaw_angle_degrees=float(np.rad2deg(rotation_angle)), + ) + + + @staticmethod + def _align_vector_to_axis( + mesh: Any, + *, + source_axis: list[float], + target_axis: list[float], + ) -> Any: + source = GeometryManager._normalize( + np.asarray(source_axis, dtype=float) + ) + target = GeometryManager._normalize( + np.asarray(target_axis, dtype=float) + ) + if np.linalg.norm(source) == 0: + raise ValueError("source_axis must be non-zero.") + if np.linalg.norm(target) == 0: + raise ValueError("target_axis must be non-zero.") + + transform = GeometryManager._rotation_transform_between_vectors( + source, target + ) + aligned = mesh.copy() + aligned.apply_transform(transform) + return aligned + + + @staticmethod + def _find_support_plane_candidates( + mesh: Any, + *, + normal_angle_tol_deg: float = 8.0, + plane_distance_tol: float | None = None, + min_area_ratio: float = 0.02, + max_candidates: int = 24, + ) -> list[SupportPlaneCandidate]: + GeometryManager._validate_mesh(mesh) + + normals = np.asarray(mesh.face_normals, dtype=float) + centers = np.asarray(mesh.triangles_center, dtype=float) + areas = np.asarray(mesh.area_faces, dtype=float) + vertices = np.asarray(mesh.vertices, dtype=float) + total_area = float(np.sum(areas)) + if total_area <= 0: + raise ValueError("Mesh has no positive face area.") + + if plane_distance_tol is None: + extent = float( + np.linalg.norm(np.asarray(mesh.extents, dtype=float)) + ) + plane_distance_tol = max(extent * 0.01, 1e-4) + + cos_tol = float(np.cos(np.deg2rad(normal_angle_tol_deg))) + min_area = total_area * min_area_ratio + order = np.argsort(-areas) + used = np.zeros(len(areas), dtype=bool) + candidates: list[SupportPlaneCandidate] = [] + + for seed_index in order: + if used[seed_index]: + continue + seed_normal = GeometryManager._normalize(normals[seed_index]) + if np.linalg.norm(seed_normal) == 0: + used[seed_index] = True + continue + + seed_center = centers[seed_index] + seed_offset = float(np.dot(seed_normal, seed_center)) + normal_match = normals @ seed_normal >= cos_tol + offsets = centers @ seed_normal + plane_match = np.abs(offsets - seed_offset) <= plane_distance_tol + face_mask = normal_match & plane_match & ~used + face_indices = np.flatnonzero(face_mask) + if len(face_indices) == 0: + used[seed_index] = True + continue + + used[face_indices] = True + area = float(np.sum(areas[face_indices])) + if area < min_area: + continue + + weighted_normal = GeometryManager._normalize( + np.sum( + normals[face_indices] * areas[face_indices, None], axis=0 + ), + ) + center = ( + np.sum( + centers[face_indices] * areas[face_indices, None], axis=0 + ) + / area + ) + candidate = GeometryManager._build_candidate( + normal=weighted_normal, + center=center, + area=area, + face_indices=face_indices, + vertices=vertices, + ) + candidates.append(candidate) + + candidates.sort(key=lambda c: c.score, reverse=True) + return candidates[:max_candidates] + + @staticmethod + def _select_tabletop_plane( + candidates: list[SupportPlaneCandidate], + ) -> SupportPlaneCandidate: + if not candidates: + raise ValueError("No support-plane candidates were found.") + return max(candidates, key=lambda c: c.score) + + @staticmethod + def _orient_plane_normal( + mesh: Any, + *, + plane_normal: list[float], + plane_center: list[float], + ) -> list[float]: + GeometryManager._validate_mesh(mesh) + + normal = GeometryManager._normalize( + np.asarray(plane_normal, dtype=float) + ) + center = np.asarray(plane_center, dtype=float) + if np.linalg.norm(normal) == 0: + raise ValueError("plane_normal must be non-zero.") + + vertices = np.asarray(mesh.vertices, dtype=float) + signed_distances = (vertices - center) @ normal + positive_mask = signed_distances > 1e-6 + negative_mask = signed_distances < -1e-6 + positive_score = float(np.sum(np.abs(signed_distances[positive_mask]))) + negative_score = float(np.sum(np.abs(signed_distances[negative_mask]))) + + if positive_score > negative_score: + normal = -normal + return [float(v) for v in normal] + + @staticmethod + def _build_candidate( + *, + normal: Any, + center: Any, + area: float, + face_indices: Any, + vertices: Any, + ) -> SupportPlaneCandidate: + signed_distances = (vertices - center) @ normal + below_mask = signed_distances < -1e-6 + above_mask = signed_distances > 1e-6 + below_count = int(np.count_nonzero(below_mask)) + above_count = int(np.count_nonzero(above_mask)) + below_score = float(np.sum(np.abs(signed_distances[below_mask]))) + above_score = float(np.sum(np.abs(signed_distances[above_mask]))) + + smaller_score = min(below_score, above_score) + larger_score = max(below_score, above_score) + asymmetry_score = min( + (larger_score + 1e-9) / (smaller_score + 1e-9), 10.0 + ) + score = float(area * asymmetry_score) + return SupportPlaneCandidate( + normal=[float(v) for v in normal], + center=[float(v) for v in center], + area=area, + face_indices=[int(i) for i in face_indices], + below_vertex_count=below_count, + above_vertex_count=above_count, + below_area_score=below_score, + above_area_score=above_score, + score=score, + ) + + + @staticmethod + def _select_xy_vertices( + mesh: Any, + vertices: Any, + face_indices: list[int] | None, + ) -> Any: + if face_indices is None: + return vertices[:, :2] + + faces = np.asarray(mesh.faces, dtype=int) + selected_faces = faces[np.asarray(face_indices, dtype=int)] + selected_vertex_indices = np.unique(selected_faces.reshape(-1)) + return vertices[selected_vertex_indices, :2] + + @staticmethod + def _minimal_angle_to_align_axis( + source_angle: float, target_angle: float + ) -> float: + candidates = [ + GeometryManager._wrap_to_pi(target_angle - source_angle), + GeometryManager._wrap_to_pi( + target_angle + 3.141592653589793 - source_angle + ), + ] + return min(candidates, key=abs) + + @staticmethod + def _wrap_to_pi(angle: float) -> float: + two_pi = 2.0 * 3.141592653589793 + return (angle + 3.141592653589793) % two_pi - 3.141592653589793 + + @staticmethod + def _z_axis_rotation_transform(angle: float) -> Any: + c = float(np.cos(angle)) + s = float(np.sin(angle)) + transform = np.eye(4) + transform[:3, :3] = np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=float, + ) + return transform + + + @staticmethod + def _rotation_transform_between_vectors( + source: Any, target: Any + ) -> Any: + dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) + transform = np.eye(4) + if dot > 1.0 - 1e-8: + return transform + if dot < -1.0 + 1e-8: + axis = GeometryManager._orthogonal_axis(source) + rotation = GeometryManager._axis_angle_rotation(axis, np.pi) + else: + axis = GeometryManager._normalize(np.cross(source, target)) + angle = float(np.arccos(dot)) + rotation = GeometryManager._axis_angle_rotation(axis, angle) + transform[:3, :3] = rotation + return transform + + @staticmethod + def _axis_angle_rotation(axis: Any, angle: float) -> Any: + axis = GeometryManager._normalize(axis) + x, y, z = axis + c = float(np.cos(angle)) + s = float(np.sin(angle)) + one_c = 1.0 - c + return np.array( + [ + [c + x * x * one_c, x * y * one_c - z * s, x * z * one_c + y * s], + [y * x * one_c + z * s, c + y * y * one_c, y * z * one_c - x * s], + [z * x * one_c - y * s, z * y * one_c + x * s, c + z * z * one_c], + ], + dtype=float, + ) + + @staticmethod + def _orthogonal_axis(vector: Any) -> Any: + axis = np.array([1.0, 0.0, 0.0]) + if abs(float(np.dot(vector, axis))) > 0.9: + axis = np.array([0.0, 1.0, 0.0]) + return GeometryManager._normalize(np.cross(vector, axis)) + + @staticmethod + def _normalize(vector: Any) -> Any: + norm = float(np.linalg.norm(vector)) + if norm == 0.0: + return vector + return vector / norm + + @staticmethod + def _mean_abs_log_ratio_error(lhs: Any, rhs: Any) -> float: + eps = 1.0e-6 + lhs = np.maximum(np.asarray(lhs, dtype=np.float64), eps) + rhs = np.maximum(np.asarray(rhs, dtype=np.float64), eps) + return float(np.mean(np.abs(np.log(lhs / rhs)))) + + @staticmethod + def _validate_mesh(mesh: Any) -> None: + if not hasattr(mesh, "vertices") or not hasattr(mesh, "faces"): + raise ValueError("Loaded geometry is not a mesh.") + if len(mesh.vertices) == 0 or len(mesh.faces) == 0: + raise ValueError("Mesh must contain vertices and faces.") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py new file mode 100644 index 00000000..be502fbb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py @@ -0,0 +1,567 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + DetectTabletopRequest, + GeometryManager, +) + +__all__ = [ + "_compose_json_matrices", + "_compose_simready_to_aligned_matrix", + "_decompose_transform_matrix", + "_aabb_bottom_to_xy_plane_transform", + "_aabb_center", + "_compose_sam3d_multi_object_transform", + "_copy_scene_with_transform", + "_estimate_support_normal", + "_glb_to_sam3d_local_matrix", + "_load_scene_with_transform", + "_matrix_from_json", + "_quaternion_wxyz_to_matrix", + "_rotation_between_vectors", + "_row_linear_to_trimesh_matrix", + "_scale_transform", + "_scene_to_mesh", + "_support_normal_flip_transform", + "_transform_point", + "_validate_vector", + "_xy_aabb_center", + "_xy_aabb_size", + "_z_up_to_glb_y_up_transform", + "_z_yaw_transform", +] + + +def _compose_json_matrices(*values: Any) -> list[list[float]]: + matrices = [np.asarray(value, dtype=np.float64) for value in values] + if any(matrix.shape != (4, 4) for matrix in matrices): + return [] + result = np.eye(4, dtype=np.float64) + for matrix in matrices: + result = result @ matrix + return result.tolist() + + +def _compose_simready_to_aligned_matrix( + *, raw_to_aligned_matrix: Any, raw_to_simready_matrix: Any +) -> list[list[float]]: + raw_to_aligned = np.asarray(raw_to_aligned_matrix, dtype=np.float64) + raw_to_simready = np.asarray(raw_to_simready_matrix, dtype=np.float64) + if raw_to_aligned.shape != (4, 4) or raw_to_simready.shape != (4, 4): + return [] + try: + return (raw_to_aligned @ np.linalg.inv(raw_to_simready)).tolist() + except np.linalg.LinAlgError: + return [] + + +def _decompose_transform_matrix(matrix_value: Any) -> dict[str, Any]: + matrix = np.asarray(matrix_value, dtype=np.float64) + if matrix.shape != (4, 4): + return {"translation": [], "rotation_matrix": [], "scale": []} + linear = matrix[:3, :3] + scale = np.linalg.norm(linear, axis=0) + rotation = np.eye(3, dtype=np.float64) + for index in range(3): + if scale[index] > 1.0e-12: + rotation[:, index] = linear[:, index] / scale[index] + return { + "translation": matrix[:3, 3].tolist(), + "rotation_matrix": rotation.tolist(), + "scale": scale.tolist(), + } + + +def _support_normal_flip_transform( + *, + support_normal: np.ndarray, + normal_alignment: np.ndarray, +) -> np.ndarray: + flipped_normal_alignment = _rotation_between_vectors( + -support_normal, + np.array([0.0, 0.0, 1.0], dtype=np.float64), + ) + return flipped_normal_alignment @ np.linalg.inv(normal_alignment) + + +def _z_yaw_transform(yaw_degrees: float) -> np.ndarray: + angle = np.deg2rad(yaw_degrees) + c = float(np.cos(angle)) + s = float(np.sin(angle)) + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + return transform + + +def _z_up_to_glb_y_up_transform() -> np.ndarray: + return _rotation_between_vectors( + np.array([0.0, 0.0, 1.0], dtype=np.float64), + np.array([0.0, 1.0, 0.0], dtype=np.float64), + ) + + +def _copy_scene_with_transform(scene: Any, transform: np.ndarray) -> Any: + copied = scene.copy() + copied.apply_transform(transform) + return copied + + +def _matrix_from_json(value: Any, *, name: str) -> np.ndarray: + matrix = np.asarray(value, dtype=np.float64) + if matrix.shape != (4, 4): + raise ValueError(f"{name} must be a 4x4 matrix.") + return matrix + + +def _load_scene_with_transform( + *, + path: Path, + transform: np.ndarray, + trimesh: Any, +) -> Any: + scene = trimesh.load(path, force="scene") + scene.apply_transform(transform) + return scene + + +def _scene_to_mesh(scene: Any, *, trimesh: Any) -> Any: + if isinstance(scene, trimesh.Trimesh): + return scene + dumped = scene.dump(concatenate=True) + if isinstance(dumped, trimesh.Trimesh): + return dumped + meshes = [item for item in dumped if isinstance(item, trimesh.Trimesh)] + if not meshes: + raise ValueError("Scene contains no mesh geometry.") + return trimesh.util.concatenate(meshes) + + +def _estimate_support_normal(mesh: Any) -> np.ndarray: + geom = GeometryManager() + try: + detect_result = geom.detect_tabletop(DetectTabletopRequest(mesh=mesh)) + normal = np.asarray(detect_result.oriented_normal, dtype=np.float64) + norm = np.linalg.norm(normal) + if norm > 0.0: + return normal / norm + except Exception: + pass + + normals = np.asarray(mesh.face_normals, dtype=np.float64) + areas = np.asarray(mesh.area_faces, dtype=np.float64) + if normals.size == 0 or areas.size == 0: + return np.array([0.0, 0.0, 1.0], dtype=np.float64) + normal = normals[int(np.argmax(areas))] + norm = np.linalg.norm(normal) + if norm == 0.0: + return np.array([0.0, 0.0, 1.0], dtype=np.float64) + return normal / norm + + +def _rotation_between_vectors(source: np.ndarray, target: np.ndarray) -> np.ndarray: + source = source / np.linalg.norm(source) + target = target / np.linalg.norm(target) + cross = np.cross(source, target) + dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) + if np.linalg.norm(cross) < 1e-8: + if dot > 0.0: + return np.eye(4, dtype=np.float64) + axis = np.array([1.0, 0.0, 0.0], dtype=np.float64) + if abs(float(np.dot(source, axis))) > 0.9: + axis = np.array([0.0, 1.0, 0.0], dtype=np.float64) + cross = np.cross(source, axis) + axis = cross / np.linalg.norm(cross) + angle = float(np.arccos(dot)) + skew = np.array( + [ + [0.0, -axis[2], axis[1]], + [axis[2], 0.0, -axis[0]], + [-axis[1], axis[0], 0.0], + ], + dtype=np.float64, + ) + rotation = ( + np.eye(3, dtype=np.float64) + + np.sin(angle) * skew + + (1.0 - np.cos(angle)) * (skew @ skew) + ) + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = rotation + return transform + + +def _transform_point(transform: np.ndarray, point: np.ndarray) -> np.ndarray: + homogeneous = np.ones(4, dtype=np.float64) + homogeneous[:3] = point + return (transform @ homogeneous)[:3] + + +def _aabb_center(bounds: np.ndarray) -> np.ndarray: + return 0.5 * ( + np.asarray(bounds[0], dtype=np.float64) + + np.asarray(bounds[1], dtype=np.float64) + ) + + +def _xy_aabb_center(bounds: np.ndarray) -> np.ndarray: + bounds = np.asarray(bounds, dtype=np.float64) + return 0.5 * (bounds[0, :2] + bounds[1, :2]) + + +def _xy_aabb_size(bounds: np.ndarray) -> np.ndarray: + bounds = np.asarray(bounds, dtype=np.float64) + return np.maximum(bounds[1, :2] - bounds[0, :2], 1e-6) + + +def _aabb_bottom_to_xy_plane_transform(bounds: np.ndarray) -> np.ndarray: + bounds = np.asarray(bounds, dtype=np.float64) + min_z = float(bounds[0][2]) + transform = np.eye(4, dtype=np.float64) + transform[:3, 3] = [0.0, 0.0, -min_z] + return transform + + +def _scale_transform(scale: float) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] *= float(scale) + return transform + + +def _compose_sam3d_multi_object_transform( + *, + rotation_quaternion_wxyz: list[float], + translation: list[float], + scale: list[float], +) -> np.ndarray: + """Compose the transform equivalent to the old baked multi-object export.""" + rotation = _quaternion_wxyz_to_matrix(rotation_quaternion_wxyz) + scale_matrix = np.diag(_validate_vector(scale, expected_len=3, name="scale")) + linear_row = _glb_to_sam3d_local_matrix() @ scale_matrix @ rotation + return _row_linear_to_trimesh_matrix( + linear_row=linear_row, + translation=translation, + ) + + +def _row_linear_to_trimesh_matrix( + *, + linear_row: np.ndarray, + translation: list[float], +) -> np.ndarray: + """Convert a row-vector linear transform to trimesh's 4x4 matrix format.""" + translation_vector = _validate_vector( + translation, + expected_len=3, + name="translation", + ) + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = linear_row.T + transform[:3, 3] = translation_vector + return transform + + +def _validate_vector( + values: list[float], + *, + expected_len: int, + name: str, +) -> np.ndarray: + """Validate and convert a numeric vector.""" + if len(values) != expected_len: + raise ValueError(f"{name} must have {expected_len} values") + return np.asarray(values, dtype=np.float64) + + +def _glb_to_sam3d_local_matrix() -> np.ndarray: + """Return the basis conversion used by the old baked multi-object exporter.""" + return np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, -1.0, 0.0], + ], + dtype=np.float64, + ) + + +def _quaternion_wxyz_to_matrix(quaternion: list[float]) -> np.ndarray: + """Convert a wxyz quaternion to a 3x3 rotation matrix.""" + if len(quaternion) != 4: + raise ValueError("rotation_quaternion_wxyz must have 4 values") + w, x, y, z = [float(v) for v in quaternion] + norm = np.sqrt(w * w + x * x + y * y + z * z) + if norm == 0.0: + raise ValueError("rotation quaternion must be non-zero") + w, x, y, z = w / norm, x / norm, y / norm, z / norm + return np.array( + [ + [ + 1.0 - 2.0 * (y * y + z * z), + 2.0 * (x * y - z * w), + 2.0 * (x * z + y * w), + ], + [ + 2.0 * (x * y + z * w), + 1.0 - 2.0 * (x * x + z * z), + 2.0 * (y * z - x * w), + ], + [ + 2.0 * (x * z - y * w), + 2.0 * (y * z + x * w), + 1.0 - 2.0 * (x * x + y * y), + ], + ], + dtype=np.float64, + ) + + +def _detect_table_fit_support_quad( + mesh: Any, + *, + target_aspect: float, +) -> dict[str, Any]: + geom = GeometryManager() + detect = geom.detect_tabletop(DetectTabletopRequest(mesh=mesh)) + faces = np.asarray(mesh.faces, dtype=np.int64) + vertices = np.asarray(mesh.vertices, dtype=np.float64) + support_vertices = vertices[ + np.unique(faces[np.asarray(detect.selected.face_indices, dtype=np.int64)]) + ] + hull_xy = _table_fit_convex_hull_2d(support_vertices[:, :2]) + quad = _largest_centered_table_fit_inscribed_rect( + hull_xy, + target_aspect=max(float(target_aspect), 1.0e-6), + ) + center_z = float(np.mean(support_vertices[:, 2])) + return { + "method": "sampled_centered_inscribed_rectangle_on_support_convex_hull", + "normal": detect.oriented_normal, + "area": float(detect.selected.area), + "center": [quad["center_xy"][0], quad["center_xy"][1], center_z], + "center_xy": quad["center_xy"], + "size_xy": quad["size_xy"], + "yaw_radians": quad["yaw_radians"], + "yaw_degrees": float(np.rad2deg(quad["yaw_radians"])), + "corners_xy": quad["corners_xy"], + "support_hull_xy": hull_xy.tolist(), + } + + +def _largest_centered_table_fit_inscribed_rect( + hull_xy: np.ndarray, + *, + target_aspect: float, + yaw_samples: int = 180, +) -> dict[str, Any]: + if hull_xy.shape[0] < 3: + raise ValueError("Support hull must contain at least 3 points.") + best: dict[str, Any] | None = None + centers = [ + np.mean(hull_xy, axis=0), + 0.5 * (np.min(hull_xy, axis=0) + np.max(hull_xy, axis=0)), + ] + for yaw in np.linspace(0.0, np.pi, yaw_samples, endpoint=False): + rot = _table_fit_rot2(-yaw) + inv_rot = _table_fit_rot2(yaw) + rotated_hull = hull_xy @ rot.T + for center_world in centers: + center = center_world @ rot.T + lo = 0.0 + bbox_size = np.max(rotated_hull, axis=0) - np.min(rotated_hull, axis=0) + hi = float(max(bbox_size[0] / target_aspect, bbox_size[1], 1.0e-6)) + for _ in range(40): + mid = 0.5 * (lo + hi) + width = target_aspect * mid + depth = mid + corners = _table_fit_rect_corners( + center=center, + width=width, + depth=depth, + ) + corners_world = corners @ inv_rot.T + if all( + _table_fit_point_in_convex_polygon(point, hull_xy) + for point in corners_world + ): + lo = mid + else: + hi = mid + width = target_aspect * lo + depth = lo + area = width * depth + corners_world = ( + _table_fit_rect_corners(center=center, width=width, depth=depth) + @ inv_rot.T + ) + candidate = { + "center_xy": center_world.tolist(), + "size_xy": [float(width), float(depth)], + "yaw_radians": float(yaw), + "corners_xy": corners_world.tolist(), + "area": float(area), + } + if best is None or area > float(best["area"]): + best = candidate + if best is None: + raise ValueError("Failed to estimate an inscribed support rectangle.") + return best + + +def _load_table_fit_scene_internal_z( + path: Path, + *, + trimesh: Any, + y_to_z: np.ndarray, +) -> Any: + if not path.is_file(): + raise FileNotFoundError(f"GLB not found: {path}") + scene = trimesh.load(path, force="scene") + scene.apply_transform(y_to_z) + return scene + + +def _table_fit_scene_union_bounds(scenes: list[Any], *, trimesh: Any) -> np.ndarray: + bounds = [ + np.asarray(_scene_to_mesh(scene, trimesh=trimesh).bounds, dtype=np.float64) + for scene in scenes + ] + return np.vstack( + [ + np.vstack([item[0] for item in bounds]).min(axis=0), + np.vstack([item[1] for item in bounds]).max(axis=0), + ] + ) + + +def _table_fit_bounds_xy_manifest( + bounds: np.ndarray, + *, + unit_scale: float, +) -> dict[str, Any]: + min_xy = bounds[0, :2] * unit_scale + max_xy = bounds[1, :2] * unit_scale + size_xy = max_xy - min_xy + center_xy = 0.5 * (min_xy + max_xy) + return { + "unit": "cm", + "min_xy": min_xy.tolist(), + "max_xy": max_xy.tolist(), + "center_xy": center_xy.tolist(), + "size_xy": size_xy.tolist(), + "area": float(size_xy[0] * size_xy[1]), + } + + +def _table_fit_uniform_xy_scale_transform( + *, + center_xy: np.ndarray, + scale: float, +) -> np.ndarray: + center = np.eye(4, dtype=np.float64) + center[:3, 3] = [float(center_xy[0]), float(center_xy[1]), 0.0] + uncenter = np.eye(4, dtype=np.float64) + uncenter[:3, 3] = [-float(center_xy[0]), -float(center_xy[1]), 0.0] + scale_mat = np.eye(4, dtype=np.float64) + scale_mat[0, 0] = float(scale) + scale_mat[1, 1] = float(scale) + return center @ scale_mat @ uncenter + + +def _table_fit_safe_positive_ratio(numerator: float, denominator: float) -> float: + return max(float(numerator) / max(float(denominator), 1.0e-6), 1.0e-6) + + +def _table_fit_convex_hull_2d(points: np.ndarray) -> np.ndarray: + unique = sorted({(float(x), float(y)) for x, y in np.asarray(points)[:, :2]}) + if len(unique) <= 1: + return np.asarray(unique, dtype=np.float64) + + def cross( + o: tuple[float, float], + a: tuple[float, float], + b: tuple[float, float], + ) -> float: + return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0]) + + lower: list[tuple[float, float]] = [] + for point in unique: + while len(lower) >= 2 and cross(lower[-2], lower[-1], point) <= 0.0: + lower.pop() + lower.append(point) + upper: list[tuple[float, float]] = [] + for point in reversed(unique): + while len(upper) >= 2 and cross(upper[-2], upper[-1], point) <= 0.0: + upper.pop() + upper.append(point) + return np.asarray(lower[:-1] + upper[:-1], dtype=np.float64) + + +def _table_fit_point_in_convex_polygon( + point: np.ndarray, + polygon: np.ndarray, +) -> bool: + previous = 0.0 + for index in range(len(polygon)): + a = polygon[index] + b = polygon[(index + 1) % len(polygon)] + cross = float(np.cross(b - a, point - a)) + if abs(cross) < 1.0e-9: + continue + if previous == 0.0: + previous = cross + elif cross * previous < -1.0e-9: + return False + return True + + +def _table_fit_rect_corners( + *, + center: np.ndarray, + width: float, + depth: float, +) -> np.ndarray: + half_w = 0.5 * float(width) + half_d = 0.5 * float(depth) + return np.asarray( + [ + [center[0] - half_w, center[1] - half_d], + [center[0] + half_w, center[1] - half_d], + [center[0] + half_w, center[1] + half_d], + [center[0] - half_w, center[1] + half_d], + ], + dtype=np.float64, + ) + + +def _table_fit_rot2(angle: float) -> np.ndarray: + c = float(np.cos(angle)) + s = float(np.sin(angle)) + return np.asarray([[c, -s], [s, c]], dtype=np.float64) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py new file mode 100644 index 00000000..f001720f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py @@ -0,0 +1,201 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "AlignToAxisRequest", + "AlignToAxisResult", + "AlignXYLongAxisRequest", + "AlignXYLongAxisResult", + "CenterMeshRequest", + "NormalizeRequest", + "NormalizeResult", + "CenterMeshResult", + "ConvertUpAxisRequest", + "ConvertUpAxisResult", + "DetectTabletopRequest", + "DetectTabletopResult", + "ExportMeshRequest", + "ExportMeshResult", + "LoadMeshRequest", + "LoadMeshResult", + "PlaceAbovePlaneRequest", + "PlaceAbovePlaneResult", + "SupportPlaneCandidate", +] + + +@dataclass(frozen=True) +class SupportPlaneCandidate: + """Candidate planar tabletop support surface.""" + + normal: list[float] + center: list[float] + area: float + face_indices: list[int] + below_vertex_count: int + above_vertex_count: int + below_area_score: float + above_area_score: float + score: float + + +@dataclass(frozen=True) +class LoadMeshRequest: + """Request to load a GLB/mesh file.""" + + mesh_path: Path + + +@dataclass(frozen=True) +class LoadMeshResult: + """Result of loading a mesh file.""" + + mesh: Any + + +@dataclass(frozen=True) +class ExportMeshRequest: + """Request to export a mesh to a file.""" + + mesh: Any + output_path: Path + + +@dataclass(frozen=True) +class ExportMeshResult: + """Result of exporting a mesh.""" + + output_path: Path + + +@dataclass(frozen=True) +class ConvertUpAxisRequest: + """Request to convert a mesh from one up-axis convention to another.""" + + mesh: Any + input_up_axis: list[float] | None = None + output_up_axis: list[float] | None = None + + +@dataclass(frozen=True) +class ConvertUpAxisResult: + """Result of converting a mesh up-axis.""" + + mesh: Any + + +@dataclass(frozen=True) +class CenterMeshRequest: + """Request to center a mesh by its bounding-box center.""" + + mesh: Any + + +@dataclass(frozen=True) +class CenterMeshResult: + """Result of centering a mesh.""" + + mesh: Any + bbox_center: list[float] + + +@dataclass(frozen=True) +class AlignToAxisRequest: + """Request to rotate a mesh so a source axis aligns to a target axis.""" + + mesh: Any + source_axis: list[float] + target_axis: list[float] + + +@dataclass(frozen=True) +class AlignToAxisResult: + """Result of aligning a mesh vector to an axis.""" + + mesh: Any + + +@dataclass(frozen=True) +class PlaceAbovePlaneRequest: + """Request to translate a mesh so its AABB bottom sits above the XY plane.""" + + mesh: Any + clearance: float = 0.01 + + +@dataclass(frozen=True) +class PlaceAbovePlaneResult: + """Result of placing a mesh above the XY plane.""" + + mesh: Any + + +@dataclass(frozen=True) +class DetectTabletopRequest: + """Request to detect the most likely tabletop plane in a mesh.""" + + mesh: Any + normal_angle_tol_deg: float = 8.0 + plane_distance_tol: float | None = None + min_area_ratio: float = 0.02 + max_candidates: int = 24 + + +@dataclass(frozen=True) +class DetectTabletopResult: + """Result of detecting the tabletop plane with oriented normal.""" + + selected: SupportPlaneCandidate + oriented_normal: list[float] + candidates: list[SupportPlaneCandidate] + + +@dataclass(frozen=True) +class AlignXYLongAxisRequest: + """Request to align a mesh XY long axis to the Y axis via PCA.""" + + mesh: Any + face_indices: list[int] | None = None + + +@dataclass(frozen=True) +class AlignXYLongAxisResult: + """Result of PCA yaw alignment.""" + + mesh: Any + yaw_angle_degrees: float + + +@dataclass(frozen=True) +class NormalizeRequest: + """Request to normalize a mesh to a target size.""" + + mesh: Any + target_size: float = 1.0 + + +@dataclass(frozen=True) +class NormalizeResult: + """Result of normalizing a mesh.""" + + mesh: Any + scale_factor: float diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py new file mode 100644 index 00000000..c7a200a5 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py @@ -0,0 +1,35 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager.manager import ( + ASSET_IMAGE_PROMPT_SUFFIX, + ImageGenerationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager.schemas import ( + ImageGenerationRequest, + ImageGenerationResult, + TextToAssetImageRequest, +) + +__all__ = [ + "ASSET_IMAGE_PROMPT_SUFFIX", + "ImageGenerationManager", + "ImageGenerationRequest", + "ImageGenerationResult", + "TextToAssetImageRequest", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py new file mode 100644 index 00000000..6406f74d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py @@ -0,0 +1,76 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client import ( + ImageGenerationClient, + ImageGenerationError, + ImageGenerationServerRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager.schemas import ( + ImageGenerationRequest, + ImageGenerationResult, + TextToAssetImageRequest, +) + +ASSET_IMAGE_PROMPT_SUFFIX = ( + "single isolated object, centered, fully visible, " + "on a high contrast colored background. " +) + + +class ImageGenerationManager: + """Image generation domain operations.""" + + def __init__(self, *, client: ImageGenerationClient | None = None) -> None: + self.client = client or ImageGenerationClient() + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResult: + output_path = request.output_path.expanduser().resolve() + response = self.client.generate( + ImageGenerationServerRequest( + prompt=request.prompt, + output_path=output_path, + ), + ) + if isinstance(response, ImageGenerationError): + raise RuntimeError(response.error_message) + + return ImageGenerationResult( + image_path=Path(response.result.image_path).expanduser().resolve(), + ) + + def generate_asset_image_from_text( + self, + request: TextToAssetImageRequest, + ) -> Path: + prompt = _build_asset_image_prompt(request.prompt) + result = self.generate_image( + ImageGenerationRequest(prompt=prompt, output_path=request.output_path) + ) + return result.image_path + + +def _build_asset_image_prompt(prompt: str) -> str: + prompt = prompt.strip() + if not prompt: + raise ValueError("Text-to-asset image prompt must be non-empty.") + if ASSET_IMAGE_PROMPT_SUFFIX in prompt: + return prompt + return f"{prompt}, {ASSET_IMAGE_PROMPT_SUFFIX}" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py new file mode 100644 index 00000000..ac4a9cd7 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py @@ -0,0 +1,43 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class TextToAssetImageRequest: + """Request for generating an asset image from a text prompt.""" + + prompt: str + output_path: Path + + +@dataclass(frozen=True) +class ImageGenerationRequest: + """Request for generating one image from text.""" + + prompt: str + output_path: Path + + +@dataclass(frozen=True) +class ImageGenerationResult: + """Generated image path.""" + + image_path: Path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py new file mode 100644 index 00000000..2ad8f11a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py @@ -0,0 +1,29 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.alignment import ( + _export_support_aligned_layout_glbs, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.manifests import ( + _write_multi_object_layout_manifests, +) + +__all__ = [ + "_export_support_aligned_layout_glbs", + "_write_multi_object_layout_manifests", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py new file mode 100644 index 00000000..6d7084f4 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py @@ -0,0 +1,537 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.prompts import ( + build_up_down_flip_check_messages, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( + GlobalMetricScaleRequest, + MetricScaleManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.schemas import ( + UP_DOWN_FLIP_CHECK_JSON_SCHEMA, +) + +UP_DOWN_FLIP_CHECK_CONFIDENCE_THRESHOLD = 0.6 +UNIFIED_SCENE_STEP = "unified_scene" +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager import ( + BlenderRenderingManager, + RenderObjectScenesRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager import ( + MatplotlibManager, + RenderImageComparisonRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _aabb_center, + _copy_scene_with_transform, + _estimate_support_normal, + _load_scene_with_transform, + _matrix_from_json, + _rotation_between_vectors, + _scale_transform, + _scene_to_mesh, + _support_normal_flip_transform, + _xy_aabb_center, + _z_up_to_glb_y_up_transform, + _z_yaw_transform, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _object_scenes_xy_aabb_manifest, + _settle_and_pack_object_footprints, +) + +__all__ = ["_export_support_aligned_layout_glbs"] + + +def _export_support_aligned_layout_glbs( + *, + table: dict[str, Any], + objects: list[dict[str, Any]], + spatial_relations: list[dict[str, Any]], + original_image_path: Path | None, + llm: Any | None, + output_dir: Path, + output_root: Path, +) -> dict[str, Any]: + """Export layout-baked GLBs aligned by support normal and left-right order.""" + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Support-aligned GLB export requires trimesh.") from exc + + output_dir.mkdir(parents=True, exist_ok=True) + support_reference_path = _resolve_generated_path( + table.get("support_reference_geometry_path") or table.get("raw_geometry_path"), + output_root, + ) + object_paths = [ + ( + str(item["id"]), + _resolve_generated_path(item.get("raw_geometry_path"), output_root), + item.get("transform_matrix"), + ) + for item in objects + if item.get("raw_geometry_path") and item.get("transform_matrix") + ] + if not support_reference_path.is_file(): + raise FileNotFoundError( + f"Support reference table GLB not found: {support_reference_path}" + ) + support_reference_transform = _matrix_from_json( + table.get("support_reference_transform_matrix") + or table.get("transform_matrix"), + name="table.support_reference_transform_matrix", + ) + if not object_paths: + raise ValueError("No raw object GLBs with transform matrices available.") + + support_reference_scene = trimesh.load(support_reference_path, force="scene") + support_reference_scene.apply_transform(support_reference_transform) + object_scenes = [ + ( + object_id, + _load_scene_with_transform( + path=path, + transform=_matrix_from_json( + transform, + name=f"{object_id}.transform_matrix", + ), + trimesh=trimesh, + ), + ) + for object_id, path, transform in object_paths + ] + table_mesh = _scene_to_mesh(support_reference_scene, trimesh=trimesh) + support_normal = _estimate_support_normal(table_mesh) + normal_alignment = _rotation_between_vectors( + support_normal, + np.array([0.0, 0.0, 1.0]), + ) + + for _, scene in object_scenes: + scene.apply_transform(normal_alignment) + + object_bounds = [ + _scene_to_mesh(scene, trimesh=trimesh).bounds for _, scene in object_scenes + ] + clutter_bounds = np.vstack( + [ + np.vstack([bounds[0] for bounds in object_bounds]).min(axis=0), + np.vstack([bounds[1] for bounds in object_bounds]).max(axis=0), + ] + ) + clutter_center = 0.5 * (clutter_bounds[0] + clutter_bounds[1]) + center_transform = np.eye(4, dtype=np.float64) + center_transform[:3, 3] = [ + -float(clutter_center[0]), + -float(clutter_center[1]), + -float(clutter_center[2]), + ] + + for _, scene in object_scenes: + scene.apply_transform(center_transform) + + alignment_candidates = _build_up_down_alignment_candidates( + object_scenes=object_scenes, + support_normal=support_normal, + normal_alignment=normal_alignment, + spatial_relations=spatial_relations, + trimesh=trimesh, + ) + vlm_check_dir = output_dir / "vlm_up_down_flip_check" + up_down_flip_check_result = _run_aligned_up_down_flip_vlm_check( + llm=llm, + original_image_path=original_image_path, + normal_object_scenes=alignment_candidates["normal"]["object_scenes"], + flipped_object_scenes=alignment_candidates["flipped"]["object_scenes"], + output_dir=vlm_check_dir, + ) + selected_variant = str( + up_down_flip_check_result.get("selected_variant") or "normal" + ) + if selected_variant not in alignment_candidates: + selected_variant = "normal" + selected_candidate = alignment_candidates[selected_variant] + object_scenes = selected_candidate["object_scenes"] + selected_extra_transform = selected_candidate["extra_transform"] + apply_up_down_flip = selected_variant == "flipped" + + global_metric_scale = MetricScaleManager.compute_global_from_object_scenes( + GlobalMetricScaleRequest( + objects=objects, + object_scenes=object_scenes, + ) + ) + metric_scale_transform = _scale_transform(global_metric_scale["scale_factor"]) + if float(global_metric_scale["scale_factor"]) != 1.0: + for _, scene in object_scenes: + scene.apply_transform(metric_scale_transform) + + footprint_result = _settle_and_pack_object_footprints( + object_scenes=object_scenes, + output_dir=output_dir / "footprint_layout", + output_root=output_root, + trimesh=trimesh, + ) + object_scenes = footprint_result["object_scenes"] + + output_axis_transform = _z_up_to_glb_y_up_transform() + object_outputs = [] + for object_id, scene in object_scenes: + object_output = output_dir / f"{object_id}_aligned.glb" + _copy_scene_with_transform(scene, output_axis_transform).export(object_output) + object_outputs.append( + { + "id": object_id, + "aligned_geometry_path": relative_path(str(object_output), output_root), + } + ) + + alignment_matrix = selected_extra_transform @ center_transform @ normal_alignment + scaled_alignment_matrix = metric_scale_transform @ alignment_matrix + final_clutter_aabb_2d_cm = _object_scenes_xy_aabb_manifest( + object_scenes=object_scenes, + trimesh=trimesh, + unit_scale=100.0, + unit="cm", + ) + return { + "status": "ok", + "output_dir": relative_path(str(output_dir), output_root), + "support_normal": support_normal.tolist(), + "clutter_aabb_center_before_centering": clutter_center.tolist(), + "alignment_matrix": scaled_alignment_matrix.tolist(), + "pre_metric_scale_alignment_matrix": alignment_matrix.tolist(), + "global_metric_scale": global_metric_scale, + "final_clutter_2d_aabb_cm": final_clutter_aabb_2d_cm, + "internal_up_axis": [0.0, 0.0, 1.0], + "glb_output_up_axis": [0.0, 1.0, 0.0], + "glb_output_axis_transform": output_axis_transform.tolist(), + "selected_up_down_variant": selected_variant, + "applied_up_down_flip": apply_up_down_flip, + "selected_extra_transform": selected_extra_transform.tolist(), + "object_alignment_matrices": { + object_id: (object_transform @ scaled_alignment_matrix).tolist() + for object_id, object_transform in footprint_result[ + "object_layout_transforms" + ].items() + }, + "footprint_layout": footprint_result["manifest"], + "yaw_sampling": { + "sample_count_per_variant": 360, + "score_type": "center_left_of_hard_count", + "top_view_plane": "XY", + "yaw_axis": "Z", + "left_right_axis": "X", + "front_back_axis": "Y", + "front_direction": "+Y", + "normal": alignment_candidates["normal"]["yaw_metadata"], + "flipped": alignment_candidates["flipped"]["yaw_metadata"], + }, + "up_down_flip_check": up_down_flip_check_result, + "objects": object_outputs, + } + + +def _build_up_down_alignment_candidates( + *, + object_scenes: list[tuple[str, Any]], + support_normal: np.ndarray, + normal_alignment: np.ndarray, + spatial_relations: list[dict[str, Any]], + trimesh: Any, +) -> dict[str, dict[str, Any]]: + flip_transform = _support_normal_flip_transform( + support_normal=support_normal, + normal_alignment=normal_alignment, + ) + directional_relations = _spatial_directional_relations(spatial_relations) + candidates: dict[str, dict[str, Any]] = {} + for variant, pre_yaw_transform in [ + ("normal", np.eye(4, dtype=np.float64)), + ("flipped", flip_transform), + ]: + candidate_object_scenes = [ + (object_id, _copy_scene_with_transform(scene, pre_yaw_transform)) + for object_id, scene in object_scenes + ] + object_bounds = { + object_id: np.asarray( + _scene_to_mesh(scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + for object_id, scene in candidate_object_scenes + } + yaw_metadata = _best_spatial_yaw( + object_bounds=object_bounds, + relations=directional_relations, + ) + yaw_transform = _z_yaw_transform( + float(yaw_metadata["yaw_degrees"]), + ) + for _, scene in candidate_object_scenes: + scene.apply_transform(yaw_transform) + candidates[variant] = { + "object_scenes": candidate_object_scenes, + "pre_yaw_transform": pre_yaw_transform, + "yaw_transform": yaw_transform, + "extra_transform": yaw_transform @ pre_yaw_transform, + "yaw_metadata": yaw_metadata, + } + return candidates + + +def _best_spatial_yaw( + *, + object_bounds: dict[str, np.ndarray], + relations: list[dict[str, str]], +) -> dict[str, Any]: + if not relations: + return { + "yaw_degrees": 0, + "score": 0, + "raw_gap_sum": 0.0, + "relation_count": 0, + "score_type": "center_left_of_hard_count", + } + + object_centers = { + object_id: _aabb_center(bounds) for object_id, bounds in object_bounds.items() + } + best_yaw = 0 + best_score = -1 + best_raw_gap_sum = float("-inf") + best_relation_scores: list[dict[str, Any]] = [] + for yaw_degrees in range(360): + rotation = _z_yaw_transform(float(yaw_degrees)) + rotated_centers = { + object_id: _transform_point(rotation, center) + for object_id, center in object_centers.items() + } + score, raw_gap_sum, relation_scores = _center_left_of_score( + centers=rotated_centers, + relations=relations, + ) + if score > best_score or ( + score == best_score and raw_gap_sum > best_raw_gap_sum + ): + best_yaw = yaw_degrees + best_score = score + best_raw_gap_sum = raw_gap_sum + best_relation_scores = relation_scores + return { + "yaw_degrees": best_yaw, + "score": best_score, + "raw_gap_sum": best_raw_gap_sum, + "relation_count": len(relations), + "score_type": "center_left_of_hard_count", + "relation_scores": best_relation_scores, + } + + +def _spatial_directional_relations( + spatial_relations: list[dict[str, Any]], +) -> list[dict[str, str]]: + relations: list[dict[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for relation in spatial_relations: + subject = str(relation.get("subject") or "") + object_id = str(relation.get("object") or "") + relation_name = str(relation.get("relation") or "") + if ( + not subject + or not object_id + or subject == object_id + or relation_name != "left_of" + ): + continue + key = (subject, relation_name, object_id) + if key in seen: + continue + seen.add(key) + relations.append( + { + "subject": subject, + "relation": relation_name, + "object": object_id, + } + ) + return relations + + +def _center_left_of_score( + centers: dict[str, np.ndarray], + relations: list[dict[str, str]], +) -> tuple[int, float, list[dict[str, Any]]]: + score = 0 + raw_gap_sum = 0.0 + relation_scores: list[dict[str, Any]] = [] + for relation in relations: + subject = relation["subject"] + object_id = relation["object"] + if subject not in centers or object_id not in centers: + continue + subject_center = centers[subject] + object_center = centers[object_id] + gap = float(object_center[0] - subject_center[0]) + relation_score = 1 if gap > 0.0 else 0 + score += relation_score + raw_gap_sum += gap + relation_scores.append( + { + "subject": subject, + "relation": "left_of", + "object": object_id, + "gap": gap, + "score": relation_score, + } + ) + return score, raw_gap_sum, relation_scores + + +def _transform_point(transform: np.ndarray, point: np.ndarray) -> np.ndarray: + homogeneous = np.ones(4, dtype=np.float64) + homogeneous[:3] = point + return (transform @ homogeneous)[:3] + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() + + +def _run_aligned_up_down_flip_vlm_check( + *, + llm: Any | None, + original_image_path: Path | None, + normal_object_scenes: list[tuple[str, Any]], + flipped_object_scenes: list[tuple[str, Any]], + output_dir: Path, +) -> dict[str, Any]: + output_dir.mkdir(parents=True, exist_ok=True) + result: dict[str, Any] = { + "status": "skipped", + "applied_up_down_flip": False, + "confidence_threshold": UP_DOWN_FLIP_CHECK_CONFIDENCE_THRESHOLD, + "reason": "", + } + if not normal_object_scenes or not flipped_object_scenes: + result["reason"] = "missing_object_scenes" + return result + + try: + normal_render_path = output_dir / "normal_object_only_front_oblique_view.png" + flipped_render_path = output_dir / "flipped_object_only_front_oblique_view.png" + comparison_image_path = output_dir / "numbered_up_down_candidates.png" + BlenderRenderingManager().render_object_scenes( + RenderObjectScenesRequest( + object_scenes=normal_object_scenes, + output_path=normal_render_path, + ) + ) + BlenderRenderingManager().render_object_scenes( + RenderObjectScenesRequest( + object_scenes=flipped_object_scenes, + output_path=flipped_render_path, + ) + ) + MatplotlibManager(figsize=(12, 6), dpi=180).render_image_comparison( + RenderImageComparisonRequest( + first_image_path=normal_render_path, + second_image_path=flipped_render_path, + output_path=comparison_image_path, + ) + ) + if llm is None: + result["reason"] = "missing_llm" + return result + if original_image_path is None or not original_image_path.is_file(): + result["reason"] = "missing_original_image" + return result + + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=UP_DOWN_FLIP_CHECK_JSON_SCHEMA, + messages=build_up_down_flip_check_messages( + original_image_path=original_image_path, + comparison_image_path=comparison_image_path, + ), + context="Unified scene aligned up-down flip check", + step_name=UNIFIED_SCENE_STEP, + output_root=None, + attempt_count=0, + ) + # Persist VLM raw output alongside the comparison renders + try: + import json as _json + + vlm_result_path = output_dir / "vlm_flip_check_result.json" + vlm_result_path.write_text( + _json.dumps(raw_model_output, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + except Exception: + pass + confidence = float(raw_model_output.get("confidence", 0.0)) + selected_number = int(raw_model_output.get("selected_number", 1)) + if selected_number not in {1, 2}: + selected_number = 1 + model_selected_variant = "flipped" if selected_number == 2 else "normal" + should_apply = ( + model_selected_variant == "flipped" + and confidence >= UP_DOWN_FLIP_CHECK_CONFIDENCE_THRESHOLD + ) + selected_variant = "flipped" if should_apply else "normal" + selected_number = 2 if selected_variant == "flipped" else 1 + result.update( + { + "status": "ok", + "selected_number": selected_number, + "selected_variant": selected_variant, + "applied_up_down_flip": should_apply, + "model_selected_number": raw_model_output.get("selected_number"), + "model_selected_variant": model_selected_variant, + "confidence": confidence, + "reason": str(raw_model_output.get("reason", "")), + } + ) + return result + except Exception: + result.update( + { + "status": "failed", + "reason": traceback.format_exc(), + } + ) + return result diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py new file mode 100644 index 00000000..6ae379c3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py @@ -0,0 +1,212 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _compose_json_matrices, + _compose_simready_to_aligned_matrix, + _decompose_transform_matrix, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json + +__all__ = ["_write_multi_object_layout_manifests"] + + +def _write_multi_object_layout_manifests( + *, + glb_gen_dir: Path, + output_root: Path, + table: dict[str, Any] | None, + objects: list[dict[str, Any]], + alignment: dict[str, Any] | None, +) -> dict[str, str]: + simready_to_aligned_path = glb_gen_dir / "simready_to_aligned_manifest.json" + + write_json( + simready_to_aligned_path, + _simready_to_aligned_manifest( + table=table, + items=objects, + alignment=alignment, + output_root=output_root, + ), + ) + return { + "simready_to_aligned_manifest_path": relative_path( + str(simready_to_aligned_path), + output_root, + ), + } + + +def _simready_to_aligned_manifest( + *, + table: dict[str, Any] | None, + items: list[dict[str, Any]], + alignment: dict[str, Any] | None, + output_root: Path, +) -> dict[str, Any]: + alignment = alignment or {} + alignment_matrix = alignment.get("alignment_matrix", []) + glb_output_axis_transform = alignment.get("glb_output_axis_transform", []) + object_alignment_matrices = alignment.get("object_alignment_matrices", {}) + aligned_by_id = _aligned_outputs_by_id(alignment) + return { + "note": ( + "Aligned GLBs are generated from raw_downloads plus SAM3D layout " + "matrices in memory; simready paths are recorded here as the " + "simulation-ready counterpart for each raw GLB." + ), + "alignment_status": alignment.get("status", ""), + "alignment_reason": alignment.get("reason", ""), + "selected_up_down_variant": alignment.get("selected_up_down_variant", ""), + "applied_up_down_flip": alignment.get("applied_up_down_flip", False), + "alignment_matrix": alignment_matrix, + "global_metric_scale": alignment.get("global_metric_scale"), + "final_clutter_2d_aabb_cm": alignment.get("final_clutter_2d_aabb_cm"), + "glb_output_axis_transform": glb_output_axis_transform, + "table": ( + _simready_manifest_table_item(table, output_root=output_root) + if table is not None + else None + ), + "items": [ + _simready_to_aligned_manifest_item( + item, + aligned_by_id=aligned_by_id, + alignment_matrix=alignment_matrix, + object_alignment_matrices=object_alignment_matrices, + glb_output_axis_transform=glb_output_axis_transform, + output_root=output_root, + ) + for item in items + ], + } + + +def _aligned_outputs_by_id(alignment: dict[str, Any]) -> dict[str, str]: + outputs: dict[str, str] = {} + for item in alignment.get("objects", []) or []: + if isinstance(item, dict) and item.get("id"): + outputs[str(item["id"])] = str(item.get("aligned_geometry_path", "")) + return outputs + + +def _simready_manifest_table_item( + item: dict[str, Any], + *, + output_root: Path, +) -> dict[str, Any]: + return { + "id": item.get("id", ""), + "name": item.get("name", ""), + "kind": item.get("kind", "table"), + "status": item.get("status", ""), + "simready_geometry_path": ( + relative_path( + str( + _resolve_generated_path( + item.get("simready_geometry_path"), output_root + ) + ), + output_root, + ) + if item.get("simready_geometry_path") + else "" + ), + "support_reference_geometry_path": ( + relative_path( + str( + _resolve_generated_path( + item.get("support_reference_geometry_path"), + output_root, + ) + ), + output_root, + ) + if item.get("support_reference_geometry_path") + else "" + ), + "table_asset_source": item.get("table_asset_source", ""), + "support_normal_source": item.get("support_normal_source", ""), + "is_complete_visible_table": item.get("is_complete_visible_table", False), + "complete_table_description": item.get("complete_table_description", ""), + } + + +def _simready_to_aligned_manifest_item( + item: dict[str, Any], + *, + aligned_by_id: dict[str, str], + alignment_matrix: Any, + object_alignment_matrices: Any, + glb_output_axis_transform: Any, + output_root: Path, +) -> dict[str, Any]: + item_id = str(item.get("id", "")) + sam3d_transform = item.get("transform_matrix", []) + item_alignment_matrix = alignment_matrix + if isinstance(object_alignment_matrices, dict): + item_alignment_matrix = object_alignment_matrices.get( + item_id, + alignment_matrix, + ) + raw_to_aligned_matrix = _compose_json_matrices( + glb_output_axis_transform, + item_alignment_matrix, + sam3d_transform, + ) + simready_to_aligned_matrix = _compose_simready_to_aligned_matrix( + raw_to_aligned_matrix=raw_to_aligned_matrix, + raw_to_simready_matrix=item.get("raw_to_simready_glb_matrix", []), + ) + decomposed = _decompose_transform_matrix(simready_to_aligned_matrix) + return { + "id": item_id, + "name": item.get("name", ""), + "kind": item.get("kind", ""), + "simready_geometry_path": item.get("simready_geometry_path", ""), + "aligned_geometry_path": aligned_by_id.get(item_id, ""), + "metric_scale": _trim_metric_scale(item.get("metric_scale")), + "simready_to_aligned_matrix": simready_to_aligned_matrix, + "translation": decomposed["translation"], + "rotation_matrix": decomposed["rotation_matrix"], + "scale": decomposed["scale"], + } + + +def _trim_metric_scale(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + metric_scale = dict(value) + for key in ["result_path", "raw_model_output_path"]: + metric_scale.pop(key, None) + return metric_scale + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py new file mode 100644 index 00000000..85b41388 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py @@ -0,0 +1,106 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_image_metric_scale_messages", + "build_up_down_flip_check_messages", +] + +UNIFIED_SCENE_GEN_PROMPT_NAME = "unified_scene_gen.yaml" + + +def build_image_metric_scale_messages( + *, + bbox_name_image_path: Path, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="image_metric_scale_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + { + "objects_json": json.dumps( + objects_json, + ensure_ascii=False, + indent=2, + ), + }, + prompt_key="image_metric_scale_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + +def build_up_down_flip_check_messages( + *, + original_image_path: Path, + comparison_image_path: Path, +) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(original_image_path)}, + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(comparison_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py new file mode 100644 index 00000000..b22fcebb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py @@ -0,0 +1,71 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +__all__ = [ + "IMAGE_METRIC_SCALE_JSON_SCHEMA", + "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", +] + +UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { + "title": "AlignedUpDownFlipCheckOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "selected_number": {"type": "integer", "enum": [1, 2]}, + "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, + "reason": {"type": "string"}, + }, + "required": ["selected_number", "confidence", "reason"], +} + +IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageMetricScaleEstimate", + "type": "object", + "additionalProperties": False, + "properties": { + "object_scales": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "object_id": {"type": "string"}, + "bbox_dims_cm": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "items": { + "type": "number", + "minimum": 1.0e-6, + }, + }, + "confidence": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + }, + "reason": {"type": "string"}, + }, + "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], + }, + }, + }, + "required": ["object_scales"], +} diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py new file mode 100644 index 00000000..fbbf3148 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager.manager import ( + ImageSegmentationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager.schemas import ( + AssetImageToRgbaRequest, + ImageSegmentationRequest, + ImageSegmentationResult, +) + +__all__ = [ + "AssetImageToRgbaRequest", + "ImageSegmentationManager", + "ImageSegmentationRequest", + "ImageSegmentationResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py new file mode 100644 index 00000000..052b8d7d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py @@ -0,0 +1,90 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + ImageSegmentationClient, + ImageSegmentationError, + ImageSegmentationServerRequest, + apply_mask_to_alpha, + decode_rle_mask, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager.schemas import ( + AssetImageToRgbaRequest, + ImageSegmentationRequest, + ImageSegmentationResult, +) + + +class ImageSegmentationManager: + """Image segmentation domain operations.""" + + def __init__(self, *, client: ImageSegmentationClient | None = None) -> None: + self.client = client or ImageSegmentationClient() + + def segment_image( + self, + request: ImageSegmentationRequest, + ) -> ImageSegmentationResult: + image_path = request.image_path.expanduser().resolve() + _validate_segment_request(image_path=image_path, prompt=request.prompt) + + response = self.client.segment( + ImageSegmentationServerRequest( + prompt=request.prompt.strip(), + image_path=image_path, + ), + ) + if isinstance(response, ImageSegmentationError): + raise RuntimeError(response.error_message) + + return ImageSegmentationResult(candidates=list(response.result.candidates)) + + def convert_asset_image_to_rgba( + self, + request: AssetImageToRgbaRequest, + ) -> Path: + segmentation_result = self.segment_image( + ImageSegmentationRequest( + image_path=request.image_path, + prompt=request.prompt, + ) + ) + if not segmentation_result.candidates: + raise ValueError("Image segmentation returned no candidates.") + + candidate = segmentation_result.candidates[0] + if candidate.mask_rle is None: + raise ValueError(f"Candidate {candidate.candidate_id} has no mask_rle.") + + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + mask = decode_rle_mask(candidate.mask_rle) + rgba = apply_mask_to_alpha(request.image_path, mask) + rgba.save(output_path) + if not output_path.is_file(): + raise FileNotFoundError(f"RGBA image was not written: {output_path}") + return output_path + + +def _validate_segment_request(*, image_path: Path, prompt: str) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"Image segmentation input not found: {image_path}") + if not prompt.strip(): + raise ValueError("Image segmentation prompt must be non-empty.") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py new file mode 100644 index 00000000..d59b7e7a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + ImageSegmentationCandidate, +) + + +@dataclass(frozen=True) +class AssetImageToRgbaRequest: + """Request for converting an asset image to an RGBA cutout.""" + + image_path: Path + prompt: str + output_path: Path + + +@dataclass(frozen=True) +class ImageSegmentationRequest: + """Request for segmenting one image with one text prompt.""" + + image_path: Path + prompt: str + + +@dataclass(frozen=True) +class ImageSegmentationResult: + """Segmentation candidates.""" + + candidates: list[ImageSegmentationCandidate] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py new file mode 100644 index 00000000..21cf6c25 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py @@ -0,0 +1,43 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.manager import ( + MatplotlibManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.schemas import ( + RenderFootprintLayoutRequest, + RenderFootprintLayoutResult, + RenderImageComparisonRequest, + RenderImageComparisonResult, + RenderSupportRegionRequest, + RenderSupportRegionResult, + RenderXYComparisonRequest, + RenderXYComparisonResult, +) + +__all__ = [ + "MatplotlibManager", + "RenderFootprintLayoutRequest", + "RenderFootprintLayoutResult", + "RenderImageComparisonRequest", + "RenderImageComparisonResult", + "RenderSupportRegionRequest", + "RenderSupportRegionResult", + "RenderXYComparisonRequest", + "RenderXYComparisonResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py new file mode 100644 index 00000000..1feb13c3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py @@ -0,0 +1,401 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Matplotlib manager for mesh visualization.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.collections import PolyCollection +from matplotlib.patches import Rectangle +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.schemas import ( + RenderFootprintLayoutRequest, + RenderFootprintLayoutResult, + RenderImageComparisonRequest, + RenderImageComparisonResult, + RenderSupportRegionRequest, + RenderSupportRegionResult, + RenderXYComparisonRequest, + RenderXYComparisonResult, +) + +__all__ = ["MatplotlibManager"] + + +class MatplotlibManager: + """Manager for mesh visualization via matplotlib. + + Wraps matplotlib rendering with typed request/response methods, + following the same pattern as service clients. + """ + + def __init__( + self, + *, + figsize: tuple[float, float] = (8, 8), + dpi: int = 180, + ) -> None: + """Initialize the matplotlib manager. + + Args: + figsize: Default figure size for rendered images. + dpi: Output image resolution. + """ + self._figsize = figsize + self._dpi = dpi + + def render_footprint_layout( + self, + request: RenderFootprintLayoutRequest, + ) -> RenderFootprintLayoutResult: + """Render labeled XY footprints with full-length coordinate axes.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + if not request.object_ids: + return RenderFootprintLayoutResult(output_path=output_path) + + centers = { + object_id: np.asarray(request.centers[object_id], dtype=float) + for object_id in request.object_ids + } + sizes = { + object_id: np.asarray(request.xy_sizes[object_id], dtype=float) + for object_id in request.object_ids + } + footprint_mins = np.vstack( + [ + centers[object_id] - 0.5 * sizes[object_id] + for object_id in request.object_ids + ] + ) + footprint_maxs = np.vstack( + [ + centers[object_id] + 0.5 * sizes[object_id] + for object_id in request.object_ids + ] + ) + data_min = footprint_mins.min(axis=0) + data_max = footprint_maxs.max(axis=0) + span = np.maximum(data_max - data_min, 1.0e-6) + padding = max(float(span.max()) * 0.12, 1.0e-3) + x_limits = (float(data_min[0] - padding), float(data_max[0] + padding)) + y_limits = (float(data_min[1] - padding), float(data_max[1] + padding)) + + fig, ax = plt.subplots(figsize=self._figsize) + for object_id in request.object_ids: + center = centers[object_id] + size = sizes[object_id] + ax.add_patch( + Rectangle( + (center[0] - 0.5 * size[0], center[1] - 0.5 * size[1]), + size[0], + size[1], + facecolor=(0.35, 0.60, 0.95, 0.30), + edgecolor=(0.08, 0.22, 0.60, 1.0), + linewidth=1.5, + ) + ) + label = object_id.replace("interact_", "").removesuffix("_0") + ax.text( + center[0], + center[1], + label, + ha="center", + va="center", + fontsize=9, + color="black", + ) + + self._draw_full_xy_axes(ax, x_limits=x_limits, y_limits=y_limits) + ax.set_xlim(*x_limits) + ax.set_ylim(*y_limits) + ax.set_aspect("equal", adjustable="box") + ax.set_title(request.title) + ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.30) + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi) + plt.close(fig) + return RenderFootprintLayoutResult(output_path=output_path) + + def render_image_comparison( + self, + request: RenderImageComparisonRequest, + ) -> RenderImageComparisonResult: + """Render two images side by side with numbered labels.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + first_image = plt.imread(request.first_image_path.expanduser().resolve()) + second_image = plt.imread(request.second_image_path.expanduser().resolve()) + + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + for ax, image, label in ( + (axes[0], first_image, request.first_label), + (axes[1], second_image, request.second_label), + ): + ax.imshow(image) + ax.set_title(label, fontsize=16, loc="left") + ax.axis("off") + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi, facecolor="white") + plt.close(fig) + return RenderImageComparisonResult(output_path=output_path) + + @staticmethod + def _draw_full_xy_axes( + ax: Any, + *, + x_limits: tuple[float, float], + y_limits: tuple[float, float], + ) -> None: + """Draw axes across the full viewport, centered on the data bounds.""" + axis_color = "#303030" + x_center = 0.5 * (x_limits[0] + x_limits[1]) + y_center = 0.5 * (y_limits[0] + y_limits[1]) + # Horizontal axis (X) — spans full width, positioned at vertical centre. + ax.annotate( + "", + xy=(x_limits[1], y_center), + xytext=(x_limits[0], y_center), + arrowprops={"arrowstyle": "->", "color": axis_color, "lw": 1.8}, + zorder=8, + ) + # Vertical axis (Y) — spans full height, positioned at horizontal centre. + ax.annotate( + "", + xy=(x_center, y_limits[1]), + xytext=(x_center, y_limits[0]), + arrowprops={"arrowstyle": "->", "color": axis_color, "lw": 1.8}, + zorder=8, + ) + x_span = x_limits[1] - x_limits[0] + y_span = y_limits[1] - y_limits[0] + ax.text( + x_limits[1] - 0.03 * x_span, + y_center + 0.02 * y_span, + "+X", + ha="right", + va="bottom", + color=axis_color, + fontsize=11, + ) + ax.text( + x_center + 0.02 * x_span, + y_limits[1] - 0.03 * y_span, + "+Y", + ha="left", + va="top", + color=axis_color, + fontsize=11, + ) + # Mark the origin at the centre. + ax.plot(x_center, y_center, "o", color=axis_color, markersize=6, zorder=9) + ax.text( + x_center + 0.015 * x_span, + y_center + 0.015 * y_span, + "Origin", + fontsize=8, + color=axis_color, + ha="left", + va="bottom", + zorder=9, + ) + + def render_selected_support_region( + self, request: RenderSupportRegionRequest + ) -> RenderSupportRegionResult: + """Render a mesh with the selected support region highlighted.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + + vertices = np.asarray(request.mesh.vertices, dtype=float) + faces = np.asarray(request.mesh.faces, dtype=int) + selected_faces = faces[np.asarray(request.face_indices, dtype=int)] + + fig = plt.figure(figsize=self._figsize) + ax = fig.add_subplot(111, projection="3d") + ax.add_collection3d( + Poly3DCollection( + vertices[faces], + facecolors=(0.65, 0.68, 0.72, 0.16), + edgecolors=(0.35, 0.37, 0.40, 0.08), + linewidths=0.15, + ) + ) + ax.add_collection3d( + Poly3DCollection( + vertices[selected_faces], + facecolors=(1.0, 0.18, 0.05, 0.88), + edgecolors=(0.55, 0.02, 0.0, 1.0), + linewidths=0.8, + ) + ) + self._set_equal_axes(ax, vertices) + ax.view_init(elev=25.0, azim=-45.0) + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_title("Selected Support Region") + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi) + plt.close(fig) + return RenderSupportRegionResult(output_path=output_path) + + def render_xy_alignment_comparison( + self, request: RenderXYComparisonRequest + ) -> RenderXYComparisonResult: + """Render before/after XY projections for PCA yaw alignment.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + + before_polygons, before_xy = self._xy_polygons_and_vertices(request.before_mesh) + after_polygons, after_xy = self._xy_polygons_and_vertices(request.after_mesh) + center, view_half = self._xy_view_bounds(before_xy, after_xy) + + fig, axes = plt.subplots(1, 2, figsize=self._figsize) + self._draw_xy_projection( + axes[0], + before_polygons, + before_xy, + "Before PCA yaw", + center, + view_half, + ) + self._draw_xy_projection( + axes[1], + after_polygons, + after_xy, + f"After PCA yaw ({request.angle_degrees:.2f} deg)", + center, + view_half, + ) + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi) + plt.close(fig) + return RenderXYComparisonResult(output_path=output_path) + + @staticmethod + def _xy_polygons_and_vertices(mesh: Any) -> tuple[Any, Any]: + vertices = np.asarray(mesh.vertices, dtype=float) + faces = np.asarray(mesh.faces, dtype=int) + return vertices[faces][:, :, :2], vertices[:, :2] + + @staticmethod + def _xy_view_bounds(before_xy: Any, after_xy: Any) -> tuple[Any, float]: + values = np.concatenate([before_xy, after_xy], axis=0) + bounds_min = values.min(axis=0) + bounds_max = values.max(axis=0) + center = 0.5 * (bounds_min + bounds_max) + span = np.maximum(bounds_max - bounds_min, 1e-3) + view_half = max(float(span.max()) * 0.65, 0.5) + return center, view_half + + def _draw_xy_projection( + self, + ax: Any, + polygons_xy: Any, + vertices_xy: Any, + title: str, + center: Any, + view_half: float, + ) -> None: + ax.add_collection( + PolyCollection( + polygons_xy, + facecolors=(0.24, 0.50, 0.90, 0.28), + edgecolors=(0.05, 0.16, 0.35, 0.20), + linewidths=0.20, + ) + ) + self._draw_xy_aabb(ax, vertices_xy) + self._add_xy_axes(ax, view_half) + ax.set_xlim(center[0] - view_half, center[0] + view_half) + ax.set_ylim(center[1] - view_half, center[1] + view_half) + ax.set_aspect("equal", adjustable="box") + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_title(title) + ax.grid(True, which="major", linestyle="-", linewidth=0.7, alpha=0.35) + ax.minorticks_on() + ax.grid(True, which="minor", linestyle=":", linewidth=0.45, alpha=0.25) + + @staticmethod + def _draw_xy_aabb(ax: Any, vertices_xy: Any) -> None: + bounds_min = vertices_xy.min(axis=0) + bounds_max = vertices_xy.max(axis=0) + width, height = bounds_max - bounds_min + ax.add_patch( + Rectangle( + (bounds_min[0], bounds_min[1]), + width, + height, + fill=False, + edgecolor="#d62828", + linewidth=1.6, + linestyle="-", + alpha=0.95, + ) + ) + + @staticmethod + def _add_xy_axes(ax: Any, view_half: float) -> None: + arrow_len = max(view_half * 0.35, 0.2) + ax.scatter([0.0], [0.0], color="black", s=22, zorder=8) + ax.text(0.0, 0.0, " Origin", fontsize=9, ha="left", va="bottom") + ax.arrow( + 0.0, + 0.0, + arrow_len, + 0.0, + width=arrow_len * 0.015, + head_width=arrow_len * 0.06, + head_length=arrow_len * 0.08, + color="#d62828", + length_includes_head=True, + zorder=9, + ) + ax.text(arrow_len * 1.08, 0.0, "+X", color="#d62828", fontsize=11) + ax.arrow( + 0.0, + 0.0, + 0.0, + arrow_len, + width=arrow_len * 0.015, + head_width=arrow_len * 0.06, + head_length=arrow_len * 0.08, + color="#2a9d8f", + length_includes_head=True, + zorder=9, + ) + ax.text(0.0, arrow_len * 1.08, "+Y", color="#2a9d8f", fontsize=11) + + @staticmethod + def _set_equal_axes(ax: Any, vertices: Any) -> None: + mins = np.min(vertices, axis=0) + maxs = np.max(vertices, axis=0) + center = (mins + maxs) * 0.5 + radius = max(float(np.max(maxs - mins)) * 0.5, 1e-6) + ax.set_xlim(center[0] - radius, center[0] + radius) + ax.set_ylim(center[1] - radius, center[1] + radius) + ax.set_zlim(center[2] - radius, center[2] + radius) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py new file mode 100644 index 00000000..764383f3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py @@ -0,0 +1,101 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "RenderFootprintLayoutRequest", + "RenderFootprintLayoutResult", + "RenderImageComparisonRequest", + "RenderImageComparisonResult", + "RenderSupportRegionRequest", + "RenderSupportRegionResult", + "RenderXYComparisonRequest", + "RenderXYComparisonResult", +] + + +@dataclass(frozen=True) +class RenderFootprintLayoutRequest: + """Request to render labeled top-down object footprints.""" + + object_ids: list[str] + centers: dict[str, Any] + xy_sizes: dict[str, Any] + output_path: Path + title: str = "" + + +@dataclass(frozen=True) +class RenderFootprintLayoutResult: + """Result of rendering a footprint layout.""" + + output_path: Path + + +@dataclass(frozen=True) +class RenderImageComparisonRequest: + """Request to render two labeled images side by side.""" + + first_image_path: Path + second_image_path: Path + output_path: Path + first_label: str = "1: normal" + second_label: str = "2: flipped" + + +@dataclass(frozen=True) +class RenderImageComparisonResult: + """Result of rendering an image comparison.""" + + output_path: Path + + +@dataclass(frozen=True) +class RenderSupportRegionRequest: + """Request to render a mesh with the selected support region highlighted.""" + + mesh: Any + face_indices: list[int] + output_path: Path + + +@dataclass(frozen=True) +class RenderSupportRegionResult: + """Result of rendering the support region.""" + + output_path: Path + + +@dataclass(frozen=True) +class RenderXYComparisonRequest: + """Request to render before/after XY projections for PCA yaw alignment.""" + + before_mesh: Any + after_mesh: Any + angle_degrees: float + output_path: Path + + +@dataclass(frozen=True) +class RenderXYComparisonResult: + """Result of rendering the XY alignment comparison.""" + + output_path: Path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py new file mode 100644 index 00000000..8eca3510 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.manager import ( + METRIC_SCALE_ENABLED, + MetricScaleManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.schemas import ( + EstimateMetricScalesRequest, + EstimateMetricScalesResult, + GlobalMetricScaleRequest, + MetricScaleObjectInput, +) + +__all__ = [ + "METRIC_SCALE_ENABLED", + "EstimateMetricScalesRequest", + "EstimateMetricScalesResult", + "GlobalMetricScaleRequest", + "MetricScaleManager", + "MetricScaleObjectInput", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py new file mode 100644 index 00000000..ce1d47e9 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py @@ -0,0 +1,431 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + LoadMeshRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.schemas import ( + EstimateMetricScalesRequest, + EstimateMetricScalesResult, + GlobalMetricScaleRequest, + MetricScaleObjectInput, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, +) + +__all__ = ["METRIC_SCALE_ENABLED", "MetricScaleManager"] + +METRIC_SCALE_ENABLED = True + + +class MetricScaleManager: + """Manager for metric scale estimation and scale aggregation.""" + + @staticmethod + def estimate_metric_scales( + request: EstimateMetricScalesRequest, + ) -> EstimateMetricScalesResult: + """Call an LLM and convert bbox-size predictions into scale factors.""" + object_payload = MetricScaleManager.build_object_payload(request.objects) + raw_model_output_path = ( + request.raw_output_path.expanduser().resolve() + if request.raw_output_path is not None + else None + ) + raw_model_output = call_structured_json_model_step( + llm=request.llm, + schema=request.schema, + messages=request.messages, + context=request.context, + step_name=request.step_name, + output_root=None, + attempt_count=0, + raw_output_writer=( + (lambda payload: write_json(raw_model_output_path, payload)) + if raw_model_output_path is not None + else None + ), + ) + object_scales = MetricScaleManager.apply_model_output( + object_payload=object_payload, + raw_model_output=raw_model_output, + method=request.method, + ) + return EstimateMetricScalesResult( + status="ok", + object_scales=object_scales, + object_payload=object_payload, + raw_model_output=raw_model_output, + ) + + @staticmethod + def build_object_payload( + objects: list[MetricScaleObjectInput], + ) -> list[dict[str, Any]]: + """Build object payload with normalized mesh bbox measurements.""" + geom = GeometryManager() + payload: list[dict[str, Any]] = [] + for obj in objects: + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=obj.mesh_path)).mesh + normalized_bbox_size_m = GeometryManager.mesh_aabb_size(mesh) + payload.append( + { + "object_id": obj.object_id, + "object_name": obj.object_name, + "object_description": obj.object_description, + "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), + "normalized_bbox_ratio": GeometryManager.bbox_ratio( + normalized_bbox_size_m + ).tolist(), + } + ) + return payload + + @staticmethod + def object_prompt_payload( + objects: list[MetricScaleObjectInput], + ) -> list[dict[str, str]]: + """Return the lightweight object payload intended for LLM prompts.""" + return [ + { + "object_id": obj.object_id, + "object_name": obj.object_name, + "object_description": obj.object_description, + } + for obj in objects + ] + + @staticmethod + def apply_model_output( + *, + object_payload: list[dict[str, Any]], + raw_model_output: dict[str, Any], + method: str, + ) -> list[dict[str, Any]]: + """Convert model bbox predictions into per-object metric-scale records.""" + model_by_id = { + str(item.get("object_id", "")): item + for item in raw_model_output.get("object_scales", []) + if isinstance(item, dict) + } + estimates: list[dict[str, Any]] = [] + for payload in object_payload: + object_id = str(payload.get("object_id", "")) + model_item = model_by_id.get(object_id) + if model_item is None: + estimates.append( + MetricScaleManager.failure( + object_id=object_id, + reason="missing_object_scale_from_model", + method=method, + ) + ) + continue + estimates.append( + MetricScaleManager.select_candidate( + object_id=object_id, + object_name=str(payload.get("object_name", "")), + object_description=str(payload.get("object_description", "")), + bbox_dims_cm=model_item.get("bbox_dims_cm", []), + confidence=float(model_item.get("confidence", 0.0)), + reason=str(model_item.get("reason", "")), + normalized_bbox_size_m=np.asarray( + payload["normalized_bbox_size_m"], + dtype=np.float64, + ), + method=method, + ) + ) + return estimates + + @staticmethod + def apply_to_objects( + *, + objects: list[dict[str, Any]], + object_scales: list[dict[str, Any]], + ) -> None: + """Attach metric-scale records to object dictionaries by object id.""" + scale_by_id = {str(item.get("object_id", "")): item for item in object_scales} + for obj in objects: + object_id = str(obj.get("id", "")) + if object_id in scale_by_id: + obj["metric_scale"] = scale_by_id[object_id] + + @staticmethod + def select_candidate( + *, + object_id: str, + object_name: str, + object_description: str, + bbox_dims_cm: Any, + confidence: float, + reason: str, + normalized_bbox_size_m: np.ndarray, + method: str, + ) -> dict[str, Any]: + """Select a scale factor from predicted real-world bbox dimensions.""" + try: + selected = MetricScaleManager.compute_from_bbox_dims( + bbox_dims_cm=bbox_dims_cm, + confidence=confidence, + reason=reason, + normalized_bbox_size_m=normalized_bbox_size_m, + ) + except (TypeError, ValueError): + return MetricScaleManager.failure( + object_id=object_id, + reason="invalid_bbox_dims_cm", + method=method, + ) + normalized_bbox_size_cm = ( + np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 + ) + return { + "status": "ok", + "method": method, + "object_id": object_id, + "object_name": object_name, + "object_description": object_description, + "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), + "normalized_bbox_size_cm": normalized_bbox_size_cm.tolist(), + "normalized_bbox_ratio": GeometryManager.bbox_ratio( + normalized_bbox_size_m + ).tolist(), + "bbox_dims_cm": selected["bbox_dims_cm"], + "axis_match": selected["axis_match"], + "scale_factor": selected["scale_factor"], + "confidence": selected["confidence"], + "reason": selected["reason"], + "unit_note": "scale_factor is not baked into this GLB.", + } + + @staticmethod + def compute_from_bbox_dims( + *, + bbox_dims_cm: Any, + confidence: float, + reason: str, + normalized_bbox_size_m: np.ndarray, + ) -> dict[str, Any]: + """Compute one scale candidate from model-predicted bbox dimensions.""" + dims_cm = np.asarray( + [float(value) for value in bbox_dims_cm], + dtype=np.float64, + ) + if dims_cm.shape != (3,) or np.any(dims_cm <= 0.0): + raise ValueError("bbox_dims_cm must contain three positive values.") + normalized_bbox_size_cm = ( + np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 + ) + axis_match = GeometryManager.best_axis_bbox_scale_match( + source_size_cm=normalized_bbox_size_cm, + target_size_cm=dims_cm, + ) + return { + "bbox_dims_cm": dims_cm.tolist(), + "axis_match": axis_match, + "scale_factor": float(axis_match["scale_factor"]), + "confidence": confidence, + "reason": reason, + } + + @staticmethod + def failure( + *, + object_id: str, + reason: str, + method: str, + ) -> dict[str, Any]: + """Build a failed per-object metric-scale record.""" + return { + "status": "failed", + "method": method, + "object_id": object_id, + "scale_factor": 1.0, + "reason": reason, + } + + @staticmethod + def set_for_all_objects( + *, + objects: list[dict[str, Any]], + status: str, + reason: str, + method: str, + ) -> None: + """Attach the same fallback metric-scale status to all objects.""" + for obj in objects: + obj["metric_scale"] = { + "status": status, + "method": method, + "object_id": str(obj.get("id", "")), + "scale_factor": 1.0, + "reason": reason, + } + + @staticmethod + def compute_global_from_object_scenes( + request: GlobalMetricScaleRequest, + ) -> dict[str, Any]: + """Aggregate object metric scales into one global scale for a scene layout.""" + if not METRIC_SCALE_ENABLED: + return { + "status": "disabled", + "method": "metric_scale_disabled", + "scale_factor": 1.0, + "object_count": len(request.objects), + "used_count": 0, + "skipped_count": len(request.objects), + "used": [], + "skipped": [ + {"id": str(item.get("id", "")), "reason": "metric_scale_disabled"} + for item in request.objects + ], + "unit_note": ( + "Metric scale is disabled; aligned GLBs keep simready " + "normalized size." + ), + } + + used: list[dict[str, Any]] = [] + skipped: list[dict[str, Any]] = [] + object_by_id = {str(item.get("id", "")): item for item in request.objects} + for object_id, scene in request.object_scenes: + item = object_by_id.get(object_id) + if item is None: + skipped.append({"id": object_id, "reason": "missing_object_record"}) + continue + metric_scale = item.get("metric_scale") + if not isinstance(metric_scale, dict): + skipped.append({"id": object_id, "reason": "missing_metric_scale"}) + continue + if metric_scale.get("status") != "ok": + skipped.append( + { + "id": object_id, + "reason": str(metric_scale.get("status") or "not_ok"), + } + ) + continue + + scale_factor_simready = float(metric_scale.get("scale_factor", 1.0)) + if not np.isfinite(scale_factor_simready) or scale_factor_simready <= 0.0: + skipped.append( + {"id": object_id, "reason": "invalid_simready_scale_factor"} + ) + continue + try: + simready_size_m = np.asarray( + [float(v) for v in metric_scale.get("normalized_bbox_size_m", [])], + dtype=np.float64, + ) + except (TypeError, ValueError): + skipped.append( + {"id": object_id, "reason": "invalid_normalized_bbox_size_m"} + ) + continue + if simready_size_m.shape != (3,) or np.any(simready_size_m <= 0.0): + skipped.append( + {"id": object_id, "reason": "invalid_normalized_bbox_size_m"} + ) + continue + + current_bounds = np.asarray(GeometryManager.scene_to_mesh(scene).bounds) + current_size_m = current_bounds[1] - current_bounds[0] + if current_size_m.shape != (3,) or np.any(current_size_m <= 0.0): + skipped.append({"id": object_id, "reason": "invalid_current_scene_aabb"}) + continue + + geo_ratio = np.sort(current_size_m) / np.sort(simready_size_m) + geo_scale = float(np.median(geo_ratio)) + if not np.isfinite(geo_scale) or geo_scale <= 0.0: + skipped.append({"id": object_id, "reason": "non_positive_geo_scale"}) + continue + + effective_scale = scale_factor_simready / geo_scale + if not np.isfinite(effective_scale) or effective_scale <= 0.0: + skipped.append( + {"id": object_id, "reason": "non_positive_effective_scale"} + ) + continue + + used.append( + { + "id": object_id, + "effective_scale": effective_scale, + "scale_factor_simready": scale_factor_simready, + "geo_scale": geo_scale, + "simready_bbox_size_m": simready_size_m.tolist(), + "simready_bbox_size_cm": (simready_size_m * 100.0).tolist(), + "current_scene_bbox_size_m": current_size_m.tolist(), + "current_scene_bbox_size_cm": (current_size_m * 100.0).tolist(), + "target_bbox_dims_cm": metric_scale.get("bbox_dims_cm"), + "confidence": metric_scale.get("confidence"), + } + ) + + if not used: + return { + "status": "fallback", + "method": "simready_reference_geo_ratio_mean_with_clamp", + "scale_factor": 1.0, + "raw_scale_factor": 1.0, + "was_clamped": False, + "clamp": {"min": request.min_scale, "max": request.max_scale}, + "object_count": len(request.objects), + "used_count": 0, + "skipped_count": len(skipped), + "used": [], + "skipped": skipped, + "unit_note": ( + "No valid metric scale was available; image clutter keeps the " + "SAM3D layout scale without an additional metric scale." + ), + } + + raw_scale_factor = float(np.mean([item["effective_scale"] for item in used])) + scale_factor = float( + np.clip(raw_scale_factor, request.min_scale, request.max_scale) + ) + return { + "status": "ok", + "method": "simready_reference_geo_ratio_mean_with_clamp", + "scale_factor": scale_factor, + "raw_scale_factor": raw_scale_factor, + "was_clamped": bool(scale_factor != raw_scale_factor), + "clamp": {"min": request.min_scale, "max": request.max_scale}, + "object_count": len(request.objects), + "used_count": len(used), + "skipped_count": len(skipped), + "used": used, + "skipped": skipped, + "unit_note": ( + "Global scale derived from scene-level VLM per-object scale_factor " + "divided by the geometric scale ratio between simready normalized " + "bbox and current aligned scene bbox (sorted, permutation-invariant). " + f"Aggregated via mean across objects, clamped to " + f"[{request.min_scale:.2f}, {request.max_scale:.2f}]." + ), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py new file mode 100644 index 00000000..dd2de343 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py @@ -0,0 +1,73 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "EstimateMetricScalesRequest", + "EstimateMetricScalesResult", + "GlobalMetricScaleRequest", + "MetricScaleObjectInput", +] + + +@dataclass(frozen=True) +class MetricScaleObjectInput: + """Object input for metric-scale estimation.""" + + object_id: str + object_name: str + object_description: str + mesh_path: Path + + +@dataclass(frozen=True) +class EstimateMetricScalesRequest: + """Request to estimate metric scale for a set of normalized objects.""" + + objects: list[MetricScaleObjectInput] + messages: list[dict[str, Any]] + schema: dict[str, Any] + llm: Any + context: str + method: str + step_name: str = "metric_scale" + raw_output_path: Path | None = None + + +@dataclass(frozen=True) +class EstimateMetricScalesResult: + """Result of estimating metric scale for normalized objects.""" + + status: str + object_scales: list[dict[str, Any]] + object_payload: list[dict[str, Any]] + raw_model_output: dict[str, Any] | None = None + reason: str = "" + + +@dataclass(frozen=True) +class GlobalMetricScaleRequest: + """Request to aggregate per-object metric scales into one scene scale.""" + + objects: list[dict[str, Any]] + object_scenes: list[tuple[str, Any]] + min_scale: float = 0.10 + max_scale: float = 10.00 diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py new file mode 100644 index 00000000..b61756bf --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager.manager import ( + _center_xy_aabb_layout, + _footprint_layout_diagnostics, + _object_scenes_xy_aabb_manifest, + _settle_and_pack_object_footprints, + _xy_aabb_overlap, + _xy_union_area, + _xy_union_bounds, +) + +__all__ = [ + "_center_xy_aabb_layout", + "_footprint_layout_diagnostics", + "_object_scenes_xy_aabb_manifest", + "_settle_and_pack_object_footprints", + "_xy_aabb_overlap", + "_xy_union_area", + "_xy_union_bounds", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py new file mode 100644 index 00000000..d7ed1348 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py @@ -0,0 +1,633 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import tempfile +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _aabb_bottom_to_xy_plane_transform, + _copy_scene_with_transform, + _matrix_from_json, + _scene_to_mesh, + _xy_aabb_center, + _xy_aabb_size, + _z_up_to_glb_y_up_transform, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) + +__all__ = [ + "_center_xy_aabb_layout", + "_object_scenes_xy_aabb_manifest", + "_settle_and_pack_object_footprints", + "_xy_aabb_overlap", + "_xy_union_area", + "_xy_union_bounds", +] + +def _object_scenes_xy_aabb_manifest( + *, + object_scenes: list[tuple[str, Any]], + trimesh: Any, + unit_scale: float, + unit: str, +) -> dict[str, Any]: + if not object_scenes: + return { + "status": "empty", + "unit": unit, + "object_count": 0, + } + bounds = [ + np.asarray(_scene_to_mesh(scene, trimesh=trimesh).bounds, dtype=np.float64) + for _, scene in object_scenes + ] + union_bounds = np.vstack( + [ + np.vstack([item[0] for item in bounds]).min(axis=0), + np.vstack([item[1] for item in bounds]).max(axis=0), + ] + ) + min_xy = union_bounds[0, :2] * unit_scale + max_xy = union_bounds[1, :2] * unit_scale + size_xy = max_xy - min_xy + center_xy = 0.5 * (min_xy + max_xy) + return { + "status": "ok", + "unit": unit, + "object_count": len(object_scenes), + "min_xy": min_xy.tolist(), + "max_xy": max_xy.tolist(), + "center_xy": center_xy.tolist(), + "size_xy": size_xy.tolist(), + "area": float(size_xy[0] * size_xy[1]), + } + + + +def _settle_and_pack_object_footprints( + *, + object_scenes: list[tuple[str, Any]], + output_dir: Path, + output_root: Path, + trimesh: Any, +) -> dict[str, Any]: + sim = SimulationManager(headless=True, sim_device="cpu") + footprint_items: list[dict[str, Any]] = [] + settled_entries: list[dict[str, Any]] = [] + output_axis_transform = _z_up_to_glb_y_up_transform() + output_to_internal_transform = np.linalg.inv(output_axis_transform) + + with tempfile.TemporaryDirectory(prefix="p2s_footprint_drop_") as tmp_dir: + tmp_path = Path(tmp_dir) + for object_id, scene in object_scenes: + mesh = _scene_to_mesh(scene, trimesh=trimesh) + mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) + mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) + bottom_to_xy_plane_transform = _aabb_bottom_to_xy_plane_transform( + mesh_bounds + ) + normalized_scene = _copy_scene_with_transform( + scene, + bottom_to_xy_plane_transform, + ) + normalized_output_scene = _copy_scene_with_transform( + normalized_scene, + output_axis_transform, + ) + pre_gravity_path = tmp_path / f"{object_id}_pre_gravity.glb" + normalized_output_scene.export(pre_gravity_path) + gravity_initial_height = mesh_z_height * 0.1 + + gravity_status = "ok" + gravity_transform = np.eye(4, dtype=np.float64) + gravity_reason = "" + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity_path, + max_convex_hull_num=32, + initial_height=gravity_initial_height, + ) + ) + gravity_transform = _matrix_from_json( + gravity_result.final_pose, + name=f"{object_id}.gravity_final_pose", + ) + except Exception: + gravity_status = "failed" + gravity_reason = traceback.format_exc() + + settled_origin_scene = _copy_scene_with_transform( + normalized_scene, + gravity_transform, + ) + settled_mesh = _scene_to_mesh(settled_origin_scene, trimesh=trimesh) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + settled_xy_center = _xy_aabb_center(settled_bounds) + settled_xy_size = _xy_aabb_size(settled_bounds) + settled_entries.append( + { + "id": object_id, + "scene": scene, + "bottom_to_xy_plane_transform": bottom_to_xy_plane_transform, + "mesh_z_height": mesh_z_height, + "gravity_initial_height": gravity_initial_height, + "gravity_transform": gravity_transform, + "settled_bounds": settled_bounds, + "settled_xy_center": settled_xy_center, + "settled_xy_size": settled_xy_size, + "gravity_status": gravity_status, + "gravity_reason": gravity_reason, + } + ) + + layout_result = _optimize_xy_aabb_footprint_layout( + object_ids=[str(entry["id"]) for entry in settled_entries], + xy_sizes={ + str(entry["id"]): np.asarray(entry["settled_xy_size"], dtype=np.float64) + for entry in settled_entries + }, + current_centers={ + str(entry["id"]): _xy_aabb_center( + _scene_to_mesh(entry["scene"], trimesh=trimesh).bounds + ) + for entry in settled_entries + }, + ) + target_centers = layout_result["centers"] + + packed_object_scenes: list[tuple[str, Any]] = [] + object_layout_transforms: dict[str, np.ndarray] = {} + for entry in settled_entries: + object_id = str(entry["id"]) + settled_bounds = np.asarray(entry["settled_bounds"], dtype=np.float64) + target_xy = target_centers[object_id] + placement_transform = np.eye(4, dtype=np.float64) + placement_transform[:3, 3] = [ + float(target_xy[0] - entry["settled_xy_center"][0]), + float(target_xy[1] - entry["settled_xy_center"][1]), + -float(settled_bounds[0][2]), + ] + object_transform = ( + placement_transform + @ entry["gravity_transform"] + @ entry["bottom_to_xy_plane_transform"] + ) + packed_scene = _copy_scene_with_transform(entry["scene"], object_transform) + packed_object_scenes.append((object_id, packed_scene)) + object_layout_transforms[object_id] = object_transform + + packed_bounds = np.asarray( + _scene_to_mesh(packed_scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + footprint_items.append( + { + "id": object_id, + "gravity_status": entry["gravity_status"], + "gravity_reason": entry["gravity_reason"], + "bottom_to_xy_plane_transform": entry[ + "bottom_to_xy_plane_transform" + ].tolist(), + "mesh_z_height": entry["mesh_z_height"], + "gravity_initial_height": entry["gravity_initial_height"], + "gravity_transform": entry["gravity_transform"].tolist(), + "placement_transform": placement_transform.tolist(), + "object_layout_transform": object_transform.tolist(), + "settled_xy_size": entry["settled_xy_size"].tolist(), + "target_xy_center": target_xy.tolist(), + "packed_bounds": packed_bounds.tolist(), + } + ) + + manifest = { + "status": "ok", + "method": "per_object_gravity_then_geometry_knn_2d_aabb_relaxation", + "output_dir": relative_path(str(output_dir), output_root), + "internal_up_axis": [0.0, 0.0, 1.0], + "gravity_glb_up_axis": [0.0, 1.0, 0.0], + "internal_to_gravity_glb_transform": output_axis_transform.tolist(), + "gravity_glb_to_internal_transform": output_to_internal_transform.tolist(), + "layout_optimization": layout_result["metadata"], + "items": footprint_items, + } + return { + "object_scenes": packed_object_scenes, + "object_layout_transforms": object_layout_transforms, + "manifest": manifest, + } + + + +def _optimize_xy_aabb_footprint_layout( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + current_centers: dict[str, np.ndarray], + padding_ratio: float = 0.08, +) -> dict[str, Any]: + if not object_ids: + return { + "centers": {}, + "metadata": { + "method": "geometry_knn_2d_aabb_relaxation", + "iterations": 0, + "confidence_score": 1.0, + }, + } + + max_extent = max( + float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) + for object_id in object_ids + ) + padding = max(max_extent * padding_ratio, 1e-3) + max_iterations = 300 + overlap_strength = 1.0 + neighbor_strength = 0.04 + compactness_strength = 0.01 + target_expansion_ratio = 1.2 + knn_k = min(3, max(len(object_ids) - 1, 0)) + centers = { + object_id: np.asarray( + current_centers.get(object_id, np.zeros(2, dtype=np.float64)), + dtype=np.float64, + ).copy() + for object_id in object_ids + } + centers = _center_xy_aabb_layout( + centers=centers, + xy_sizes=xy_sizes, + ) + initial_centers = { + object_id: center.copy() + for object_id, center in centers.items() + } + initial_union_bounds = _xy_union_bounds( + centers=initial_centers, + xy_sizes=xy_sizes, + ) + neighbor_edges = _knn_neighbor_edges( + centers=initial_centers, + k=knn_k, + ) + + iterations = 0 + for iteration in range(max_iterations): + iterations = iteration + 1 + max_delta = 0.0 + + for i, object_id in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + overlap = _xy_aabb_overlap( + center_a=centers[object_id], + size_a=xy_sizes[object_id], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if overlap is None: + continue + overlap_x, overlap_y = overlap + if overlap_x <= overlap_y: + axis = 0 + sign = ( + -1.0 + if centers[object_id][0] <= centers[other_id][0] + else 1.0 + ) + amount = overlap_x + else: + axis = 1 + sign = ( + -1.0 + if centers[object_id][1] <= centers[other_id][1] + else 1.0 + ) + amount = overlap_y + shift = 0.5 * (amount + 1e-6) * overlap_strength + centers[object_id][axis] += sign * shift + centers[other_id][axis] -= sign * shift + max_delta = max(max_delta, shift) + + for edge in neighbor_edges: + object_id = edge["object"] + neighbor_id = edge["neighbor"] + initial_delta = np.asarray(edge["initial_delta"], dtype=np.float64) + error = (centers[object_id] - centers[neighbor_id]) - initial_delta + correction = 0.5 * neighbor_strength * error + centers[object_id] -= correction + centers[neighbor_id] += correction + max_delta = max(max_delta, float(np.linalg.norm(correction))) + + max_delta = max( + max_delta, + _apply_compactness_pull( + centers=centers, + xy_sizes=xy_sizes, + initial_union_bounds=initial_union_bounds, + target_expansion_ratio=target_expansion_ratio, + strength=compactness_strength, + ), + ) + + centers = _center_xy_aabb_layout( + centers=centers, + xy_sizes=xy_sizes, + ) + if iteration >= 20 and max_delta < 1e-5: + break + + diagnostics = _footprint_layout_diagnostics( + object_ids=object_ids, + centers=centers, + initial_centers=initial_centers, + xy_sizes=xy_sizes, + padding=padding, + initial_union_bounds=initial_union_bounds, + ) + metadata = { + "method": "geometry_knn_2d_aabb_relaxation", + "relation_usage": "disabled", + "iterations": iterations, + "padding": padding, + "padding_ratio": padding_ratio, + "max_iterations": max_iterations, + "overlap_strength": overlap_strength, + "neighbor_strength": neighbor_strength, + "compactness_strength": compactness_strength, + "target_expansion_ratio": target_expansion_ratio, + "knn_k": knn_k, + "neighbor_edges": neighbor_edges, + "final_centers": { + object_id: centers[object_id].tolist() + for object_id in object_ids + }, + **diagnostics, + } + return {"centers": centers, "metadata": metadata} + + + +def _knn_neighbor_edges( + *, + centers: dict[str, np.ndarray], + k: int, +) -> list[dict[str, Any]]: + if k <= 0 or len(centers) < 2: + return [] + object_ids = sorted(centers) + edges: list[dict[str, Any]] = [] + seen: set[tuple[str, str]] = set() + for object_id in object_ids: + distances = [] + for other_id in object_ids: + if other_id == object_id: + continue + distance = float(np.linalg.norm(centers[object_id] - centers[other_id])) + distances.append((distance, other_id)) + for _, neighbor_id in sorted(distances)[:k]: + edge_key = tuple(sorted((object_id, neighbor_id))) + if edge_key in seen: + continue + seen.add(edge_key) + edges.append( + { + "object": object_id, + "neighbor": neighbor_id, + "initial_delta": ( + centers[object_id] - centers[neighbor_id] + ).tolist(), + } + ) + return edges + + + +def _apply_compactness_pull( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + initial_union_bounds: np.ndarray, + target_expansion_ratio: float, + strength: float, +) -> float: + current_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) + expansion_ratio = _xy_union_area(current_bounds) / max( + _xy_union_area(initial_union_bounds), + 1.0e-12, + ) + if expansion_ratio <= target_expansion_ratio: + return 0.0 + excess = min(expansion_ratio / target_expansion_ratio - 1.0, 1.0) + union_center = 0.5 * (current_bounds[0] + current_bounds[1]) + factor = strength * excess + max_delta = 0.0 + for object_id, center in centers.items(): + delta = factor * (union_center - center) + centers[object_id] = center + delta + max_delta = max(max_delta, float(np.linalg.norm(delta))) + return max_delta + + + +def _footprint_layout_diagnostics( + *, + object_ids: list[str], + centers: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + padding: float, + initial_union_bounds: np.ndarray, +) -> dict[str, Any]: + remaining_overlaps = _remaining_xy_overlaps( + object_ids=object_ids, + centers=centers, + xy_sizes=xy_sizes, + padding=padding, + ) + displacements = [ + float(np.linalg.norm(centers[object_id] - initial_centers[object_id])) + for object_id in object_ids + ] + current_union_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) + expansion_ratio = _xy_union_area(current_union_bounds) / max( + _xy_union_area(initial_union_bounds), + 1.0e-12, + ) + average_displacement = float(np.mean(displacements)) if displacements else 0.0 + max_displacement = float(np.max(displacements)) if displacements else 0.0 + confidence_score = _footprint_confidence_score( + remaining_overlap_count=len(remaining_overlaps), + average_displacement=average_displacement, + max_extent=max( + float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) + for object_id in object_ids + ) + if object_ids + else 1.0, + expansion_ratio=expansion_ratio, + ) + return { + "remaining_overlaps": remaining_overlaps, + "average_displacement": average_displacement, + "max_displacement": max_displacement, + "union_aabb_expansion_ratio": expansion_ratio, + "confidence_score": confidence_score, + } + + + +def _remaining_xy_overlaps( + *, + object_ids: list[str], + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + padding: float, +) -> list[dict[str, Any]]: + overlaps: list[dict[str, Any]] = [] + for index, object_id in enumerate(object_ids): + for other_id in object_ids[index + 1 :]: + overlap = _xy_aabb_overlap( + center_a=centers[object_id], + size_a=xy_sizes[object_id], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if overlap is None: + continue + overlaps.append( + { + "object": object_id, + "other": other_id, + "overlap_x": overlap[0], + "overlap_y": overlap[1], + } + ) + return overlaps + + + +def _footprint_confidence_score( + *, + remaining_overlap_count: int, + average_displacement: float, + max_extent: float, + expansion_ratio: float, +) -> float: + displacement_scale = max(max_extent, 1.0e-6) + overlap_penalty = min(0.35 * remaining_overlap_count, 0.7) + displacement_penalty = min(0.1 * average_displacement / displacement_scale, 0.2) + expansion_penalty = min(max(expansion_ratio - 1.2, 0.0) * 0.25, 0.2) + return float( + np.clip( + 1.0 + - overlap_penalty + - displacement_penalty + - expansion_penalty, + 0.0, + 1.0, + ) + ) + + + +def _center_xy_aabb_layout( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + if not centers: + return centers + bounds_min = [] + bounds_max = [] + for object_id, center in centers.items(): + half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) + bounds_min.append(center - half_size) + bounds_max.append(center + half_size) + clutter_center = 0.5 * ( + np.vstack(bounds_min).min(axis=0) + + np.vstack(bounds_max).max(axis=0) + ) + return { + object_id: np.asarray(center, dtype=np.float64) - clutter_center + for object_id, center in centers.items() + } + + + +def _xy_union_bounds( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], +) -> np.ndarray: + if not centers: + return np.zeros((2, 2), dtype=np.float64) + bounds_min = [] + bounds_max = [] + for object_id, center in centers.items(): + half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) + bounds_min.append(np.asarray(center, dtype=np.float64) - half_size) + bounds_max.append(np.asarray(center, dtype=np.float64) + half_size) + return np.vstack( + [ + np.vstack(bounds_min).min(axis=0), + np.vstack(bounds_max).max(axis=0), + ] + ) + + + +def _xy_union_area(bounds: np.ndarray) -> float: + bounds = np.asarray(bounds, dtype=np.float64) + size = np.maximum(bounds[1] - bounds[0], 1.0e-9) + return float(size[0] * size[1]) + + + +def _xy_aabb_overlap( + *, + center_a: np.ndarray, + size_a: np.ndarray, + center_b: np.ndarray, + size_b: np.ndarray, + padding: float, +) -> tuple[float, float] | None: + half_a = 0.5 * np.asarray(size_a, dtype=np.float64) + half_b = 0.5 * np.asarray(size_b, dtype=np.float64) + delta = np.abs( + np.asarray(center_b, dtype=np.float64) + - np.asarray(center_a, dtype=np.float64) + ) + overlap = half_a + half_b + padding - delta + if float(overlap[0]) <= 0.0 or float(overlap[1]) <= 0.0: + return None + return float(overlap[0]), float(overlap[1]) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py new file mode 100644 index 00000000..12ebfd69 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py @@ -0,0 +1,35 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.manager import ( + SimreadyManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + MakeAssetSimreadyRequest, + MakeAssetSimreadyResult, + MakeTableSimreadyRequest, + MakeTableSimreadyResult, +) + +__all__ = [ + "MakeAssetSimreadyRequest", + "MakeAssetSimreadyResult", + "MakeTableSimreadyRequest", + "MakeTableSimreadyResult", + "SimreadyManager", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py new file mode 100644 index 00000000..6f92e1f8 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py @@ -0,0 +1,396 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.manager import ( + DEFAULT_INPUT_UP_AXIS, + DEFAULT_UP_AXIS, + GeometryManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.schemas import ( + AlignToAxisRequest, + CenterMeshRequest, + ConvertUpAxisRequest, + DetectTabletopRequest, + ExportMeshRequest, + LoadMeshRequest, + NormalizeRequest, + PlaceAbovePlaneRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.manager import ( + MatplotlibManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.schemas import ( + RenderSupportRegionRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + MakeAssetSimreadyRequest, + MakeAssetSimreadyResult, + MakeTableSimreadyRequest, + MakeTableSimreadyResult, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) + + +class SimreadyManager: + """Prepare generated GLB assets for simulation placement.""" + + def __init__( + self, + *, + geometry_manager: GeometryManager | None = None, + simulation_manager: SimulationManager | None = None, + matplotlib_manager: MatplotlibManager | None = None, + ) -> None: + self.geometry_manager = geometry_manager or GeometryManager() + self.simulation_manager = simulation_manager or SimulationManager() + self.matplotlib_manager = matplotlib_manager or MatplotlibManager() + + def make_asset_simready( + self, + request: MakeAssetSimreadyRequest, + ) -> MakeAssetSimreadyResult: + input_path = request.input_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + if output_path.suffix.lower() != ".glb": + raise ValueError("Sim-ready asset output_path must be a .glb file.") + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_up_axis = _request_axis(request.input_up_axis, DEFAULT_INPUT_UP_AXIS) + raw_to_simready = np.eye(4, dtype=np.float64) + geom = self.geometry_manager + sim = self.simulation_manager + + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=input_path)).mesh + + transform = _axis_conversion_transform(input_up_axis, DEFAULT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=input_up_axis, + output_up_axis=DEFAULT_UP_AXIS, + ) + ).mesh + + center_result = geom.center_by_bbox(CenterMeshRequest(mesh=mesh)) + mesh = center_result.mesh + transform = _translation_transform(-np.asarray(center_result.bbox_center)) + raw_to_simready = transform @ raw_to_simready + + transform = _place_above_plane_transform(mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + mesh = geom.place_above_plane( + PlaceAbovePlaneRequest(mesh=mesh, clearance=request.ground_clearance) + ).mesh + + pre_gravity_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + pre_gravity_path = output_path.with_name(f".{output_path.stem}_pre_gravity.glb") + geom.export_mesh( + ExportMeshRequest(mesh=pre_gravity_mesh, output_path=pre_gravity_path) + ) + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest(glb_path=pre_gravity_path, max_convex_hull_num=32) + ) + + gravity_transform = _as_transform(gravity_result.final_pose) + settled_mesh = mesh.copy() + settled_mesh.apply_transform(gravity_transform) + raw_to_simready = gravity_transform @ raw_to_simready + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + settled_mesh.apply_transform(transform) + raw_to_simready = transform @ raw_to_simready + + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + raw_to_simready = transform @ raw_to_simready + final_mesh = _center_aabb_bottom_xy_at_origin(settled_mesh) + + normalize_result = geom.normalize(NormalizeRequest(mesh=final_mesh)) + final_mesh = normalize_result.mesh + transform = _scale_transform(normalize_result.scale_factor) + raw_to_simready = transform @ raw_to_simready + + transform = _place_above_plane_transform(final_mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.place_above_plane( + PlaceAbovePlaneRequest( + mesh=final_mesh, + clearance=request.ground_clearance, + ) + ).mesh + + transform = _axis_conversion_transform(DEFAULT_UP_AXIS, DEFAULT_INPUT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=final_mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + + geom.export_mesh(ExportMeshRequest(mesh=final_mesh, output_path=output_path)) + finally: + pre_gravity_path.unlink(missing_ok=True) + + return MakeAssetSimreadyResult( + output_path=output_path, + transform_matrix=raw_to_simready.tolist(), + ) + + def make_table_simready( + self, + request: MakeTableSimreadyRequest, + ) -> MakeTableSimreadyResult: + input_path = request.input_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + if output_path.suffix.lower() != ".glb": + raise ValueError("Sim-ready table output_path must be a .glb file.") + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_up_axis = _request_axis(request.input_up_axis, DEFAULT_INPUT_UP_AXIS) + up_axis = _request_axis(request.up_axis, DEFAULT_UP_AXIS) + raw_to_simready = np.eye(4, dtype=np.float64) + geom = self.geometry_manager + sim = self.simulation_manager + mpl = self.matplotlib_manager + + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=input_path)).mesh + + transform = _axis_conversion_transform(input_up_axis, DEFAULT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=input_up_axis, + output_up_axis=DEFAULT_UP_AXIS, + ) + ).mesh + + center_result = geom.center_by_bbox(CenterMeshRequest(mesh=mesh)) + mesh = center_result.mesh + transform = _translation_transform(-np.asarray(center_result.bbox_center)) + raw_to_simready = transform @ raw_to_simready + + detect_result = geom.detect_tabletop(DetectTabletopRequest(mesh=mesh)) + + transform = _axis_conversion_transform(detect_result.oriented_normal, up_axis) + raw_to_simready = transform @ raw_to_simready + mesh = geom.align_to_axis( + AlignToAxisRequest( + mesh=mesh, + source_axis=detect_result.oriented_normal, + target_axis=up_axis, + ) + ).mesh + + transform = _place_above_plane_transform(mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + mesh = geom.place_above_plane( + PlaceAbovePlaneRequest(mesh=mesh, clearance=request.ground_clearance) + ).mesh + + pre_gravity_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + pre_gravity_path = output_path.with_name(f".{output_path.stem}_pre_gravity.glb") + geom.export_mesh( + ExportMeshRequest(mesh=pre_gravity_mesh, output_path=pre_gravity_path) + ) + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest(glb_path=pre_gravity_path, max_convex_hull_num=16) + ) + + gravity_transform = _as_transform(gravity_result.final_pose) + settled_mesh = mesh.copy() + settled_mesh.apply_transform(gravity_transform) + raw_to_simready = gravity_transform @ raw_to_simready + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + settled_mesh.apply_transform(transform) + raw_to_simready = transform @ raw_to_simready + + settled_detect = geom.detect_tabletop( + DetectTabletopRequest(mesh=settled_mesh) + ) + + mpl.render_selected_support_region( + RenderSupportRegionRequest( + mesh=settled_mesh, + face_indices=settled_detect.selected.face_indices, + output_path=output_path.with_name( + f"{output_path.stem}_support_region.png" + ), + ) + ) + + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + raw_to_simready = transform @ raw_to_simready + final_mesh = _center_aabb_bottom_xy_at_origin(settled_mesh) + + normalize_result = geom.normalize(NormalizeRequest(mesh=final_mesh)) + final_mesh = normalize_result.mesh + transform = _scale_transform(normalize_result.scale_factor) + raw_to_simready = transform @ raw_to_simready + + transform = _place_above_plane_transform(final_mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.place_above_plane( + PlaceAbovePlaneRequest( + mesh=final_mesh, + clearance=request.ground_clearance, + ) + ).mesh + + transform = _axis_conversion_transform(DEFAULT_UP_AXIS, DEFAULT_INPUT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=final_mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + + geom.export_mesh(ExportMeshRequest(mesh=final_mesh, output_path=output_path)) + finally: + pre_gravity_path.unlink(missing_ok=True) + + return MakeTableSimreadyResult( + output_path=output_path, + transform_matrix=raw_to_simready.tolist(), + ) + + +def _request_axis(value: list[float] | None, default: tuple[float, float, float]) -> list[float]: + if value is not None: + return list(value) + return list(default) + + +def _center_aabb_bottom_xy_at_origin(mesh: Any) -> Any: + bounds = mesh.bounds + bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 + bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 + centered = mesh.copy() + centered.apply_translation([-bottom_center_x, -bottom_center_y, 0.0]) + return centered + + +def _axis_conversion_transform(source_axis: list[float], target_axis: list[float]) -> np.ndarray: + source = _normalize(np.asarray(source_axis, dtype=np.float64)) + target = _normalize(np.asarray(target_axis, dtype=np.float64)) + return _rotation_between_vectors(source, target) + + +def _place_above_plane_transform(mesh: Any, clearance: float) -> np.ndarray: + min_z = float(mesh.bounds[0][2]) + return _translation_transform(np.array([0.0, 0.0, clearance - min_z])) + + +def _center_aabb_bottom_xy_at_origin_transform(mesh: Any) -> np.ndarray: + bounds = mesh.bounds + bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 + bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 + return _translation_transform(np.array([-bottom_center_x, -bottom_center_y, 0.0])) + + +def _translation_transform(translation: np.ndarray) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, 3] = translation + return transform + + +def _scale_transform(scale: float) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] *= float(scale) + return transform + + +def _as_transform(value: Any) -> np.ndarray: + transform = np.asarray(value, dtype=np.float64) + if transform.shape != (4, 4): + raise ValueError("Expected a 4x4 transform matrix.") + return transform + + +def _rotation_between_vectors(source: np.ndarray, target: np.ndarray) -> np.ndarray: + source = _normalize(source) + target = _normalize(target) + dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) + transform = np.eye(4, dtype=np.float64) + if dot > 1.0 - 1e-8: + return transform + if dot < -1.0 + 1e-8: + axis = _orthogonal_axis(source) + rotation = _axis_angle_rotation(axis, np.pi) + else: + axis = _normalize(np.cross(source, target)) + angle = float(np.arccos(dot)) + rotation = _axis_angle_rotation(axis, angle) + transform[:3, :3] = rotation + return transform + + +def _axis_angle_rotation(axis: np.ndarray, angle: float) -> np.ndarray: + axis = _normalize(axis) + x, y, z = axis + c = float(np.cos(angle)) + s = float(np.sin(angle)) + one_c = 1.0 - c + return np.array( + [ + [c + x * x * one_c, x * y * one_c - z * s, x * z * one_c + y * s], + [y * x * one_c + z * s, c + y * y * one_c, y * z * one_c - x * s], + [z * x * one_c - y * s, z * y * one_c + x * s, c + z * z * one_c], + ], + dtype=np.float64, + ) + + +def _orthogonal_axis(vector: np.ndarray) -> np.ndarray: + axis = np.array([1.0, 0.0, 0.0], dtype=np.float64) + if abs(float(np.dot(vector, axis))) > 0.9: + axis = np.array([0.0, 1.0, 0.0], dtype=np.float64) + return _normalize(np.cross(vector, axis)) + + +def _normalize(vector: np.ndarray) -> np.ndarray: + norm = float(np.linalg.norm(vector)) + if norm == 0.0: + return vector + return vector / norm diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py new file mode 100644 index 00000000..86ae22b0 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py @@ -0,0 +1,58 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class MakeAssetSimreadyRequest: + """Request to prepare a general asset GLB for simulation placement.""" + + input_path: Path + output_path: Path + input_up_axis: list[float] | None = None + up_axis: list[float] | None = None + ground_clearance: float = 0.01 + + +@dataclass(frozen=True) +class MakeAssetSimreadyResult: + """Result of making an asset simulation-ready.""" + + output_path: Path + transform_matrix: list[list[float]] + + +@dataclass(frozen=True) +class MakeTableSimreadyRequest: + """Request to prepare a generated table GLB for simulation placement.""" + + input_path: Path + output_path: Path + input_up_axis: list[float] | None = None + up_axis: list[float] | None = None + ground_clearance: float = 0.01 + + +@dataclass(frozen=True) +class MakeTableSimreadyResult: + """Result of making a table simulation-ready.""" + + output_path: Path + transform_matrix: list[list[float]] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py new file mode 100644 index 00000000..9441c6b8 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, + GravityDropResult, +) + +__all__ = [ + "GravityDropRequest", + "GravityDropResult", + "SimulationManager", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py new file mode 100644 index 00000000..4a072110 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py @@ -0,0 +1,124 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Simulation manager for gravity-based asset placement.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch +import trimesh + +from embodichain.lab.sim.cfg import RigidObjectCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.sim_manager import ( + SimulationManager as _EmbodiSimManager, + SimulationManagerCfg, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, + GravityDropResult, +) + +__all__ = ["SimulationManager"] + + +class SimulationManager: + """Manager for gravity-based asset placement. + + Wraps an EmbodiChain simulation instance with typed request/response + methods, following the same pattern as service clients. + """ + + def __init__( + self, + *, + headless: bool = True, + physics_dt: float = 0.01, + sim_device: str = "cpu", + ) -> None: + """Initialize the simulation manager. + + Args: + headless: Whether to run without a GUI. + physics_dt: Physics timestep in seconds. + sim_device: Device to run the simulation on. + """ + self._headless = headless + self._physics_dt = physics_dt + self._sim_device = sim_device + + def run_gravity_simulation( + self, request: GravityDropRequest + ) -> GravityDropResult: + """Drop one GLB under gravity and return its final pose.""" + glb_path = request.glb_path.expanduser().resolve() + if not glb_path.is_file(): + raise FileNotFoundError(f"GLB file not found: {glb_path}") + + initial_height = ( + float(request.initial_height) + if request.initial_height is not None + else self._compute_adaptive_drop_height(glb_path) + ) + sim = _EmbodiSimManager( + SimulationManagerCfg( + headless=self._headless, + physics_dt=self._physics_dt, + sim_device=self._sim_device, + ) + ) + obj = sim.add_rigid_object( + RigidObjectCfg( + uid="dropped_asset", + shape=MeshCfg(fpath=str(glb_path)), + init_pos=(0.0, 0.0, initial_height), + init_rot=(0.0, 0.0, 0.0), + body_type="dynamic", + max_convex_hull_num=request.max_convex_hull_num, + ) + ) + sim.update(step=300) + + final_pose = obj.get_local_pose(to_matrix=True)[0].detach().cpu() + sim._deferred_destroy() + return GravityDropResult( + final_pose=np.asarray(final_pose.numpy(), dtype=float), + ) + + def _compute_adaptive_drop_height( + self, + glb_path: Path, + *, + min_clearance: float = 0.2, + height_scale: float = 1.25, + ) -> float: + """Compute an initial drop height from a GLB bounding box.""" + if min_clearance < 0.0: + raise ValueError("min_clearance must be non-negative.") + if height_scale <= 0.0: + raise ValueError("height_scale must be positive.") + + glb_path = glb_path.expanduser().resolve() + loaded = trimesh.load(glb_path, force=None) + if isinstance(loaded, trimesh.Scene): + bounds = loaded.bounds + else: + bounds = loaded.bounds + height = float(bounds[1][2] - bounds[0][2]) + return max(height * height_scale, height + min_clearance) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py new file mode 100644 index 00000000..c9df4a52 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "GravityDropRequest", + "GravityDropResult", +] + + +@dataclass(frozen=True) +class GravityDropRequest: + """Request to drop a GLB asset under gravity simulation.""" + + glb_path: Path + max_convex_hull_num: int = 32 + initial_height: float | None = None + + +@dataclass(frozen=True) +class GravityDropResult: + """Result of dropping a GLB asset under gravity.""" + + final_pose: Any diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py new file mode 100644 index 00000000..0819a0d3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py @@ -0,0 +1,23 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.table_clutter_fit_manager.manager import ( + fit_table_to_clutter, +) + +__all__ = ["fit_table_to_clutter"] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py new file mode 100644 index 00000000..987e1487 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py @@ -0,0 +1,298 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.utils.io import relative_path +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _copy_scene_with_transform, + _scene_to_mesh, + _z_up_to_glb_y_up_transform, + _detect_table_fit_support_quad, + _load_table_fit_scene_internal_z, + _table_fit_bounds_xy_manifest, + _table_fit_safe_positive_ratio, + _table_fit_scene_union_bounds, + _table_fit_uniform_xy_scale_transform, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) + +__all__ = ["fit_table_to_clutter"] + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + if not value: + return Path() + path = Path(str(value)).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root.expanduser().resolve() / path).resolve() + + +def _gravity_settle_table_fit_internal_z_scene( + scene: Any, + *, + z_to_y: np.ndarray, + sim_device: str, +) -> Any: + sim = SimulationManager(headless=True, sim_device=sim_device) + with tempfile.TemporaryDirectory(prefix="p2s_table_fit_gravity_") as tmp: + tmp_path = Path(tmp) + pre_gravity = tmp_path / "table_pre_gravity.glb" + _copy_scene_with_transform(scene, z_to_y).export(pre_gravity) + result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity, + max_convex_hull_num=16, + initial_height=0.05, + ) + ) + settled = scene.copy() + settled.apply_transform(np.asarray(result.final_pose, dtype=np.float64)) + return settled + + +def _write_table_fit_json(path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(data, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + + +def fit_table_to_clutter( + *, + table_result: dict[str, Any], + clutter_result: dict[str, Any], + output_root: Path, + output_dir: Path, + margin_cm: float = 10.0, + support_occupancy_ratio: float = 0.80, + gravity_settle_table: bool = True, + sim_device: str = "cpu", +) -> dict[str, Any]: + """Fit a table mesh to an already laid-out clutter result.""" + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Table fitting requires trimesh.") from exc + + output_root = output_root.expanduser().resolve() + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Resolve the table geometry. + table_simready_path = _resolve_generated_path( + table_result.get("simready_geometry_path") or table_result.get("mesh_path"), + output_root, + ) + if not table_simready_path.is_file(): + raise FileNotFoundError(f"Table simready GLB not found: {table_simready_path}") + + # Resolve the clutter object geometries. + settled_objects = [ + item + for item in clutter_result.get("objects", []) + if isinstance(item, dict) and item.get("status") == "ok" + ] + if not settled_objects: + raise ValueError("No successfully settled objects for table fitting.") + + object_glb_paths: list[tuple[str, Path]] = [] + for item in settled_objects: + glb_path = _resolve_generated_path( + item.get("laid_out_glb_path") or item.get("settled_glb_path"), + output_root, + ) + if glb_path.is_file(): + object_glb_paths.append((str(item["id"]), glb_path)) + + if not object_glb_paths: + raise ValueError("No valid settled object GLBs for table fitting.") + + z_to_y = _z_up_to_glb_y_up_transform() + y_to_z = np.linalg.inv(z_to_y) + + # Load the table and detect its support surface. + table_scene = _load_table_fit_scene_internal_z( + table_simready_path, + trimesh=trimesh, + y_to_z=y_to_z, + ) + table_mesh = _scene_to_mesh(table_scene, trimesh=trimesh) + clutter_aabb = clutter_result.get("clutter_2d_aabb_cm") or {} + clutter_size = clutter_aabb.get("size_xy", [1.0, 1.0]) + target_aspect = float(clutter_size[0]) / max(float(clutter_size[1]), 1.0e-6) + initial_support = _detect_table_fit_support_quad( + table_mesh, + target_aspect=target_aspect, + ) + + # Load the clutter scenes. + clutter_scenes = [ + (oid, _load_table_fit_scene_internal_z(path, trimesh=trimesh, y_to_z=y_to_z)) + for oid, path in object_glb_paths + ] + clutter_bounds = _table_fit_scene_union_bounds( + [scene for _, scene in clutter_scenes], + trimesh=trimesh, + ) + + # Compute the required table size and uniform scale. + clutter_size_cm = (clutter_bounds[1, :2] - clutter_bounds[0, :2]) * 100.0 + occupancy = float(np.clip(support_occupancy_ratio, 0.1, 1.0)) + required_size_cm = clutter_size_cm / occupancy + 2.0 * float(margin_cm) + support_size_cm = np.asarray(initial_support["size_xy"], dtype=np.float64) * 100.0 + scale_x = _table_fit_safe_positive_ratio(required_size_cm[0], support_size_cm[0]) + scale_y = _table_fit_safe_positive_ratio(required_size_cm[1], support_size_cm[1]) + uniform_scale = max(scale_x, scale_y) + table_scale_transform = _table_fit_uniform_xy_scale_transform( + center_xy=np.asarray(initial_support["center_xy"], dtype=np.float64), + scale=uniform_scale, + ) + table_scene.apply_transform(table_scale_transform) + + # Settle the table under gravity. + if gravity_settle_table: + table_scene = _gravity_settle_table_fit_internal_z_scene( + table_scene, + z_to_y=z_to_y, + sim_device=sim_device, + ) + + # Reposition the table at the origin. + final_table_mesh = _scene_to_mesh(table_scene, trimesh=trimesh) + final_support = _detect_table_fit_support_quad( + final_table_mesh, + target_aspect=float(required_size_cm[0] / max(required_size_cm[1], 1.0e-6)), + ) + support_center = np.asarray(final_support["center"], dtype=np.float64) + table_bounds = np.asarray(final_table_mesh.bounds, dtype=np.float64) + table_bottom_z = float(table_bounds[0, 2]) + + table_shift = np.eye(4, dtype=np.float64) + table_shift[:3, 3] = [-support_center[0], -support_center[1], -table_bottom_z] + table_scene.apply_transform(table_shift) + support_z_after = float((support_center + table_shift[:3, 3])[2]) + + # Measure the table surface height. + # Use the highest point of the table mesh (after scaling + gravity + shift) + # rather than the support-plane mean Z, so that thin objects sit above the + # actual geometry even when the tabletop has slight unevenness. + _table_mesh_after_shift = _scene_to_mesh(table_scene, trimesh=trimesh) + _table_max_z = float( + np.asarray(_table_mesh_after_shift.bounds, dtype=np.float64)[1, 2] + ) + _surface_z_margin = 0.01 # 1 cm above the highest table point + + # Place the objects on the table. + placed_objects: list[dict[str, Any]] = [] + shifted_clutter: list[tuple[str, Any]] = [] + clutter_after = _table_fit_scene_union_bounds( + [scene for _, scene in clutter_scenes], + trimesh=trimesh, + ) + clutter_center_xy = 0.5 * (clutter_after[0, :2] + clutter_after[1, :2]) + for oid, scene in clutter_scenes: + obj_mesh = _scene_to_mesh(scene, trimesh=trimesh) + obj_bounds = np.asarray(obj_mesh.bounds, dtype=np.float64) + obj_bottom_z = float(obj_bounds[0, 2]) + obj_shift = np.eye(4, dtype=np.float64) + obj_shift[:3, 3] = [ + -float(clutter_center_xy[0]), + -float(clutter_center_xy[1]), + _table_max_z - obj_bottom_z + _surface_z_margin, + ] + scene.apply_transform(obj_shift) + shifted_clutter.append((oid, scene)) + + # Export the fitted table and placed objects. + final_table_path = output_dir / "table_fit_to_clutter.glb" + _copy_scene_with_transform(table_scene, z_to_y).export(final_table_path) + + for oid, scene in shifted_clutter: + object_path = output_dir / f"{oid}_on_table.glb" + _copy_scene_with_transform(scene, z_to_y).export(object_path) + placed_objects.append({"id": oid, "path": str(object_path)}) + + # Write the fit manifest. + final_clutter_bounds = _table_fit_scene_union_bounds( + [scene for _, scene in shifted_clutter], + trimesh=trimesh, + ) + final_clutter_aabb_cm = _table_fit_bounds_xy_manifest( + final_clutter_bounds, + unit_scale=100.0, + ) + final_support_centered = { + **final_support, + "center": (support_center + table_shift[:3, 3]).tolist(), + "center_xy": ( + np.asarray(final_support["center_xy"], dtype=np.float64) + - support_center[:2] + ).tolist(), + "corners_xy": ( + np.asarray(final_support["corners_xy"], dtype=np.float64) + - support_center[:2] + ).tolist(), + } + manifest = { + "status": "ok", + "output_dir": str(output_dir), + "table_simready_path": str(table_simready_path), + "table_output_path": str(final_table_path), + "objects": placed_objects, + "margin_cm": margin_cm, + "support_occupancy_ratio": occupancy, + "gravity_settle_table": gravity_settle_table, + "table_bottom_z_after_shift": 0.0, + "support_z_after_shift": support_z_after, + "initial_support_quad": initial_support, + "final_support_quad_centered": final_support_centered, + "clutter_2d_aabb_cm": final_clutter_aabb_cm, + "required_support_size_cm": required_size_cm.tolist(), + "table_xy_scale": { + "uniform_scale": uniform_scale, + "scale_x_raw": scale_x, + "scale_y_raw": scale_y, + "support_size_before_scale_cm": support_size_cm.tolist(), + }, + "fit_check": { + "fits_width": float(final_clutter_aabb_cm["size_xy"][0]) + <= float(np.asarray(final_support_centered["size_xy"])[0] * 100.0), + "fits_depth": float(final_clutter_aabb_cm["size_xy"][1]) + <= float(np.asarray(final_support_centered["size_xy"])[1] * 100.0), + }, + } + manifest_path = output_dir / "table_fit_to_clutter_manifest.json" + _write_table_fit_json(manifest_path, manifest) + return { + "status": "ok", + "manifest_path": relative_path(manifest_path, output_root), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py new file mode 100644 index 00000000..ce221532 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.layout import ( + _layout_text_objects_grid, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.optimization import ( + _optimize_text_layout_slp, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.settle import ( + settle_text_objects_to_ground, +) + +__all__ = [ + "_layout_text_objects_grid", + "_optimize_text_layout_slp", + "settle_text_objects_to_ground", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py new file mode 100644 index 00000000..7b94a852 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py @@ -0,0 +1,383 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _center_xy_aabb_layout, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.optimization import ( + _optimize_text_layout_slp, +) +__all__ = [ + "_layout_text_objects_grid", +] + +def _transitive_closure( + nodes: list[str], + edges: list[tuple[str, str]], +) -> list[tuple[str, str]]: + """Floyd–Warshall transitive closure over a small set of nodes.""" + if not nodes or not edges: + return list(edges) + idx = {n: i for i, n in enumerate(nodes)} + n = len(nodes) + adj = [[False] * n for _ in range(n)] + for src, dst in edges: + if src in idx and dst in idx: + adj[idx[src]][idx[dst]] = True + for k in range(n): + for i in range(n): + if adj[i][k]: + row_k = adj[k] + row_i = adj[i] + for j in range(n): + if row_k[j]: + row_i[j] = True + closed: list[tuple[str, str]] = [] + for i in range(n): + for j in range(n): + if adj[i][j]: + closed.append((nodes[i], nodes[j])) + return closed + + + +def _longest_path_ranks( + nodes: list[str], + edges: list[tuple[str, str]], +) -> dict[str, int]: + """Assign integer ranks satisfying ``(A,B)`` → rank[A] < rank[B]. + + Uses topological sort + longest-path DP. Returns a rank dict for every + node in *nodes* (default 0 for isolated nodes). + """ + ranks: dict[str, int] = {n: 0 for n in nodes} + if not edges: + return ranks + # Build adjacency and in-degree + adj: dict[str, list[str]] = {n: [] for n in nodes} + in_deg: dict[str, int] = {n: 0 for n in nodes} + present = set(nodes) + for src, dst in edges: + if src not in present or dst not in present: + continue + adj[src].append(dst) + in_deg[dst] += 1 + # Kahn topological sort + queue = [n for n in nodes if in_deg[n] == 0] + order: list[str] = [] + while queue: + u = queue.pop(0) + order.append(u) + for v in adj[u]: + in_deg[v] -= 1 + if in_deg[v] == 0: + queue.append(v) + # Longest path + for u in order: + for v in adj[u]: + if ranks[v] < ranks[u] + 1: + ranks[v] = ranks[u] + 1 + # Remaining nodes (cycles / isolated) keep rank 0 + return ranks + + + +def _layout_text_objects_grid( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + spatial_relations: list[dict[str, Any]], + table_constraints: list[dict[str, Any]] | None = None, + grid_spacing: float = 0.02, + padding_ratio: float = 0.08, +) -> dict[str, Any]: + """Lay out text-scene objects — transitive closure + longest-path ranks. + + 1. Transitive closure of left_of / front_of. + 2. Pick centre: explicit 9‑grid ʻcenterʼ, else highest-degree node. + 3. Longest-path rank assignment (left_of→X, front_of→Y). + 4. Shift 9‑grid anchors to their grid positions. + 5. Free objects auto‑wrap below. + 6. Convert ranks→XY using per‑column/row max sizes + gaps. + 7. SA point optimisation + mesh AABB collision cleanup. + """ + if not object_ids: + return { + "centers": {}, + "initial_centers": {}, + "metadata": { + "method": "transitive_closure_longest_path_with_9grid", + "iterations": 0, + }, + } + + # Parse spatial relations. + left_of_edges: list[tuple[str, str]] = [] + front_of_edges: list[tuple[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for rel in spatial_relations: + subject = str(rel.get("subject") or "") + obj = str(rel.get("object") or "") + relation = str(rel.get("relation") or "") + if not subject or not obj or subject == obj: + continue + key = (subject, relation, obj) + if key in seen: + continue + seen.add(key) + if relation == "left_of": + left_of_edges.append((subject, obj)) + elif relation == "front_of": + front_of_edges.append((subject, obj)) + + # Compute transitive closures. + left_of_closed = _transitive_closure(object_ids, left_of_edges) + front_of_closed = _transitive_closure(object_ids, front_of_edges) + + # Parse nine-grid constraints. + # −Y = front, so front row = 0, back row = 2 + _GRID_TO_RC: dict[str, tuple[int, int]] = { + "left_front": (0, 0), "center_front": (1, 0), "right_front": (2, 0), + "left_center": (0, 1), "center": (1, 1), "right_center": (2, 1), + "left_back": (0, 2), "center_back": (1, 2), "right_back": (2, 2), + "front": (1, 0), "back": (1, 2), + "left": (0, 1), "right": (2, 1), + } + grid_targets: dict[str, tuple[int, int]] = {} + for tc in (table_constraints or []): + asset = str(tc.get("asset") or "") + grid_name = str(tc.get("grid") or "").strip() + if asset in object_ids and grid_name in _GRID_TO_RC: + grid_targets[asset] = _GRID_TO_RC[grid_name] + + # Select a center object when none is explicit. + auto_center_oid: str | None = None + has_explicit_center = any( + tc.get("grid") == "center" for tc in (table_constraints or []) + ) + if not has_explicit_center: + # Degree = appearances in left_of + front_of (subject or object) + degree: dict[str, int] = {oid: 0 for oid in object_ids} + for src, dst in left_of_closed + front_of_closed: + if src in degree: + degree[src] += 1 + if dst in degree: + degree[dst] += 1 + max_deg = max(degree.values()) if degree else 0 + if max_deg > 0: + candidates = [oid for oid, d in degree.items() if d == max_deg] + # Tie-breaker: largest AABB area + centre_oid = max( + candidates, + key=lambda oid: float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), + ) + grid_targets[centre_oid] = (1, 1) # 9‑grid centre + auto_center_oid = centre_oid + + # Derive ranks from the transitive closures. + x_rank = _longest_path_ranks(object_ids, left_of_closed) + # −Y = front: A front_of B → A.y < B.y → row[A] < row[B]. + # _longest_path_ranks gives rank[src] < rank[dst]; edges are + # already (A,B) for "A front_of B", so NO reversal needed. + y_rank = _longest_path_ranks(object_ids, front_of_closed) + + # Apply nine-grid shifts. + # Pin 9‑grid objects to their target ranks; shift all connected + # objects (both upstream and downstream) to preserve topology. + if grid_targets: + # Build undirected connected-components via relation edges + all_edges = left_of_closed + front_of_closed + neighbours: dict[str, set[str]] = {oid: set() for oid in object_ids} + for src, dst in all_edges: + if src in neighbours and dst in neighbours: + neighbours[src].add(dst) + neighbours[dst].add(src) + for oid in grid_targets: + neighbours.setdefault(oid, set()) + + # For each 9‑grid object, BFS the component and shift uniformly + shifted: set[str] = set() + for oid, (target_col, target_row) in grid_targets.items(): + if oid in shifted: + continue + dx = target_col - x_rank.get(oid, 0) + dy = target_row - y_rank.get(oid, 0) + + # BFS to collect the full connected component + component: set[str] = {oid} + queue = [oid] + while queue: + u = queue.pop(0) + for v in neighbours.get(u, set()): + if v not in component: + component.add(v) + queue.append(v) + + for oid2 in component: + if oid2 not in grid_targets: # only shift non‑anchored objects + x_rank[oid2] = x_rank.get(oid2, 0) + dx + y_rank[oid2] = y_rank.get(oid2, 0) + dy + shifted.update(component) + + # Propagate row and column alignment. + # left_of A B → same row (y_rank[A] = y_rank[B]) + # front_of A B → same col (x_rank[A] = x_rank[B]) + # Priority (higher wins): 9‑grid > higher degree > larger area. + _prio = { + oid: ( + oid in grid_targets, + sum(1 for e in left_of_closed + front_of_closed if oid in e), + float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), + ) + for oid in object_ids + } + for src, dst in left_of_closed: + if _prio[src] >= _prio[dst]: + y_rank[dst] = y_rank.get(src, 0) + else: + y_rank[src] = y_rank.get(dst, 0) + for src, dst in front_of_closed: + if _prio[src] >= _prio[dst]: + x_rank[dst] = x_rank.get(src, 0) + else: + x_rank[src] = x_rank.get(dst, 0) + + # Normalise to >= 0 + min_x = min(x_rank.values()) if x_rank else 0 + min_y = min(y_rank.values()) if y_rank else 0 + for oid in object_ids: + x_rank[oid] = x_rank.get(oid, 0) - min_x + y_rank[oid] = y_rank.get(oid, 0) - min_y + + # Resolve cell collisions: spread objects sharing the same (col, row) + cell_occupants: dict[tuple[int, int], list[str]] = {} + for oid in object_ids: + cell = (x_rank[oid], y_rank[oid]) + cell_occupants.setdefault(cell, []).append(oid) + for (col, row), occupants in cell_occupants.items(): + if len(occupants) > 1: + for offset, oid in enumerate(occupants[1:], start=1): + x_rank[oid] = col + offset + + # Place unconstrained objects in wrapped rows. + constrained = set() + for src, dst in left_of_closed + front_of_closed: + constrained.update([src, dst]) + constrained.update(grid_targets) + free_objects = [oid for oid in object_ids if oid not in constrained] + + if free_objects: + free_row = max(y_rank.values()) + 1 if y_rank else 0 + # Max row width ≈ existing union width × 1.5 (at least 3 cols) + col_keys = list(x_rank.values()) + existing_cols = max(col_keys) - min(col_keys) + 1 if col_keys else 1 + max_cols_per_row = max(existing_cols, 3) + free_sorted = sorted( + free_objects, + key=lambda oid: float(xy_sizes[oid][0]), + reverse=True, + ) + col = 0 + row_offset = 0 + for oid in free_sorted: + x_rank[oid] = col + y_rank[oid] = free_row + row_offset + col += 1 + if col >= max_cols_per_row: + col = 0 + row_offset += 1 + + # Convert ranks to XY positions. + col_widths: dict[int, float] = {} + row_heights: dict[int, float] = {} + for oid in object_ids: + c = x_rank[oid] + r = y_rank[oid] + col_widths[c] = max(col_widths.get(c, 0.0), float(xy_sizes[oid][0])) + row_heights[r] = max(row_heights.get(r, 0.0), float(xy_sizes[oid][1])) + + x_cumsum: dict[int, float] = {} + cumulative = 0.0 + for c in sorted(col_widths): + x_cumsum[c] = cumulative + cumulative += col_widths[c] + grid_spacing + + y_cumsum: dict[int, float] = {} + cumulative = 0.0 + for r in sorted(row_heights): + y_cumsum[r] = cumulative + cumulative += row_heights[r] + grid_spacing + + centers: dict[str, np.ndarray] = {} + for oid in object_ids: + c = x_rank[oid] + r = y_rank[oid] + cx = x_cumsum[c] + 0.5 * float(xy_sizes[oid][0]) + cy = y_cumsum[r] + 0.5 * float(xy_sizes[oid][1]) + centers[oid] = np.array([cx, cy], dtype=np.float64) + + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + initial_centers = {oid: c.copy() for oid, c in centers.items()} + + # Snap initial grid positions as 9‑grid spring targets + grid_spring_targets: dict[str, np.ndarray] = { + oid: initial_centers[oid].copy() + for oid in grid_targets + if oid in initial_centers + } + + # Optimize positions and remove mesh AABB collisions. + optimized = _optimize_text_layout_slp( + object_ids=object_ids, + xy_sizes=xy_sizes, + initial_centers=initial_centers, + left_of_edges=left_of_closed, + front_of_edges=front_of_closed, + grid_spring_targets=grid_spring_targets, + padding_ratio=padding_ratio, + ) + centers = optimized["centers"] + optimization_metadata = optimized["metadata"] + + # Collect layout metadata. + metadata = { + "method": "transitive_closure_longest_path_with_9grid_and_sa", + "grid_spacing": grid_spacing, + "auto_center_oid": auto_center_oid, + "has_explicit_center": has_explicit_center, + "table_constraint_count": len(grid_targets), + "left_of_count": len(left_of_edges), + "left_of_closed_count": len(left_of_closed), + "front_of_count": len(front_of_edges), + "front_of_closed_count": len(front_of_closed), + "free_object_count": len(free_objects), + "x_ranks": {oid: x_rank.get(oid, 0) for oid in object_ids}, + "y_ranks": {oid: y_rank.get(oid, 0) for oid in object_ids}, + "optimization": optimization_metadata, + } + return { + "centers": centers, + "initial_centers": initial_centers, + "metadata": metadata, + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py new file mode 100644 index 00000000..b8915fc4 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py @@ -0,0 +1,404 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np +from scipy.optimize import minimize + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _center_xy_aabb_layout, + _footprint_layout_diagnostics, + _xy_aabb_overlap, + _xy_union_bounds, +) + +__all__ = ["_optimize_text_layout_slp"] + +# SLSQP solve options — matching the original example_optimization SA pipeline. +_SLSQP_OPTIONS: dict[str, Any] = {"maxiter": 500, "ftol": 1e-6, "disp": False} + +# Objective weights (relations are hard constraints, not in the objective). +_WEIGHTS: dict[str, float] = { + "seed": 1.0, + "overlap": 200.0, + "grid": 100.0, +} + + +def _optimize_text_layout_slp( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + grid_spring_targets: dict[str, np.ndarray], + padding_ratio: float, +) -> dict[str, Any]: + """Optimize 2D centres with scipy SLSQP, then remove mesh AABB overlap. + + Mirroring the original example_optimization/SA pipeline: + - left_of / front_of → linear inequality constraints + - bounding box → variable bounds (2× initial union) + - seed / overlap / grid → soft penalties in the objective + - post‑solve collision cleanup on actual footprint AABBs + """ + if not object_ids: + return { + "centers": {}, + "metadata": { + "method": "text_slsqp_then_mesh_aabb_collision_removal", + "slsqp_iterations": 0, + "collision_iterations": 0, + }, + } + + max_extent = max( + float(max(xy_sizes[oid][0], xy_sizes[oid][1])) for oid in object_ids + ) + padding = max(max_extent * padding_ratio, 1e-3) + + initial_centers = { + oid: np.asarray(initial_centers[oid], dtype=np.float64).copy() + for oid in object_ids + } + initial_union_bounds = _xy_union_bounds( + centers=initial_centers, + xy_sizes=xy_sizes, + ) + + index_by_id = {oid: i for i, oid in enumerate(object_ids)} + x0 = _pack_centers(object_ids, initial_centers) + + # Build linear inequality constraints for left_of and front_of. + constraints: list[dict[str, Any]] = [] + _build_relation_constraints( + constraints=constraints, + object_ids=object_ids, + index_by_id=index_by_id, + xy_sizes=xy_sizes, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + padding=padding, + ) + + # Bound variables to twice the initial union size. + init_size = initial_union_bounds[1] - initial_union_bounds[0] + margin = init_size * 0.5 # 50 % each side → 2× total + bounds = [] + for oid in object_ids: + bounds.append( + ( + float(initial_union_bounds[0, 0] - margin[0]), + float(initial_union_bounds[1, 0] + margin[0]), + ) + ) # x + bounds.append( + ( + float(initial_union_bounds[0, 1] - margin[1]), + float(initial_union_bounds[1, 1] + margin[1]), + ) + ) # y + + # Define the optimization objective. + def _objective(xvec: np.ndarray) -> float: + centers = _unpack_centers(object_ids, xvec) + loss = 0.0 + + # seed: stay close to initial positions + for oid in object_ids: + delta = centers[oid] - initial_centers[oid] + loss += _WEIGHTS["seed"] * float(np.dot(delta, delta)) + + # overlap: AABB overlap area penalty + for i, oid in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + ov = _xy_aabb_overlap( + center_a=centers[oid], + size_a=xy_sizes[oid], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if ov is not None: + loss += _WEIGHTS["overlap"] * float(ov[0] * ov[1]) + + # grid: spring toward 9‑grid targets + for oid, target in grid_spring_targets.items(): + if oid not in centers: + continue + delta = centers[oid] - target + loss += _WEIGHTS["grid"] * float(np.dot(delta, delta)) + + return float(loss) + + # Solve the constrained optimization problem. + slsqp_result: dict[str, Any] = {"success": False, "nit": 0, "message": ""} + try: + result = minimize( + _objective, + x0, + method="SLSQP", + bounds=bounds, + constraints=constraints, + options=_SLSQP_OPTIONS, + ) + slsqp_result = { + "success": bool(result.success), + "nit": int(getattr(result, "nit", 0)), + "message": str(result.message), + "fun": float(result.fun) if result.fun is not None else None, + } + if result.success: + x_opt = result.x + else: + # SLSQP failed — fall back to seed positions + x_opt = x0.copy() + except Exception: + x_opt = x0.copy() + slsqp_result["message"] = "SLSQP raised an exception; using seed positions." + + centers = _unpack_centers(object_ids, x_opt) + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + # Remove residual collisions. + centers, collision_metadata = _remove_mesh_aabb_collisions( + object_ids=object_ids, + xy_sizes=xy_sizes, + centers=centers, + initial_centers=initial_centers, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + padding=padding, + ) + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + # Collect optimization metadata. + diagnostics = _footprint_layout_diagnostics( + object_ids=object_ids, + centers=centers, + initial_centers=initial_centers, + xy_sizes=xy_sizes, + padding=padding, + initial_union_bounds=initial_union_bounds, + ) + metadata: dict[str, Any] = { + "method": "text_slsqp_then_mesh_aabb_collision_removal", + "relation_usage": "left_of_front_of_hard_constraints", + "padding": float(padding), + "padding_ratio": float(padding_ratio), + "weights": dict(_WEIGHTS), + "slsqp": slsqp_result, + "bounds_expansion": 2.0, + "initial_union_size": init_size.tolist(), + **collision_metadata, + "final_centers": { + oid: centers[oid].tolist() for oid in object_ids + }, + **diagnostics, + } + return {"centers": centers, "metadata": metadata} + + +# Build relation constraints. + + +def _build_relation_constraints( + *, + constraints: list[dict[str, Any]], + object_ids: list[str], + index_by_id: dict[str, int], + xy_sizes: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + padding: float, +) -> None: + """Append SLSQP inequality constraints for left_of / front_of edges.""" + + for subject, obj in left_of_edges: + if subject not in index_by_id or obj not in index_by_id: + continue + i_a = index_by_id[subject] + i_b = index_by_id[obj] + # A.x + gap ≤ B.x → B.x - A.x - gap ≥ 0 + gap = ( + 0.5 * float(xy_sizes[subject][0]) + + 0.5 * float(xy_sizes[obj][0]) + + padding + ) + constraints.append( + { + "type": "ineq", + "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( + x[2 * ib] - x[2 * ia] - g + ), + } + ) + + for subject, obj in front_of_edges: + if subject not in index_by_id or obj not in index_by_id: + continue + i_a = index_by_id[subject] + i_b = index_by_id[obj] + # A.y + gap ≤ B.y → B.y - A.y - gap ≥ 0 + gap = ( + 0.5 * float(xy_sizes[subject][1]) + + 0.5 * float(xy_sizes[obj][1]) + + padding + ) + constraints.append( + { + "type": "ineq", + "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( + x[2 * ib + 1] - x[2 * ia + 1] - g + ), + } + ) + + +# Remove AABB collisions. + + +def _remove_mesh_aabb_collisions( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + centers: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + padding: float, +) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + relation_pairs = set(left_of_edges + front_of_edges) + relation_pairs.update((b, a) for a, b in left_of_edges + front_of_edges) + current = { + oid: np.asarray(center, dtype=np.float64).copy() + for oid, center in centers.items() + } + max_rounds = 80 + total_pushes = 0 + last_overlap_count = 0 + + for iteration in range(max_rounds): + overlaps = _mesh_aabb_collision_pairs( + object_ids=object_ids, + xy_sizes=xy_sizes, + centers=current, + padding=padding, + ) + last_overlap_count = len(overlaps) + if not overlaps: + return current, { + "collision_iterations": iteration, + "collision_pushes": total_pushes, + "collision_remaining": 0, + "collision_removal": "iterative_mesh_aabb_push", + } + for item in overlaps: + object_a = item["object"] + object_b = item["other"] + axis = int(item["axis"]) + sign = -1.0 if current[object_a][axis] <= current[object_b][axis] else 1.0 + amount = 0.5 * (float(item["overlap"]) + 1.0e-6) + if (object_a, object_b) in relation_pairs: + current[object_a][axis] += sign * amount + current[object_b][axis] -= sign * amount + else: + drift_a = np.linalg.norm( + current[object_a] - initial_centers[object_a] + ) + drift_b = np.linalg.norm( + current[object_b] - initial_centers[object_b] + ) + if drift_a <= drift_b: + current[object_a][axis] += sign * amount * 1.25 + current[object_b][axis] -= sign * amount * 0.75 + else: + current[object_a][axis] += sign * amount * 0.75 + current[object_b][axis] -= sign * amount * 1.25 + total_pushes += 1 + current = _center_xy_aabb_layout(centers=current, xy_sizes=xy_sizes) + + return current, { + "collision_iterations": max_rounds, + "collision_pushes": total_pushes, + "collision_remaining": last_overlap_count, + "collision_removal": "iterative_mesh_aabb_push", + } + + +def _mesh_aabb_collision_pairs( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + centers: dict[str, np.ndarray], + padding: float, +) -> list[dict[str, Any]]: + pairs: list[dict[str, Any]] = [] + for i, oid in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + ov = _xy_aabb_overlap( + center_a=centers[oid], + size_a=xy_sizes[oid], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if ov is None: + continue + axis = 0 if ov[0] <= ov[1] else 1 + pairs.append( + { + "object": oid, + "other": other_id, + "axis": axis, + "overlap": float(ov[axis]), + "overlap_x": float(ov[0]), + "overlap_y": float(ov[1]), + } + ) + pairs.sort(key=lambda item: item["overlap"], reverse=True) + return pairs + + +# Pack and unpack center coordinates. + + +def _pack_centers( + object_ids: list[str], + centers: dict[str, np.ndarray], +) -> np.ndarray: + values: list[float] = [] + for oid in object_ids: + c = np.asarray(centers[oid], dtype=np.float64) + values.extend([float(c[0]), float(c[1])]) + return np.asarray(values, dtype=np.float64) + + +def _unpack_centers( + object_ids: list[str], + xvec: np.ndarray, +) -> dict[str, np.ndarray]: + return { + oid: np.asarray( + [xvec[2 * i], xvec[2 * i + 1]], + dtype=np.float64, + ) + for i, oid in enumerate(object_ids) + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py new file mode 100644 index 00000000..da3cdde6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py @@ -0,0 +1,429 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import tempfile +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _object_scenes_xy_aabb_manifest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _aabb_bottom_to_xy_plane_transform, + _copy_scene_with_transform, + _matrix_from_json, + _scale_transform, + _scene_to_mesh, + _xy_aabb_center, + _xy_aabb_size, + _z_up_to_glb_y_up_transform, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, + write_json, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_warning +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager import ( + MatplotlibManager, + RenderFootprintLayoutRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.layout import ( + _layout_text_objects_grid, +) + +__all__ = ["settle_text_objects_to_ground"] + + +def settle_text_objects_to_ground( + *, + objects: list[dict[str, Any]], + spatial_relations: list[dict[str, Any]] | None = None, + table_constraints: list[dict[str, Any]] | None = None, + output_dir: Path, + output_root: Path, + sim_device: str = "cpu", +) -> dict[str, Any]: + """Scale simready objects to real-world size, gravity-settle, layout on table. + + For each text-input object: + 1. Load simready GLB (GLB Y-up) → convert to internal Z-up + 2. Apply scene-level metric scale_factor → real-world size + 3. Gravity simulation to settle on ground plane + 4. Move AABB bottom centre to XY origin at Z=0 + 5. Build grid/rank initialization from left_of/front_of and table constraints + 6. Run SA-based 2D point optimization and mesh AABB collision cleanup + 7. Apply layout positions + + Returns laid-out scenes and per-object metadata. + """ + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Text object gravity settling requires trimesh.") from exc + + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + sim = SimulationManager(headless=True, sim_device=sim_device) + z_to_y = _z_up_to_glb_y_up_transform() + y_to_z = np.linalg.inv(z_to_y) + + settled_objects: list[dict[str, Any]] = [] + object_scenes: list[tuple[str, Any]] = [] + + with tempfile.TemporaryDirectory(prefix="p2s_text_settle_") as tmp_dir: + tmp_path = Path(tmp_dir) + for obj in objects: + obj_id = str(obj.get("id", "")) + obj_name = str(obj.get("name", "")) + + # Validate the metric scale. + metric_scale = obj.get("metric_scale") + if not isinstance(metric_scale, dict): + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "skipped", + "reason": "missing_metric_scale", + } + ) + continue + scale_factor = float(metric_scale.get("scale_factor", 1.0)) + if not np.isfinite(scale_factor) or scale_factor <= 0.0: + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "skipped", + "reason": "invalid_scale_factor", + } + ) + continue + + # Load the simulation-ready GLB. + simready_path = _resolve_generated_path( + obj.get("simready_geometry_path") or obj.get("mesh_path"), + output_root, + ) + if not simready_path.is_file(): + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "skipped", + "reason": "missing_simready_glb", + } + ) + continue + + try: + # Load simready (GLB Y-up) → convert to internal Z-up + scene_yup = trimesh.load(simready_path, force="scene") + scene = _copy_scene_with_transform(scene_yup, y_to_z) + + # Apply real-world scale + scale_transform = _scale_transform(scale_factor) + scene.apply_transform(scale_transform) + + # Settle the object under gravity. + mesh = _scene_to_mesh(scene, trimesh=trimesh) + mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) + mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) + bottom_to_xy = _aabb_bottom_to_xy_plane_transform(mesh_bounds) + normalized_scene = _copy_scene_with_transform(scene, bottom_to_xy) + + # Export to Y-up GLB for gravity + pre_gravity_scene = _copy_scene_with_transform(normalized_scene, z_to_y) + pre_gravity_path = tmp_path / f"{obj_id}_pre_gravity.glb" + pre_gravity_scene.export(pre_gravity_path) + gravity_initial_height = mesh_z_height * 0.1 + + gravity_status = "ok" + gravity_transform = np.eye(4, dtype=np.float64) + gravity_reason = "" + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity_path, + max_convex_hull_num=32, + initial_height=gravity_initial_height, + ) + ) + gravity_transform = _matrix_from_json( + gravity_result.final_pose, + name=f"{obj_id}.gravity_final_pose", + ) + except Exception: + gravity_status = "failed" + gravity_reason = traceback.format_exc() + + # Apply gravity result (in internal Z-up space) + settled_scene = _copy_scene_with_transform( + normalized_scene, + gravity_transform, + ) + + # Center the bottom of the AABB at the XY origin. + settled_mesh = _scene_to_mesh(settled_scene, trimesh=trimesh) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + settled_xy_center = _xy_aabb_center(settled_bounds) + settled_xy_size = _xy_aabb_size(settled_bounds) + settled_bottom_z = float(settled_bounds[0, 2]) + + centre_transform = np.eye(4, dtype=np.float64) + centre_transform[:3, 3] = [ + -float(settled_xy_center[0]), + -float(settled_xy_center[1]), + -settled_bottom_z, + ] + centred_scene = _copy_scene_with_transform( + settled_scene, + centre_transform, + ) + + # Verify final bounds + centred_mesh = _scene_to_mesh(centred_scene, trimesh=trimesh) + centred_bounds = np.asarray(centred_mesh.bounds, dtype=np.float64) + centred_xy_size = _xy_aabb_size(centred_bounds) + + # Export settled GLB (Z-up → Y-up for GLB output) + settled_glb_path = output_dir / f"{obj_id}_settled.glb" + _copy_scene_with_transform(centred_scene, z_to_y).export( + settled_glb_path + ) + + item = { + "id": obj_id, + "name": obj_name, + "status": "ok", + "gravity_status": gravity_status, + "gravity_reason": gravity_reason, + "scale_factor": scale_factor, + "settled_glb_path": relative_path( + str(settled_glb_path), + output_root, + ), + "settled_xy_size_m": centred_xy_size.tolist(), + "settled_xy_size_cm": (centred_xy_size * 100.0).tolist(), + "settled_bounds_m": centred_bounds.tolist(), + "mesh_z_height_m": mesh_z_height, + "bottom_to_xy_transform": bottom_to_xy.tolist(), + "gravity_transform": gravity_transform.tolist(), + "centre_transform": centre_transform.tolist(), + "composed_settle_transform": ( + centre_transform + @ gravity_transform + @ bottom_to_xy + @ scale_transform + @ y_to_z + ).tolist(), + } + settled_objects.append(item) + object_scenes.append((obj_id, centred_scene)) + + except Exception: + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "failed", + "reason": traceback.format_exc(), + } + ) + + # Optimize the spatial layout. + layout_result = None + if object_scenes: + xy_sizes = { + oid: np.asarray( + _xy_aabb_size(_scene_to_mesh(scene, trimesh=trimesh).bounds), + dtype=np.float64, + ) + for oid, scene in object_scenes + } + relations = list(spatial_relations or []) + layout_result = _layout_text_objects_grid( + object_ids=[oid for oid, _ in object_scenes], + xy_sizes=xy_sizes, + spatial_relations=relations, + table_constraints=list(table_constraints or []), + ) + target_centers = layout_result["centers"] + initial_centers = layout_result.get("initial_centers", {}) + + # Render footprint layout diagnostics. + debug_dir = output_dir / "debug" + debug_dir.mkdir(parents=True, exist_ok=True) + debug_object_ids = [oid for oid, _ in object_scenes] + debug_before_centers = { + oid: np.zeros(2, dtype=np.float64) for oid in debug_object_ids + } + debug_renders = ( + ( + "footprint_layout_xy_before.png", + "Before Layout (all at origin)", + debug_before_centers, + ), + ( + "footprint_layout_xy_grid_init.png", + "After Grid Initialisation", + initial_centers, + ), + ( + "footprint_layout_xy_after.png", + "After SA Optimisation", + target_centers, + ), + ) + for filename, title, debug_centers in debug_renders: + try: + MatplotlibManager(figsize=(8, 8), dpi=180).render_footprint_layout( + RenderFootprintLayoutRequest( + object_ids=debug_object_ids, + centers=debug_centers, + xy_sizes=xy_sizes, + output_path=debug_dir / filename, + title=title, + ) + ) + except Exception as exc: + log_warning( + f"text clutter debug render failed file={filename} error={exc}" + ) + + # Apply layout positions to centred scenes + laid_out_scenes: list[tuple[str, Any]] = [] + for oid, scene in object_scenes: + target_xy = target_centers[oid] + settled_mesh = _scene_to_mesh(scene, trimesh=trimesh) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + current_xy = _xy_aabb_center(settled_bounds) + placement = np.eye(4, dtype=np.float64) + placement[:3, 3] = [ + float(target_xy[0] - current_xy[0]), + float(target_xy[1] - current_xy[1]), + 0.0, + ] + laid_out_scene = _copy_scene_with_transform(scene, placement) + laid_out_scenes.append((oid, laid_out_scene)) + + # Export laid-out GLB (replaces the origin-centred one) + laid_out_glb_path = output_dir / f"{oid}_laid_out.glb" + _copy_scene_with_transform(laid_out_scene, z_to_y).export(laid_out_glb_path) + + # Update per-object metadata with layout position + for item in settled_objects: + if item.get("id") == oid: + item["layout_target_xy"] = target_xy.tolist() + item["layout_placement_transform"] = placement.tolist() + item["laid_out_glb_path"] = relative_path( + str(laid_out_glb_path), output_root + ) + laid_out_bounds = np.asarray( + _scene_to_mesh(laid_out_scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + item["laid_out_xy_size_cm"] = ( + _xy_aabb_size(laid_out_bounds) * 100.0 + ).tolist() + break + + object_scenes = laid_out_scenes + + clutter_2d_aabb_cm = _object_scenes_xy_aabb_manifest( + object_scenes=object_scenes, + trimesh=trimesh, + unit_scale=100.0, + unit="cm", + ) + + debug_manifest = { + "status": "ok", + "output_dir": relative_path(str(output_dir), output_root), + "object_count": len(objects), + "settled_count": len(object_scenes), + "clutter_2d_aabb_cm": clutter_2d_aabb_cm, + "debug_image_before_path": ( + relative_path( + str(debug_dir / "footprint_layout_xy_before.png"), + output_root, + ) + if object_scenes + else "" + ), + "debug_image_grid_init_path": ( + relative_path( + str(debug_dir / "footprint_layout_xy_grid_init.png"), + output_root, + ) + if object_scenes + else "" + ), + "debug_image_after_path": ( + relative_path( + str(debug_dir / "footprint_layout_xy_after.png"), + output_root, + ) + if object_scenes + else "" + ), + "layout_optimization": layout_result["metadata"] if layout_result else None, + "objects": settled_objects, + } + debug_manifest_path = output_dir / "debug" / "settle_diagnostics.json" + write_json(debug_manifest_path, debug_manifest) + + # Keep workflow state limited to the contract consumed by table fitting. + workflow_objects = [ + { + key: item[key] + for key in ( + "id", + "name", + "status", + "reason", + "settled_glb_path", + "laid_out_glb_path", + ) + if key in item + } + for item in settled_objects + ] + return { + "status": "ok", + "clutter_2d_aabb_cm": clutter_2d_aabb_cm, + "objects": workflow_objects, + "debug_manifest_path": relative_path(str(debug_manifest_path), output_root), + } + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py new file mode 100644 index 00000000..e50272ef --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py @@ -0,0 +1,16 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +"""External servers, ignored by git, for testing or demo purposes.""" \ No newline at end of file diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py new file mode 100644 index 00000000..015c4151 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py new file mode 100644 index 00000000..9f3c638f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py @@ -0,0 +1,319 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import math +import shutil +import time +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + STEP_RESULT_FILENAME, + UNIFIED_SCENE_GEN_STEP, +) + +__all__ = ["export_gym_config"] + +_DEFAULT_OBJECT_ATTRS: dict[str, Any] = { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 10.0, + "min_position_iters": 32, + "min_velocity_iters": 8, +} + +_DEFAULT_TABLE_ATTRS: dict[str, Any] = { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01, +} + +_DEFAULT_MAX_CONVEX_HULL_NUM = 8 + + +def _resolve_path(value: str, output_root: Path) -> Path: + path = Path(value).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"Expected JSON object at {path}") + return data + + +def _matrix_to_euler_xyz_deg(matrix: list[list[float]]) -> list[float]: + """Decompose a 3×3 or 4×4 rotation matrix into XYZ Euler angles (degrees).""" + m = np.asarray(matrix, dtype=np.float64) + r = m[:3, :3] + sy = math.sqrt(float(r[0, 0]) ** 2 + float(r[1, 0]) ** 2) + if sy > 1e-6: + x = math.atan2(float(r[2, 1]), float(r[2, 2])) + y = math.atan2(-float(r[2, 0]), sy) + z = math.atan2(float(r[1, 0]), float(r[0, 0])) + else: + x = math.atan2(-float(r[1, 2]), float(r[1, 1])) + y = math.atan2(-float(r[2, 0]), sy) + z = 0.0 + return [math.degrees(x), math.degrees(y), math.degrees(z)] + + +def _glb_aabb_bottom_center(glb_path: Path) -> list[float]: + """``[x, y, z]`` bottom-centre position in **simulation Z-up** space. + + The GLB is stored in Y-up convention (X=width, Y=up, Z=depth). + EmbodiChain simulation converts to Z-up on load, so we return the + position in Z-up space: ``center_X``, ``-center_Z``, ``min_Y``. + """ + import trimesh + + scene = trimesh.load(glb_path, force="scene") + if isinstance(scene, trimesh.Trimesh): + mesh = scene + else: + dumped = scene.dump(concatenate=True) + mesh = ( + dumped + if isinstance(dumped, trimesh.Trimesh) + else trimesh.util.concatenate( + [m for m in dumped if isinstance(m, trimesh.Trimesh)] + ) + ) + b = np.asarray(mesh.bounds, dtype=np.float64) + return [ + float(0.5 * (b[0, 0] + b[1, 0])), # centre X + float(-0.5 * (b[0, 2] + b[1, 2])), # -centre Z (GLB Z → internal -Y) + float(b[0, 1]), # min Y (GLB up → internal Z) + ] + + +def _glb_max_z(glb_path: Path) -> float: + """Maximum height (Y in GLB, Z in simulation) of a mesh.""" + import trimesh + + scene = trimesh.load(glb_path, force="scene") + if isinstance(scene, trimesh.Trimesh): + mesh = scene + else: + dumped = scene.dump(concatenate=True) + mesh = ( + dumped + if isinstance(dumped, trimesh.Trimesh) + else trimesh.util.concatenate( + [m for m in dumped if isinstance(m, trimesh.Trimesh)] + ) + ) + return float(np.asarray(mesh.bounds, dtype=np.float64)[1, 1]) # max Y + + +def export_gym_config( + output_root: Path, + *, + export_dir: Path | None = None, +) -> Path: + """Export the unified-scene-gen result as a gym_config.json bundle. + + Uses **simready** GLBs — transforms are written explicitly as + ``body_scale``, ``init_pos``, and ``init_rot``. + """ + output_root = output_root.expanduser().resolve() + if export_dir is None: + export_dir = output_root / "gym_export" + else: + export_dir = export_dir.expanduser().resolve() + export_dir.mkdir(parents=True, exist_ok=True) + + # ── step result & table-fit manifest ────────────────────────────── + step_result = _read_json( + output_root / UNIFIED_SCENE_GEN_STEP / STEP_RESULT_FILENAME + ) + table_fit = step_result.get("table_fit_to_clutter") or {} + manifest = _read_json( + _resolve_path(table_fit.get("manifest_path", ""), output_root) + ) + + # ── per-object metadata from simready→aligned manifest ──────────── + aligned_by_id: dict[str, dict[str, Any]] = {} + aligned_manifest_path = ( + output_root / UNIFIED_SCENE_GEN_STEP / "glb_gen" / "simready_to_aligned_manifest.json" + ) + if aligned_manifest_path.is_file(): + aligned_manifest = _read_json(aligned_manifest_path) + for item in aligned_manifest.get("items", []) or []: + if isinstance(item, dict): + aligned_by_id[str(item.get("id", ""))] = item + + # ── table surface Z (from fitted table GLB) ─────────────────────── + fitted_table_path = _resolve_path( + manifest.get("table_output_path", ""), output_root + ) + table_surface_z = ( + _glb_max_z(fitted_table_path) if fitted_table_path.is_file() else 0.0 + ) + + # ── description lookup ──────────────────────────────────────────── + object_meta_by_id: dict[str, dict[str, str]] = {} + for obj in step_result.get("objects", []) or []: + if isinstance(obj, dict): + oid = str(obj.get("id", "")) + if oid: + object_meta_by_id[oid] = { + "description": str(obj.get("description", "")).strip(), + "name": str(obj.get("name", "")).strip(), + } + + table_info = step_result.get("table") or {} + table_desc = str( + table_info.get("complete_table_description") + or table_info.get("description", "") + ).strip() + + mesh_assets_dir = export_dir / "mesh_assets" + mesh_assets_dir.mkdir(parents=True, exist_ok=True) + + # ── table ───────────────────────────────────────────────────────── + table_simready = _resolve_path( + table_info.get("simready_geometry_path") + or table_info.get("mesh_path", ""), + output_root, + ) + if not table_simready.is_file(): + raise FileNotFoundError(f"Table simready GLB not found: {table_simready}") + table_dst = mesh_assets_dir / "table" / "table_0.glb" + table_dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(table_simready, table_dst) + + uniform_scale = 1.0 + ts = manifest.get("table_xy_scale") + if isinstance(ts, dict): + uniform_scale = float(ts.get("uniform_scale", 1.0)) + + # ── objects ─────────────────────────────────────────────────────── + table_fit_objects = { + str(e["id"]): _resolve_path(e["path"], output_root) + for e in (manifest.get("objects") or []) + if isinstance(e, dict) + } + objects_info = step_result.get("objects") or [] + rigid_objects: list[dict[str, Any]] = [] + + def _obj_desc(obj: dict[str, Any]) -> str: + meta = object_meta_by_id.get(str(obj.get("id", ""))) + return (meta["description"] or meta["name"]) if meta else "" + + for obj in objects_info: + if not isinstance(obj, dict): + continue + object_id = str(obj.get("id", "")) + if not object_id: + continue + + # ── GLB: simready (normalised, no baked transforms) ────────── + source = obj.get("simready_geometry_path") or obj.get("mesh_path") + object_src = _resolve_path(source, output_root) + if not object_src.is_file(): + continue + + safe_name = object_id.replace("interact_", "").strip("_") or "object" + obj_dir = mesh_assets_dir / safe_name / object_id + obj_dir.mkdir(parents=True, exist_ok=True) + object_dst = obj_dir / f"{object_id}.glb" + shutil.copy2(object_src, object_dst) + + # ── body_scale ──────────────────────────────────────────────── + ms = obj.get("metric_scale") + scale_factor = float(ms.get("scale_factor", 1.0)) if isinstance(ms, dict) else 1.0 + body_scale = [scale_factor, scale_factor, scale_factor] + + # ── init_pos: read from fitted on-table GLB ─────────────────── + fitted_path = table_fit_objects.get(object_id) + if fitted_path and fitted_path.is_file(): + init_pos = _glb_aabb_bottom_center(fitted_path) + else: + init_pos = [0.0, 0.0, table_surface_z] + + # ── init_rot: decompose from simready→aligned rotation ──────── + init_rot: list[float] = [0.0, 0.0, 0.0] + aligned = aligned_by_id.get(object_id) + if aligned: + rot = aligned.get("rotation_matrix") + if rot and isinstance(rot, list): + init_rot = _matrix_to_euler_xyz_deg(rot) + + rigid_objects.append( + { + "uid": object_id, + "description": _obj_desc(obj), + "shape": { + "shape_type": "Mesh", + "fpath": str(object_dst.relative_to(export_dir)), + "compute_uv": False, + }, + "attrs": dict(_DEFAULT_OBJECT_ATTRS), + "body_type": "dynamic", + "init_pos": init_pos, + "init_rot": init_rot, + "body_scale": body_scale, + "max_convex_hull_num": _DEFAULT_MAX_CONVEX_HULL_NUM, + } + ) + + # ── write config ────────────────────────────────────────────────── + config = { + "id": f"Prompt2Scene-{int(time.time() * 1000)}-v0", + "max_episodes": 10, + "max_episode_steps": 300, + "env": {"events": {}, "observations": {}, "dataset": {}}, + "robot": {}, + "sensor": [], + "light": {}, + "background": [ + { + "uid": "table", + "description": table_desc, + "shape": { + "shape_type": "Mesh", + "fpath": str(table_dst.relative_to(export_dir)), + "compute_uv": False, + }, + "attrs": dict(_DEFAULT_TABLE_ATTRS), + "body_scale": [uniform_scale, uniform_scale, 1.0], + "body_type": "kinematic", + "init_pos": [0.0, 0.0, 0.0], + "init_rot": [0.0, 0.0, 0.0], + } + ], + "rigid_object": rigid_objects, + } + + config_path = export_dir / "gym_config.json" + config_path.write_text( + json.dumps(config, indent=4, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + return config_path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py new file mode 100644 index 00000000..2275c40f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -0,0 +1,636 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import shutil +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + decode_rle_mask, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager import ( + GeometryGenerationManager, + RgbaImageToGeometryRequest, + RgbaImagesToGeometriesRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager import ( + ImageGenerationManager, + TextToAssetImageRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager import ( + AssetImageToRgbaRequest, + ImageSegmentationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager import ( + _export_support_aligned_layout_glbs, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( + MakeAssetSimreadyRequest, + MakeTableSimreadyRequest, + SimreadyManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( + METRIC_SCALE_ENABLED, + EstimateMetricScalesRequest, + MetricScaleManager, + MetricScaleObjectInput, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _compose_sam3d_multi_object_transform, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager import ( + _write_multi_object_layout_manifests, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.prompts import ( + build_image_metric_scale_messages, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.schemas import ( + IMAGE_METRIC_SCALE_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) + +__all__ = ["generate_image_scene_assets"] + +UNIFIED_SCENE_STEP = "unified_scene" + + +def generate_image_scene_assets( + object_specs: list[dict[str, Any]], + table_spec: dict[str, Any], + spatial_relations: list[dict[str, Any]], + segments_data: dict[str, Any], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, + output_root: Path, + llm: Any | None = None, +) -> dict[str, Any]: + """Run layout-aware table/support and object generation from image masks.""" + log_info(f"image object layout generation started count={len(object_specs)}") + status = "ok" + failure_reason = "" + original_image_path = str(segments_data.get("image_path", "")) + segment_by_id: dict[str, dict[str, Any]] = { + str(seg["asset_id"]): seg + for seg in segments_data.get("asset_segments", []) + if seg.get("asset_id") + } + table_segment = segments_data.get("table_segment") + if not isinstance(table_segment, dict): + table_segment = None + debug_subdir = debug_dir / "multi_object_masks" + masks_dir = debug_subdir / "masks" + raw_download_dir = glb_gen_dir / "raw_downloads" + simready_dir = glb_gen_dir / "multi_object_layouts_simready" + aligned_dir = glb_gen_dir / "multi_object_layouts_aligned" + masks_dir.mkdir(parents=True, exist_ok=True) + raw_download_dir.mkdir(parents=True, exist_ok=True) + simready_dir.mkdir(parents=True, exist_ok=True) + aligned_dir.mkdir(parents=True, exist_ok=True) + + requested_items: list[dict[str, Any]] = [] + mask_paths: list[Path] = [] + + table_id = str(table_spec.get("id", "table")).strip() or "table" + table_name = str(table_spec.get("name", "table")).strip() or "table" + is_complete_visible_table = bool( + table_spec.get("is_complete_visible_table", False) + ) + skipped_table: dict[str, Any] | None = None + if table_segment is None: + skipped_table = { + "id": table_id, + "name": table_name, + "reason": "missing_table_segment", + } + else: + table_mask_rle = table_segment.get("mask_rle") + if table_mask_rle is None: + skipped_table = { + "id": table_id, + "name": table_name, + "reason": "missing_table_mask_rle", + } + else: + mask_path = masks_dir / f"{len(requested_items):04d}_{table_id}_mask.png" + decode_rle_mask(table_mask_rle).save(mask_path) + mask_paths.append(mask_path) + requested_items.append( + { + "id": table_id, + "name": table_name, + "kind": "table", + "mask_path": str(mask_path), + } + ) + + for obj_spec in object_specs: + obj_id = str(obj_spec.get("id", "")).strip() + obj_name = str(obj_spec.get("name", "")).strip() + if not obj_id: + continue + segment = segment_by_id.get(obj_id) + if segment is None: + continue + mask_rle = segment.get("mask_rle") + if mask_rle is None: + continue + + mask_path = masks_dir / f"{len(requested_items):04d}_{obj_id}_mask.png" + decode_rle_mask(mask_rle).save(mask_path) + mask_paths.append(mask_path) + requested_items.append( + { + "id": obj_id, + "name": obj_name, + "description": str(obj_spec.get("description", "")), + "kind": "object", + "mask_path": str(mask_path), + } + ) + + generated_objects: list[dict[str, Any]] = [] + generated_table: dict[str, Any] | None = None + image_manager = ImageGenerationManager() + segmentation_manager = ImageSegmentationManager() + geometry_manager = GeometryGenerationManager() + simready_manager = SimreadyManager() + try: + if skipped_table is not None: + raise ValueError( + "No valid table/support mask found for image multi-object " + f"layout generation: {skipped_table['reason']}" + ) + if not mask_paths: + raise ValueError( + "No valid masks found for image multi-object layout generation." + ) + + result = geometry_manager.convert_rgba_images_to_geometries( + RgbaImagesToGeometriesRequest( + image_path=Path(original_image_path), + mask_paths=mask_paths, + output_dir=raw_download_dir, + ) + ) + if len(result.objects) != len(requested_items): + raise RuntimeError( + "Multi-object SAM3D result count mismatch: " + f"requested {len(requested_items)}, got {len(result.objects)}" + ) + for requested, generated in zip(requested_items, result.objects): + expected_sam3d_name = Path(requested["mask_path"]).stem + if generated.name != expected_sam3d_name: + raise RuntimeError( + "Multi-object SAM3D result order mismatch: " + f"expected {expected_sam3d_name!r}, got {generated.name!r}" + ) + downloaded_raw_path = Path(generated.geometry_path).expanduser().resolve() + raw_geometry_path = str(downloaded_raw_path) + status_parts: list[str] = [] + transform_matrix: list[list[float]] = [] + try: + transform = _compose_sam3d_multi_object_transform( + rotation_quaternion_wxyz=generated.rotation_quaternion_wxyz, + translation=generated.translation, + scale=generated.scale, + ) + transform_matrix = transform.tolist() + except Exception: + status_parts.append( + f"transform_matrix_failed: {traceback.format_exc()}" + ) + + simready_geometry_path = "" + raw_to_simready_glb_matrix: list[list[float]] = [] + metric_scale: dict[str, Any] | None = None + try: + if requested["kind"] == "table": + if is_complete_visible_table: + table_result = simready_manager.make_table_simready( + MakeTableSimreadyRequest( + input_path=Path(raw_geometry_path), + output_path=simready_dir + / f"{requested['id']}_simready.glb", + ) + ) + simready_geometry_path = str(table_result.output_path) + raw_to_simready_glb_matrix = table_result.transform_matrix + else: + asset_result = simready_manager.make_asset_simready( + MakeAssetSimreadyRequest( + input_path=Path(raw_geometry_path), + output_path=simready_dir + / f"{requested['id']}_simready.glb", + ) + ) + simready_geometry_path = str(asset_result.output_path) + raw_to_simready_glb_matrix = asset_result.transform_matrix + except Exception: + status_parts.append(f"simready_failed: {traceback.format_exc()}") + item_status = "ok" if not status_parts else "; ".join(status_parts) + generated_item = { + "id": requested["id"], + "name": requested["name"], + "kind": requested["kind"], + "description": str(table_spec.get("description", "")) + if requested["kind"] == "table" + else str(requested.get("description", "")), + "complete_table_description": str( + table_spec.get("complete_table_description") + or table_spec.get("description", "") + ).strip() + if requested["kind"] == "table" + else "", + "is_complete_visible_table": is_complete_visible_table + if requested["kind"] == "table" + else False, + "status": item_status, + "mask_path": relative_path(requested["mask_path"], output_root), + "raw_geometry_path": relative_path(raw_geometry_path, output_root), + "simready_geometry_path": relative_path( + simready_geometry_path, output_root + ) + if simready_geometry_path + else "", + "mesh_path": relative_path(simready_geometry_path, output_root) + if simready_geometry_path + else "", + "sam3d_name": generated.name, + "downloaded_raw_geometry_path": relative_path( + str(downloaded_raw_path), output_root + ), + "rotation_quaternion_wxyz": generated.rotation_quaternion_wxyz, + "translation": generated.translation, + "scale": generated.scale, + "transform_matrix": transform_matrix, + "raw_to_simready_glb_matrix": raw_to_simready_glb_matrix, + "metric_scale": metric_scale, + } + if requested["kind"] == "table": + support_reference_path = raw_download_dir / "support_surface_raw.glb" + table_raw_path = raw_download_dir / "table_raw.glb" + shutil.copy2(downloaded_raw_path, support_reference_path) + if is_complete_visible_table: + shutil.copy2(downloaded_raw_path, table_raw_path) + generated_item["raw_geometry_path"] = relative_path( + str(table_raw_path), + output_root, + ) + generated_item["support_reference_geometry_path"] = relative_path( + str(support_reference_path), + output_root, + ) + generated_item["support_reference_transform_matrix"] = transform_matrix + generated_item["support_normal_source"] = "segmented_table" + generated_item["table_asset_source"] = "segmented_table" + if not is_complete_visible_table: + # Replace partial image table with description-generated table. + incomplete_table_id = str( + generated_item.get("id") + or table_spec.get("id") + or "table" + ) + incomplete_table_desc = str( + table_spec.get("complete_table_description") + or table_spec.get("description", "") + ).strip() + incomplete_debug_dir = ( + debug_dir / incomplete_table_id / "description_generated" + ) + incomplete_debug_dir.mkdir(parents=True, exist_ok=True) + incomplete_raw_download_dir = glb_gen_dir / "raw_downloads" + incomplete_raw_download_dir.mkdir(parents=True, exist_ok=True) + incomplete_raw_image = str( + image_manager.generate_asset_image_from_text( + TextToAssetImageRequest( + prompt=incomplete_table_desc, + output_path=incomplete_debug_dir + / f"{incomplete_table_id}_complete.png", + ) + ) + ) + incomplete_rgba = str( + segmentation_manager.convert_asset_image_to_rgba( + AssetImageToRgbaRequest( + image_path=Path(incomplete_raw_image), + prompt=incomplete_table_desc + if incomplete_table_desc.strip() + else "whole table", + output_path=image_gen_dir + / f"{incomplete_table_id}_complete.png", + ) + ) + ) + incomplete_raw_glb = str( + geometry_manager.convert_rgba_image_to_geometry( + RgbaImageToGeometryRequest( + image_path=Path(incomplete_rgba), + output_path=incomplete_debug_dir + / f"{incomplete_table_id}_complete_raw.glb", + ) + ) + ) + incomplete_table_raw_path = ( + incomplete_raw_download_dir / "table_raw.glb" + ) + shutil.copy2(incomplete_raw_glb, incomplete_table_raw_path) + incomplete_simready = simready_manager.make_table_simready( + MakeTableSimreadyRequest( + input_path=incomplete_table_raw_path, + output_path=glb_gen_dir + / "multi_object_layouts_simready" + / f"{incomplete_table_id}_simready.glb", + ) + ) + generated_item.update( + { + "image_path": relative_path( + incomplete_rgba, output_root + ), + "raw_geometry_path": relative_path( + str(incomplete_table_raw_path), output_root + ), + "generated_table_raw_geometry_path": relative_path( + incomplete_raw_glb, output_root + ), + "simready_geometry_path": relative_path( + str(incomplete_simready.output_path), + output_root, + ), + "mesh_path": relative_path( + str(incomplete_simready.output_path), + output_root, + ), + "raw_to_simready_glb_matrix": ( + incomplete_simready.transform_matrix + ), + "transform_matrix": np.eye( + 4, dtype=np.float64 + ).tolist(), + "table_asset_source": "description_generated", + "complete_table_description": incomplete_table_desc, + } + ) + generated_table = generated_item + else: + generated_objects.append(generated_item) + except Exception as exc: + status = "failed" + failure_reason = traceback.format_exc() + log_warning(f"image object geometry generation failed error={exc}") + + if generated_objects: + _estimate_image_scene_metric_scales( + objects=generated_objects, + bbox_name_image_path=segments_data.get("bbox_name_image_path"), + output_dir=glb_gen_dir, + output_root=output_root, + llm=llm, + ) + + alignment_result: dict[str, Any] | None = None + if generated_table is not None and generated_objects: + try: + alignment_result = _export_support_aligned_layout_glbs( + table=generated_table, + objects=generated_objects, + spatial_relations=spatial_relations, + original_image_path=Path(original_image_path) + if original_image_path + else None, + llm=llm, + output_dir=aligned_dir, + output_root=output_root, + ) + aligned_object_by_id = { + item["id"]: item for item in alignment_result["objects"] + } + for generated_object in generated_objects: + aligned_object = aligned_object_by_id.get(generated_object["id"]) + if aligned_object is not None: + generated_object["aligned_geometry_path"] = aligned_object[ + "aligned_geometry_path" + ] + except Exception as exc: + status = "failed" + failure_reason = traceback.format_exc() + log_warning(f"image object alignment failed error={exc}") + alignment_result = { + "status": "failed", + "reason": failure_reason, + } + + manifest_paths = _write_multi_object_layout_manifests( + glb_gen_dir=glb_gen_dir, + output_root=output_root, + table=generated_table, + objects=generated_objects, + alignment=alignment_result, + ) + table_fields = ( + "id", + "name", + "status", + "is_complete_visible_table", + "complete_table_description", + "table_asset_source", + "support_normal_source", + "image_path", + "raw_geometry_path", + "support_reference_geometry_path", + "generated_table_raw_geometry_path", + "transformed_geometry_path", + "simready_geometry_path", + "aligned_geometry_path", + "mesh_path", + ) + object_fields = ( + "id", + "name", + "status", + "image_path", + "mesh_path", + "aligned_geometry_path", + "metric_scale", + ) + workflow_table = ( + {key: generated_table[key] for key in table_fields if key in generated_table} + if generated_table is not None + else None + ) + workflow_objects = [ + {key: item[key] for key in object_fields if key in item} + for item in generated_objects + ] + if workflow_table is not None and workflow_table.get("status") != "ok": + workflow_table["status"] = "failed" + for item in workflow_objects: + if item.get("status") != "ok": + item["status"] = "failed" + workflow_alignment = ( + { + key: alignment_result[key] + for key in ("status", "final_clutter_2d_aabb_cm") + if key in alignment_result + } + if alignment_result is not None + else None + ) + result = { + "status": status, + "table": workflow_table, + "objects": workflow_objects, + "alignment": workflow_alignment, + "manifests": manifest_paths, + } + if failure_reason: + result["reason"] = failure_reason + log_info( + "image object layout generation completed " + f"status={status} generated={len(generated_objects)}" + ) + return result + + +def _estimate_image_scene_metric_scales( + *, + objects: list[dict[str, Any]], + bbox_name_image_path: Any, + output_dir: Path, + output_root: Path, + llm: Any | None, +) -> dict[str, Any]: + result: dict[str, Any] = { + "status": "skipped", + "method": "image_scene_bbox_name_vlm_candidate_shape_ratio_median_scale", + "bbox_name_image_path": str(bbox_name_image_path or ""), + "objects": [], + } + try: + if not METRIC_SCALE_ENABLED: + result["reason"] = "metric_scale_disabled" + MetricScaleManager.set_for_all_objects( + objects=objects, + status="skipped", + reason="metric_scale_disabled", + method=str(result["method"]), + ) + return result + if llm is None: + result["reason"] = "missing_llm" + MetricScaleManager.set_for_all_objects( + objects=objects, + status="skipped", + reason="missing_llm", + method=str(result["method"]), + ) + return result + + bbox_image = _resolve_generated_path(bbox_name_image_path, output_root) + if not bbox_image.is_file(): + result["reason"] = "missing_bbox_name_image" + MetricScaleManager.set_for_all_objects( + objects=objects, + status="skipped", + reason="missing_bbox_name_image", + method=str(result["method"]), + ) + return result + + metric_objects = _build_metric_scale_inputs( + objects=objects, + output_root=output_root, + ) + result["objects"] = MetricScaleManager.object_prompt_payload(metric_objects) + metric_result = MetricScaleManager.estimate_metric_scales( + EstimateMetricScalesRequest( + objects=metric_objects, + messages=build_image_metric_scale_messages( + bbox_name_image_path=bbox_image, + objects_json=result["objects"], + ), + schema=IMAGE_METRIC_SCALE_JSON_SCHEMA, + llm=llm, + context="Image scene metric scale estimate", + method=str(result["method"]), + step_name=UNIFIED_SCENE_STEP, + raw_output_path=output_dir / "image_metric_scale_raw_model_output.json", + ) + ) + estimates = metric_result.object_scales + MetricScaleManager.apply_to_objects(objects=objects, object_scales=estimates) + result.update( + { + "status": "ok", + "object_scales": estimates, + "unit_note": ( + "Per-object scale_factor is not baked into simready GLBs. " + "Image alignment later computes one clamped global clutter " + "scale from these per-object estimates, on top of SAM3D " + "per-object layout scale." + ), + } + ) + except Exception: + result.update({"status": "failed", "reason": traceback.format_exc()}) + MetricScaleManager.set_for_all_objects( + objects=objects, + status="failed", + reason="image_scene_metric_scale_failed", + method=str(result["method"]), + ) + return result + + +def _build_metric_scale_inputs( + *, + objects: list[dict[str, Any]], + output_root: Path, +) -> list[MetricScaleObjectInput]: + inputs: list[MetricScaleObjectInput] = [] + for obj in objects: + mesh_path = _resolve_generated_path( + obj.get("simready_geometry_path") or obj.get("mesh_path"), + output_root, + ) + if not mesh_path.is_file(): + raise FileNotFoundError(f"Simready object GLB not found: {mesh_path}") + inputs.append( + MetricScaleObjectInput( + object_id=str(obj.get("id", "")), + object_name=str(obj.get("name", "")), + object_description=str(obj.get("description", "")), + mesh_path=mesh_path, + ) + ) + return inputs + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py new file mode 100644 index 00000000..ae96b3a3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py @@ -0,0 +1,105 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.table_clutter_fit_manager import ( + fit_table_to_clutter, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning + +__all__ = ["fit_image_scene_table", "fit_text_scene_table"] + + +def fit_text_scene_table( + *, + table_result: dict[str, Any], + clutter_layout_result: dict[str, Any], + output_root: Path, + output_dir: Path, +) -> dict[str, Any]: + """Fit the text-scene table and convert failures to result data.""" + try: + result = fit_table_to_clutter( + table_result=table_result, + clutter_result=clutter_layout_result, + output_root=output_root, + output_dir=output_dir, + ) + log_info(f"text table fit completed status={result.get('status')}") + return result + except Exception as exc: + log_warning(f"text table fit failed error={exc}") + return { + "status": "failed", + "reason": traceback.format_exc(), + } + + +def fit_image_scene_table( + *, + layout_result: dict[str, Any], + fallback_table_result: dict[str, Any] | None, + output_root: Path, + output_dir: Path, +) -> dict[str, Any]: + """Fit the image-scene table or return a structured skipped result.""" + generated_table = layout_result.get("table") or fallback_table_result + generated_objects = layout_result.get("objects") or [] + alignment_result = layout_result.get("alignment") + if ( + generated_table is None + or not generated_objects + or not isinstance(alignment_result, dict) + ): + return { + "status": "skipped", + "reason": "missing_table_objects_or_alignment", + } + + try: + clutter_result = { + "clutter_2d_aabb_cm": alignment_result.get( + "final_clutter_2d_aabb_cm" + ), + "objects": [ + { + "id": item["id"], + "status": "ok", + "laid_out_glb_path": item["aligned_geometry_path"], + } + for item in generated_objects + if item.get("id") and item.get("aligned_geometry_path") + ], + } + result = fit_table_to_clutter( + table_result=generated_table, + clutter_result=clutter_result, + output_root=output_root, + output_dir=output_dir, + ) + log_info(f"image table fit completed status={result.get('status')}") + return result + except Exception as exc: + log_warning(f"image table fit failed error={exc}") + return { + "status": "failed", + "reason": traceback.format_exc(), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py new file mode 100644 index 00000000..1beb7603 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py @@ -0,0 +1,294 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import shutil +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager import ( + GeometryGenerationManager, + RgbaImageToGeometryRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager import ( + ImageGenerationManager, + TextToAssetImageRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager import ( + AssetImageToRgbaRequest, + ImageSegmentationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( + MakeAssetSimreadyRequest, + MakeTableSimreadyRequest, + SimreadyManager, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning + +__all__ = [ + "generate_text_object_asset", + "generate_text_object_assets", + "generate_text_table_asset", +] + + +def generate_text_object_asset( + *, + object_spec: dict[str, Any], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, +) -> dict[str, Any]: + """Generate one object asset from a text-origin object spec.""" + object_id = str(object_spec.get("id", "object")) + object_name = str(object_spec.get("name", "")) + description = str(object_spec.get("description", "")) + class_candidates = [ + str(candidate).replace("_", " ") + for candidate in object_spec.get("class_candidate", []) + if isinstance(candidate, str) and candidate.strip() + ] + status = "ok" + image_path = "" + raw_geometry_path = "" + mesh_path = "" + raw_to_simready_matrix: list[list[float]] = [] + + debug_subdir = debug_dir / object_id + debug_subdir.mkdir(parents=True, exist_ok=True) + log_info(f"text object generation started id={object_id} name={object_name}") + + image_manager = ImageGenerationManager() + segmentation_manager = ImageSegmentationManager() + geometry_manager = GeometryGenerationManager() + simready_manager = SimreadyManager() + + try: + image_prompt = f"{object_name}, {description}".strip(", ") + raw_image_path = str( + image_manager.generate_asset_image_from_text( + TextToAssetImageRequest( + prompt=image_prompt, + output_path=debug_subdir / f"{object_id}.png", + ) + ) + ) + + rgba_prompts: list[str] = [] + if description.strip(): + rgba_prompts.append(description.strip()) + for candidate in class_candidates: + candidate_prompt = f"The entire {candidate} on the center of the image" + if candidate_prompt not in rgba_prompts: + rgba_prompts.append(candidate_prompt) + if not rgba_prompts: + rgba_prompts.append( + f"the entire single isolated object {object_name}" + if object_name + else "the entire single isolated object" + ) + + rgba_path = "" + last_rgba_error: Exception | None = None + for prompt in rgba_prompts: + try: + rgba_path = str( + segmentation_manager.convert_asset_image_to_rgba( + AssetImageToRgbaRequest( + image_path=Path(raw_image_path), + prompt=prompt, + output_path=image_gen_dir / f"{object_id}.png", + ) + ) + ) + break + except Exception as exc: + last_rgba_error = exc + log_warning( + "text object segmentation prompt failed " + f"id={object_id} prompt={prompt!r} error={exc}" + ) + if not rgba_path: + raise last_rgba_error or RuntimeError( + f"No RGBA prompt succeeded for {object_id}" + ) + + raw_glb_path = str( + geometry_manager.convert_rgba_image_to_geometry( + RgbaImageToGeometryRequest( + image_path=Path(rgba_path), + output_path=debug_subdir / f"{object_id}_raw.glb", + ) + ) + ) + raw_geometry_dir = glb_gen_dir / "raw_downloads" + raw_geometry_dir.mkdir(parents=True, exist_ok=True) + object_raw_path = raw_geometry_dir / f"{object_id}_raw.glb" + shutil.copy2(raw_glb_path, object_raw_path) + raw_geometry_path = str(object_raw_path) + + simready_result = simready_manager.make_asset_simready( + MakeAssetSimreadyRequest( + input_path=Path(raw_glb_path), + output_path=glb_gen_dir + / "text_objects_simready" + / f"{object_id}_simready.glb", + ) + ) + mesh_path = str(simready_result.output_path) + raw_to_simready_matrix = simready_result.transform_matrix + + image_path = rgba_path + log_info(f"text object generation completed id={object_id} mesh={mesh_path}") + except Exception as exc: + status = f"failed: {traceback.format_exc()}" + log_warning(f"text object generation failed id={object_id} error={exc}") + + return { + "id": object_id, + "name": object_name, + "status": status, + "image_path": image_path, + "raw_geometry_path": raw_geometry_path, + "mesh_path": mesh_path, + "simready_geometry_path": mesh_path, + "raw_to_simready_glb_matrix": raw_to_simready_matrix, + "metric_scale": None, + } + + +def generate_text_object_assets( + *, + object_specs: list[dict[str, Any]], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, +) -> list[dict[str, Any]]: + """Generate all object assets for a text-origin unified scene.""" + log_info(f"text object batch generation started count={len(object_specs)}") + results = [ + generate_text_object_asset( + object_spec=object_spec, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + for object_spec in object_specs + ] + succeeded = sum(result.get("status") == "ok" for result in results) + log_info( + f"text object batch generation completed " + f"succeeded={succeeded} failed={len(results) - succeeded}" + ) + return results + + +def generate_text_table_asset( + *, + table_spec: dict[str, Any], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, +) -> dict[str, Any]: + """Generate the table asset for a text-origin unified scene.""" + table_id = str(table_spec.get("id", "table")) + description = str( + table_spec.get("complete_table_description") + or table_spec.get("description", "") + ).strip() + status = "ok" + image_path = "" + raw_geometry_path = "" + generated_table_raw_geometry_path = "" + mesh_path = "" + + debug_subdir = debug_dir / table_id + debug_subdir.mkdir(parents=True, exist_ok=True) + log_info(f"text table generation started id={table_id}") + + image_manager = ImageGenerationManager() + segmentation_manager = ImageSegmentationManager() + geometry_manager = GeometryGenerationManager() + simready_manager = SimreadyManager() + + try: + raw_image_path = str( + image_manager.generate_asset_image_from_text( + TextToAssetImageRequest( + prompt=description, + output_path=debug_subdir / f"{table_id}.png", + ) + ) + ) + rgba_path = str( + segmentation_manager.convert_asset_image_to_rgba( + AssetImageToRgbaRequest( + image_path=Path(raw_image_path), + prompt=description if description.strip() else "whole table", + output_path=image_gen_dir / f"{table_id}.png", + ) + ) + ) + raw_glb_path = str( + geometry_manager.convert_rgba_image_to_geometry( + RgbaImageToGeometryRequest( + image_path=Path(rgba_path), + output_path=debug_subdir / f"{table_id}_raw.glb", + ) + ) + ) + generated_table_raw_geometry_path = raw_glb_path + raw_geometry_dir = glb_gen_dir / "raw_downloads" + raw_geometry_dir.mkdir(parents=True, exist_ok=True) + table_raw_path = raw_geometry_dir / "table_raw.glb" + shutil.copy2(raw_glb_path, table_raw_path) + raw_geometry_path = str(table_raw_path) + mesh_path = str( + simready_manager.make_table_simready( + MakeTableSimreadyRequest( + input_path=Path(raw_geometry_path), + output_path=glb_gen_dir + / "text_objects_simready" + / f"{table_id}_simready.glb", + ) + ).output_path + ) + image_path = rgba_path + log_info(f"text table generation completed id={table_id} mesh={mesh_path}") + except Exception as exc: + status = f"failed: {traceback.format_exc()}" + log_warning(f"text table generation failed id={table_id} error={exc}") + + return { + "id": table_id, + "name": str(table_spec.get("name", "table")), + "description": str(table_spec.get("description", "")), + "complete_table_description": description, + "is_complete_visible_table": bool( + table_spec.get("is_complete_visible_table", False) + ), + "status": status, + "image_path": image_path, + "raw_geometry_path": raw_geometry_path, + "generated_table_raw_geometry_path": generated_table_raw_geometry_path, + "support_reference_geometry_path": "", + "table_asset_source": "description_generated", + "support_normal_source": "", + "mesh_path": mesh_path, + "simready_geometry_path": mesh_path, + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py new file mode 100644 index 00000000..80bc3210 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py @@ -0,0 +1,62 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager import ( + settle_text_objects_to_ground, +) + +__all__ = ["generate_text_clutter_layout"] + + +def generate_text_clutter_layout( + *, + object_results: list[dict[str, Any]], + spatial_relations: list[dict[str, Any]], + table_constraints: list[dict[str, Any]], + output_dir: Path, + output_root: Path, +) -> dict[str, Any]: + """Settle and spatially arrange generated text-scene objects.""" + if not object_results: + return { + "status": "skipped", + "reason": "no_text_objects", + } + + try: + log_info(f"text clutter layout started count={len(object_results)}") + result = settle_text_objects_to_ground( + objects=object_results, + spatial_relations=spatial_relations, + table_constraints=table_constraints, + output_dir=output_dir, + output_root=output_root, + ) + log_info(f"text clutter layout completed status={result.get('status')}") + return result + except Exception as exc: + log_warning(f"text clutter layout failed error={exc}") + return { + "status": "failed", + "reason": traceback.format_exc(), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py new file mode 100644 index 00000000..fd0b1383 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py @@ -0,0 +1,161 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( + METRIC_SCALE_ENABLED, + EstimateMetricScalesRequest, + MetricScaleManager, + MetricScaleObjectInput, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning + +__all__ = ["build_metric_scale_inputs", "estimate_text_scene_metric_scale"] + + +def estimate_text_scene_metric_scale( + *, + object_results: list[dict[str, Any]], + user_text: str, + messages: list[dict[str, Any]], + schema: dict[str, Any], + output_dir: Path, + output_root: Path, + llm: Any | None, + step_name: str, +) -> dict[str, Any]: + """Estimate real-world scales for generated text-scene objects.""" + result: dict[str, Any] = { + "status": "skipped", + "method": "text_scene_vlm_candidate_shape_ratio_median_scale", + "user_text": user_text, + "objects": [], + } + try: + if not object_results: + result["reason"] = "missing_objects" + log_warning("text scene metric scale skipped reason=missing_objects") + return result + if not METRIC_SCALE_ENABLED: + result["reason"] = "metric_scale_disabled" + MetricScaleManager.set_for_all_objects( + objects=object_results, + status="skipped", + reason="metric_scale_disabled", + method=str(result["method"]), + ) + log_info("text scene metric scale skipped reason=metric_scale_disabled") + return result + if llm is None: + result["reason"] = "missing_llm" + MetricScaleManager.set_for_all_objects( + objects=object_results, + status="skipped", + reason="missing_llm", + method=str(result["method"]), + ) + log_warning("text scene metric scale skipped reason=missing_llm") + return result + + log_info(f"text scene metric scale started count={len(object_results)}") + metric_objects = build_metric_scale_inputs( + objects=object_results, + output_root=output_root, + ) + result["objects"] = MetricScaleManager.object_prompt_payload(metric_objects) + metric_result = MetricScaleManager.estimate_metric_scales( + EstimateMetricScalesRequest( + objects=metric_objects, + messages=messages, + schema=schema, + llm=llm, + context="Text scene metric scale estimate", + method=str(result["method"]), + step_name=step_name, + raw_output_path=output_dir / "raw_model_output.json", + ) + ) + raw_model_output = metric_result.raw_model_output or {} + if not (output_dir / "raw_model_output.json").is_file(): + try: + write_json(output_dir / "raw_model_output.json", raw_model_output) + except Exception as exc: + log_warning(f"metric scale raw output write failed error={exc}") + + estimates = metric_result.object_scales + MetricScaleManager.apply_to_objects( + objects=object_results, + object_scales=estimates, + ) + result.update( + { + "status": "ok", + "object_scales": estimates, + "unit_note": ( + "Per-object scale_factor is not baked into simready GLBs. " + "For text input, simready_geometry_path multiplied by this " + "scale_factor gives the estimated real-world size." + ), + } + ) + log_info(f"text scene metric scale completed count={len(estimates)}") + except Exception as exc: + result.update({"status": "failed", "reason": traceback.format_exc()}) + MetricScaleManager.set_for_all_objects( + objects=object_results, + status="failed", + reason="text_scene_metric_scale_failed", + method=str(result["method"]), + ) + log_warning(f"text scene metric scale failed error={exc}") + return result + + +def build_metric_scale_inputs( + *, + objects: list[dict[str, Any]], + output_root: Path, +) -> list[MetricScaleObjectInput]: + inputs: list[MetricScaleObjectInput] = [] + for obj in objects: + mesh_path = _resolve_generated_path( + obj.get("simready_geometry_path") or obj.get("mesh_path"), + output_root, + ) + if not mesh_path.is_file(): + raise FileNotFoundError(f"Simready object GLB not found: {mesh_path}") + inputs.append( + MetricScaleObjectInput( + object_id=str(obj.get("id", "")), + object_name=str(obj.get("name", "")), + object_description=str(obj.get("description", "")), + mesh_path=mesh_path, + ) + ) + return inputs + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/cli/__init__.py b/embodichain/gen_sim/prompt2scene/cli/__init__.py new file mode 100644 index 00000000..015c4151 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/cli/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/cli/start.py b/embodichain/gen_sim/prompt2scene/cli/start.py new file mode 100644 index 00000000..fdc3a27b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/cli/start.py @@ -0,0 +1,90 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import argparse +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.pipeline.runner import run_prompt2scene +from embodichain.gen_sim.prompt2scene.llms import load_llm_config +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["cli_prompt2scene", "main"] + + +def cli_prompt2scene( + image_path: str | None, + text: str | None, + output_root: str, + llm_config_path: str | None = None, +) -> None: + """Run prompt2scene from normalized CLI argument values. + + Args: + image_path: Path to an input image, if image mode is used. + text: Text prompt, if text mode is used. + output_root: Directory where prompt2scene outputs are written. + llm_config_path: Optional path to the LLM config JSON file. + """ + request = Prompt2SceneInput.from_cli_args( + image_path=Path(image_path) if image_path is not None else None, + text=text, + output_root=Path(output_root), + ) + llm_cfg = load_llm_config( + Path(llm_config_path) if llm_config_path is not None else None + ) + run_prompt2scene(request, llm_cfg=llm_cfg) + + +def main() -> None: + """Parse command line arguments and launch the prompt2scene pipeline.""" + parser = argparse.ArgumentParser( + description="embodichain.gen_sim.prompt2scene Prompt-to-Scene Pipeline" + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--image", + type=str, + help="Path to the input image file (.jpg, .jpeg, or .png)", + ) + input_group.add_argument( + "--text", + type=str, + help="Text prompt describing the target scene", + ) + parser.add_argument( + "--output_root", + type=str, + required=True, + help="Path to the output directory", + ) + parser.add_argument( + "--llm_config", + type=str, + default=None, + help="Path to the LLM config JSON file", + ) + + args = parser.parse_args() + + cli_prompt2scene(args.image, args.text, args.output_root, args.llm_config) + + +if __name__ == "__main__": + main() diff --git a/embodichain/gen_sim/prompt2scene/configs/client_config.json b/embodichain/gen_sim/prompt2scene/configs/client_config.json new file mode 100644 index 00000000..b8662eaf --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/configs/client_config.json @@ -0,0 +1,21 @@ +{ + "sam3_segmentation": { + "base_url": "http://192.168.3.23:5014", + "timeout_s": 1200, + "health_path": "/health", + "segment_single_object_path": "/predict" + }, + "sam3d_generation": { + "base_url": "http://10.7.7.32:5019", + "timeout_s": 1800, + "health_path": "/health", + "generate_multiple_objects_path": "/generate_multiple_objects", + "generate_single_object_path": "/generate_single_object" + }, + "zimage": { + "base_url": "http://192.168.3.23:5013", + "timeout_s": 120, + "health_path": "/health", + "generate_single_object_path": "/generate.png" + } +} diff --git a/embodichain/gen_sim/prompt2scene/configs/llm_config.json b/embodichain/gen_sim/prompt2scene/configs/llm_config.json new file mode 100644 index 00000000..9dd82514 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/configs/llm_config.json @@ -0,0 +1,11 @@ +{ + "llm": { + "openai_compatible": { + "api_key": "", + "model": "", + "base_url": "", + "default_query": {}, + "max_attempts": 5 + } + } +} diff --git a/embodichain/gen_sim/prompt2scene/llms/__init__.py b/embodichain/gen_sim/prompt2scene/llms/__init__.py new file mode 100644 index 00000000..8412eff4 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/llms/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.llms.config import OpenAICompatibleLLMCfg +from embodichain.gen_sim.prompt2scene.llms.openai_compatible import ( + DEFAULT_LLM_CONFIG_PATH, + build_chat_model, + load_llm_config, +) + +__all__ = [ + "DEFAULT_LLM_CONFIG_PATH", + "OpenAICompatibleLLMCfg", + "build_chat_model", + "load_llm_config", +] diff --git a/embodichain/gen_sim/prompt2scene/llms/config.py b/embodichain/gen_sim/prompt2scene/llms/config.py new file mode 100644 index 00000000..f84c4fcf --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/llms/config.py @@ -0,0 +1,49 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field + +__all__ = [ + "OpenAICompatibleLLMCfg", +] + + +@dataclass(frozen=True) +class OpenAICompatibleLLMCfg: + """OpenAI-compatible LLM configuration.""" + + api_key: str + model: str + base_url: str + default_query: dict[str, str] = field(default_factory=dict) + max_attempts: int = 3 + + def to_manifest(self) -> dict[str, object]: + """Convert the LLM config to a JSON-safe manifest. + + Returns: + LLM config metadata with sensitive values removed. + """ + return { + "provider": "openai_compatible", + "model": self.model, + "base_url": self.base_url, + "has_api_key": bool(self.api_key), + "default_query": self.default_query, + "max_attempts": self.max_attempts, + } diff --git a/embodichain/gen_sim/prompt2scene/llms/openai_compatible.py b/embodichain/gen_sim/prompt2scene/llms/openai_compatible.py new file mode 100644 index 00000000..91e94a59 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/llms/openai_compatible.py @@ -0,0 +1,115 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +from langchain_openai import ChatOpenAI + +from embodichain.gen_sim.prompt2scene.llms.config import OpenAICompatibleLLMCfg + +__all__ = ["DEFAULT_LLM_CONFIG_PATH", "build_chat_model", "load_llm_config"] + +DEFAULT_LLM_CONFIG_PATH = ( + Path(__file__).resolve().parents[1] / "configs" / "llm_config.json" +) + + +def load_llm_config(config_path: Path | None = None) -> OpenAICompatibleLLMCfg: + """Load the prompt2scene OpenAI-compatible LLM config. + + Args: + config_path: Optional path to the LLM config JSON file. + + Returns: + Parsed OpenAI-compatible LLM config. + + Raises: + FileNotFoundError: If the config file does not exist. + ValueError: If required config fields are missing. + """ + config_path = config_path or DEFAULT_LLM_CONFIG_PATH + config_path = config_path.expanduser().resolve() + + if not config_path.exists(): + raise FileNotFoundError(f"LLM config not found: {config_path}") + + with config_path.open("r", encoding="utf-8") as f: + raw_cfg: dict[str, Any] = json.load(f) + + cfg = raw_cfg.get("llm", {}).get("openai_compatible", {}) + api_key = os.getenv("OPENAI_API_KEY") or cfg.get("api_key", "") + model = os.getenv("OPENAI_MODEL") or cfg.get("model", "") + base_url = os.getenv("OPENAI_BASE_URL") or cfg.get("base_url", "") + default_query = cfg.get("default_query", {}) + max_attempts = _load_positive_int( + os.getenv("OPENAI_MAX_ATTEMPTS") or cfg.get("max_attempts", 3), + key="max_attempts", + ) + + if base_url: + base_url = base_url.rstrip("/") + + missing = [ + name + for name, value in { + "api_key": api_key, + "model": model, + "base_url": base_url, + }.items() + if not value + ] + if missing: + raise ValueError(f"Missing required LLM config keys: {missing}") + + if not isinstance(default_query, dict): + raise ValueError("LLM config key default_query must be a dict.") + + return OpenAICompatibleLLMCfg( + api_key=api_key, + model=model, + base_url=base_url, + default_query=default_query, + max_attempts=max_attempts, + ) + + +def _load_positive_int(value: object, *, key: str) -> int: + try: + parsed = int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"LLM config key {key} must be an integer.") from exc + if parsed < 1: + raise ValueError(f"LLM config key {key} must be >= 1.") + return parsed + + +def build_chat_model(cfg: OpenAICompatibleLLMCfg) -> Any: + """Build a LangChain OpenAI-compatible chat model.""" + kwargs: dict[str, Any] = { + "api_key": cfg.api_key, + "base_url": cfg.base_url, + "model": cfg.model, + "temperature": 0, + } + if cfg.default_query: + kwargs["default_query"] = cfg.default_query + + return ChatOpenAI(**kwargs) diff --git a/embodichain/gen_sim/prompt2scene/pipeline/__init__.py b/embodichain/gen_sim/prompt2scene/pipeline/__init__.py new file mode 100644 index 00000000..a1450f03 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/pipeline/__init__.py @@ -0,0 +1,25 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.pipeline.runner import ( + Prompt2SceneRunResult, + run_prompt2scene, +) + +__all__ = ["Prompt2SceneRunResult", "run_prompt2scene"] + diff --git a/embodichain/gen_sim/prompt2scene/pipeline/runner.py b/embodichain/gen_sim/prompt2scene/pipeline/runner.py new file mode 100644 index 00000000..7931f00b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/pipeline/runner.py @@ -0,0 +1,239 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.llms import OpenAICompatibleLLMCfg +from embodichain.gen_sim.prompt2scene.workflows.request import ( + InputKind, + Prompt2SceneInput, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + SCENE_INTAKE_STEP, + STEP_RESULT_FILENAME, + step_result_path, + write_step_result, + TEXT_RELATIONS_STEP, + UNIFIED_SCENE_STEP, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.graph import ( + run_unified_scene, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.graph import ( + run_unified_scene_gen, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.gym_export import ( + export_gym_config, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.image_relations import ( + run_image_relations, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake import run_scene_intake +from embodichain.gen_sim.prompt2scene.workflows.text_relations import ( + run_text_relations, +) + +__all__ = [ + "IMAGE_SEGMENTS_DIRNAME", + "IMAGE_SPATIAL_RELATIONS_DIRNAME", + "INPUT_MANIFEST_FILENAME", + "SCENE_INTAKE_DIRNAME", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_DIRNAME", + "UNIFIED_SCENE_DIRNAME", + "Prompt2SceneRunResult", + "run_prompt2scene", +] + +INPUT_MANIFEST_FILENAME = "input_manifest.json" +SCENE_INTAKE_DIRNAME = SCENE_INTAKE_STEP +IMAGE_SEGMENTS_DIRNAME = IMAGE_SEGMENTS_STEP +IMAGE_SPATIAL_RELATIONS_DIRNAME = IMAGE_SPATIAL_RELATIONS_STEP +TEXT_RELATIONS_DIRNAME = TEXT_RELATIONS_STEP +UNIFIED_SCENE_DIRNAME = UNIFIED_SCENE_STEP + + +@dataclass(frozen=True) +class Prompt2SceneRunResult: + """Result returned by the prompt2scene runner. + + Args: + output_root: Directory where prompt2scene outputs were written. + manifest_path: Path to the serialized input manifest. + scene_intake_path: Path to the serialized scene intake output. + image_segments_path: Path to serialized image segment alignment output. + image_spatial_relations_path: Path to serialized image spatial relations. + text_relations_path: Path to serialized text spatial relations. + unified_scene_path: Path to serialized unified scene output. + """ + + output_root: Path + manifest_path: Path + scene_intake_path: Path | None = None + image_segments_path: Path | None = None + image_spatial_relations_path: Path | None = None + text_relations_path: Path | None = None + unified_scene_path: Path | None = None + gym_config_path: Path | None = None + + +def run_prompt2scene( + request: Prompt2SceneInput, + llm_cfg: OpenAICompatibleLLMCfg | None = None, +) -> Prompt2SceneRunResult: + """Run the prompt2scene pipeline. + + This runner creates the output directory, writes the parsed input manifest, + and runs fixed VLM-based scene intake when an LLM config is provided. + + Args: + request: Parsed prompt2scene input. + llm_cfg: Optional LLM config used by later pipeline stages. + + Returns: + Paths created by the runner. + """ + log.log_info( + "run start " + f"input_kind={request.input_kind.value} output_root={request.output_root}" + ) + request.output_root.mkdir(parents=True, exist_ok=True) + manifest_path = request.output_root / INPUT_MANIFEST_FILENAME + manifest = request.to_manifest() + if llm_cfg is not None: + manifest["llm"] = llm_cfg.to_manifest() + write_json(manifest_path, manifest) + + scene_intake_path = None + image_segments_path = None + image_spatial_relations_path = None + text_relations_path = None + unified_scene_path = None + gym_config_path = None + if llm_cfg is not None: + log.log_info("step start scene_intake") + scene_intake = run_scene_intake(request, llm_cfg=llm_cfg) + scene_intake_path = write_step_result( + request.output_root, + SCENE_INTAKE_STEP, + scene_intake.to_manifest(), + ) + log.log_info( + f"step end scene_intake status=ok output={scene_intake_path}" + ) + if request.input_kind == InputKind.IMAGE: + log.log_info("step start image_relations") + image_relations = run_image_relations( + request, + scene_intake=scene_intake, + llm_cfg=llm_cfg, + output_root=request.output_root, + ) + image_segments_path = step_result_path( + request.output_root, + IMAGE_SEGMENTS_STEP, + ) + if not image_segments_path.is_file(): + write_step_result( + request.output_root, + IMAGE_SEGMENTS_STEP, + image_relations.to_segmentation_manifest(), + ) + image_spatial_relations_path = step_result_path( + request.output_root, + IMAGE_SPATIAL_RELATIONS_STEP, + ) + if not image_spatial_relations_path.is_file(): + write_step_result( + request.output_root, + IMAGE_SPATIAL_RELATIONS_STEP, + image_relations.to_spatial_manifest(), + ) + log.log_info( + "step end image_relations " + f"status=ok output={image_spatial_relations_path}" + ) + log.log_info("step start unified_scene") + unified_scene = run_unified_scene( + request, + scene_intake=scene_intake, + image_relations=image_relations, + output_root=request.output_root, + ) + unified_scene_path = step_result_path( + request.output_root, + UNIFIED_SCENE_STEP, + ) + else: + log.log_info("step start text_relations") + text_relations = run_text_relations( + request, + scene_intake=scene_intake, + llm_cfg=llm_cfg, + output_root=request.output_root, + ) + text_relations_path = step_result_path( + request.output_root, + TEXT_RELATIONS_STEP, + ) + log.log_info( + f"step end text_relations status=ok output={text_relations_path}" + ) + log.log_info("step start unified_scene") + unified_scene = run_unified_scene( + request, + scene_intake=scene_intake, + text_relations=text_relations, + output_root=request.output_root, + ) + unified_scene_path = step_result_path( + request.output_root, + UNIFIED_SCENE_STEP, + ) + log.log_info( + f"step end unified_scene status=ok output={unified_scene_path}" + ) + log.log_info("step start unified_scene_gen") + run_unified_scene_gen( + request.output_root, + unified_scene_result_path=unified_scene_path, + llm_cfg=llm_cfg, + ) + log.log_info("step end unified_scene_gen status=ok") + + log.log_info("step start gym_export") + gym_config_path = export_gym_config(request.output_root) + log.log_info(f"step end gym_export status=ok output={gym_config_path}") + + log.log_info(f"run end output_root={request.output_root}") + + return Prompt2SceneRunResult( + output_root=request.output_root, + manifest_path=manifest_path, + scene_intake_path=scene_intake_path, + image_segments_path=image_segments_path, + image_spatial_relations_path=image_spatial_relations_path, + text_relations_path=text_relations_path, + unified_scene_path=unified_scene_path, + gym_config_path=gym_config_path, + ) diff --git a/embodichain/gen_sim/prompt2scene/prompts/__init__.py b/embodichain/gen_sim/prompt2scene/prompts/__init__.py new file mode 100644 index 00000000..f72a97f6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/__init__.py @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from . import data +from .base import PromptRenderer + +default_prompt_renderer = PromptRenderer(data) + +__all__ = ["load_prompt", "load_prompt_data", "render_prompt", "default_prompt_renderer"] + + +def load_prompt(prompt_name: str) -> str: + """Load a prompt template from the bundled prompt data directory.""" + return default_prompt_renderer.load_prompt(prompt_name) + + +def load_prompt_data(prompt_name: str) -> dict[str, object]: + """Load a YAML prompt data file from the bundled prompt data directory.""" + return default_prompt_renderer.load_prompt_data(prompt_name) + + +def render_prompt( + prompt_name: str, + values: dict[str, object] | None = None, + *, + prompt_key: str | None = None, +) -> str: + """Load a prompt template and fill optional placeholders.""" + return default_prompt_renderer.render_prompt( + prompt_name, + values, + prompt_key=prompt_key, + ) diff --git a/embodichain/gen_sim/prompt2scene/prompts/base.py b/embodichain/gen_sim/prompt2scene/prompts/base.py new file mode 100644 index 00000000..a145735c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/base.py @@ -0,0 +1,79 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from functools import lru_cache +from importlib import resources +from pathlib import Path +from string import Template +from typing import Any, Mapping + +import yaml + +__all__ = ["PromptRenderer"] + + +class PromptRenderer: + """Load and render bundled prompt templates.""" + + def __init__(self, package: Any) -> None: + self._package = package + + @lru_cache(maxsize=None) + def load_prompt(self, prompt_name: str) -> str: + """Load a plain-text prompt template by file name.""" + prompt_path = self._get_prompt_path(prompt_name) + if not prompt_path.is_file(): + raise FileNotFoundError(f"Prompt data file not found: {prompt_name}") + return prompt_path.read_text(encoding="utf-8").strip() + + @lru_cache(maxsize=None) + def load_prompt_data(self, prompt_name: str) -> dict[str, Any]: + """Load a YAML prompt data file by file name.""" + prompt_path = self._get_prompt_path(prompt_name) + if not prompt_path.is_file(): + raise FileNotFoundError(f"Prompt data file not found: {prompt_name}") + + prompt_data = yaml.safe_load(prompt_path.read_text(encoding="utf-8")) + if not isinstance(prompt_data, dict): + raise ValueError(f"Prompt YAML must contain a mapping: {prompt_name}") + return prompt_data + + def render_prompt( + self, + prompt_name: str, + values: Mapping[str, object] | None = None, + *, + prompt_key: str | None = None, + ) -> str: + """Render a prompt template and fill placeholders.""" + if prompt_key is None: + template = self.load_prompt(prompt_name) + else: + prompt_data = self.load_prompt_data(prompt_name) + template = prompt_data.get(prompt_key) + if not isinstance(template, str): + raise KeyError(f"Prompt key {prompt_key!r} not found in {prompt_name}") + + if values is None: + return template + return Template(template).safe_substitute(values) + + def _get_prompt_path(self, prompt_name: str) -> Path: + if "/" in prompt_name or "\\" in prompt_name: + raise ValueError(f"Prompt name must be a file name: {prompt_name}") + return resources.files(self._package).joinpath(prompt_name) diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/__init__.py b/embodichain/gen_sim/prompt2scene/prompts/data/__init__.py new file mode 100644 index 00000000..96d64212 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Bundled prompt template data files.""" + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml new file mode 100644 index 00000000..50ed6964 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml @@ -0,0 +1,238 @@ +name: image_relations +version: 1 + +filter_extra_instances_system: | + + You are a careful image segmentation verification assistant for tabletop scenes. + + + + You will receive: + - One target object class name. + - One target object description. + - The expected number of target instances. + - A short candidate class list for that target object. + - One image with numbered colored masks drawn over candidate segmentation + results for that target object. + + Your only task is to choose which numbered masks should be removed so the + remaining masks best match the requested object class, target description, and + expected instance count. + + This is not a scene-description task and not a spatial-relation task. + Do not describe the scene. Do not infer object-object relations. Do not rename + the requested object class. Do not add new masks. + + + + - Use the target object class name as the primary class. + - Use the target description to distinguish visually similar objects from the + same broad category. + - Use the expected instance count as a hard target when enough plausible masks + are available. + - Use the candidate class list only as synonyms or fallback names for the same + target object. + - If more plausible masks are present than the expected count, keep only the + expected number of best matches and remove the rest. + - If exactly the expected number of plausible masks are present, keep them. + - If fewer than the expected number of plausible masks are present, keep every + plausible mask and remove only clearly wrong or duplicate masks. + - Remove a numbered mask if it clearly covers a different object class. + - Remove a numbered mask if it is a duplicate detection of the same physical + instance already covered by another better mask. + - Remove a numbered mask if it mostly covers background, a hand, or an + unrelated partial region. + - Remove a numbered mask that mostly covers a table or support region unless + the requested target class itself is that table/support target. + - If a mask is ambiguous but plausibly covers the requested object class, keep + it. + + + + { + "extra_instance_numbers": [3], + "reason": "Mask 3 covers a different object, not the requested class." + } + + + + Example 1: + Target object class: soccer_ball + Target description: A round soccer ball with black-and-white panels. + Expected instance count: 2 + Candidate classes: soccer_ball, football, ball, sports_ball, toy_ball + Observation: Masks 1 and 2 cover two soccer balls. Mask 3 covers a paper cup. + Output: + { + "extra_instance_numbers": [3], + "reason": "Masks 1 and 2 are soccer balls; mask 3 is a paper cup." + } + + Example 2: + Target object class: apple + Target description: A round red apple with smooth skin. + Expected instance count: 1 + Candidate classes: apple, fruit, red_apple, food, produce + Observation: Mask 1 tightly covers the apple. Mask 2 overlaps the same apple and + is a duplicate looser detection. + Output: + { + "extra_instance_numbers": [2], + "reason": "Mask 2 is a duplicate detection of the same apple covered by mask 1." + } + + Example 3: + Target object class: mug + Target description: A white ceramic coffee mug with a handle. + Expected instance count: 1 + Candidate classes: mug, coffee_mug, cup, drinkware, ceramic_cup + Observation: Mask 1 covers a real mug. Mask 2 covers a bowl. + Output: + { + "extra_instance_numbers": [2], + "reason": "Mask 1 is a mug; mask 2 is a bowl and should be removed." + } + + Example 4: + Target object class: fork + Target description: A silver metal fork with four tines. + Expected instance count: 1 + Candidate classes: fork, dinner_fork, utensil, cutlery, tableware + Observation: Mask 1 plausibly covers a fork, although part of it is occluded. + Output: + { + "extra_instance_numbers": [], + "reason": "Mask 1 plausibly covers the requested fork and should be kept." + } + + + + - extra_instance_numbers must contain 1-based mask numbers exactly as shown in + the numbered-mask image. + - If no masks should be removed, output an empty list. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +filter_extra_instances_user: | + Verify the numbered segmentation masks for this object class: + + + Target object class: $name + Target description: $description + Expected instance count: $expected_count + Candidate classes: $class_candidate + + + + Inspect the numbered-mask image. + Return the 1-based numbers of masks that should be removed so the remaining + masks best match the target description and expected instance count. + + +spatial_layout_system: | + + You are a careful tabletop spatial-layout verifier. + + + + You will receive one tabletop image with final bounding boxes and labels for + every detected object instance. Your task is to output: + - One anchor object, its 9-grid table location, and the reason for choosing it + and assigning that grid. + - Object groups ordered from left to right. + - Object groups ordered from front to back. + - Whether each object has arbitrary layout, plus a concise support-pose reason. + + Do not output pairwise left/right/front/behind relations. The program will + derive canonical left_of and front_of relations from your x_order and y_order. + Use ordered groups conservatively. Prefer fewer relations over a wrong + relation. + + + + - x_order must be ordered from image/table left to image/table right. + - y_order must be ordered from table front to table back. + - Split x_order groups when the left/right order is reasonably clear from the + bbox-name image. + - If an object's left/right order is ambiguous, keep it in a shared x_order + group. Never omit it. + - Front/back is especially hard to judge. Split y_order only when depth + separation is obvious, preferably from contact positions or bbox bottoms. + - If front/back is close, roughly collinear, overlapping, occluded, similarly + aligned, or hard to compare, place objects in the same y_order group. + - Ordered groups are interpreted as monotonic DAG ranks. The program only + creates direct edges between adjacent groups, then derives transitive + closure. For example, G1 < G2 < G3 creates direct edges G1 -> G2 and + G2 -> G3; G1 -> G3 is implicit. + + + + - Choose one clearly visible object as anchor. + - Prefer a large, unoccluded object whose 9-grid location is easy to judge. + - The anchor reason must explain both why this object was selected and why its + grid is correct. + - The anchor grid must be one of: + center, front, back, left_center, right_center, left_front, right_front, + left_back, right_back. + + + + - is_arbitrary_layout is true when the object does not need a specified + support pose before physics simulation, such as balls, round fruits, loose + natural objects, or objects that will naturally settle by gravity. + - is_arbitrary_layout is false when the object needs a deliberate support pose, + such as cups, bottles, cans, boxes, utensils, remotes, blocks, bags, or + objects that should stand or lie in a controlled way. + - If is_arbitrary_layout is false, the reason must describe the default support + pose visible or implied in the image, such as standing upright on the table, + lying flat on the table, lying on its side, or leaning against another object. + - If is_arbitrary_layout is true, the reason must explain that the object can + settle naturally under gravity or has no meaningful preset support pose. + + + + { + "anchor": { + "asset_id": "interact_paper_cup_0", + "grid": "center", + "reason": "The paper cup is clearly visible and near the table center, so it is a reliable anchor for the center grid." + }, + "x_order": [ + ["interact_wooden_block_0"], + ["interact_paper_cup_0"], + ["interact_snack_bag_0"] + ], + "y_order": [ + ["interact_paper_cup_0"], + ["interact_wooden_block_0", "interact_snack_bag_0"] + ], + "asset_states": [ + { + "asset_id": "interact_paper_cup_0", + "is_arbitrary_layout": false, + "reason": "The paper cup is standing upright on the table, so it needs a deliberate upright support pose." + } + ] + } + + + + - Every provided asset_id must appear exactly once in x_order. + - Every provided asset_id must appear exactly once in y_order. + - Every provided asset_id must appear exactly once in asset_states. + - Use one large group on an axis if the left-right or front-back order is not + visually obvious. Do not omit uncertain objects. + - anchor.asset_id must be one of the provided asset_ids. + - anchor.reason and every asset state reason must be concise but explicit. + - Only the anchor may have a grid. Do not add grid to asset_states. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +spatial_layout_user: | + Infer spatial order, anchor grid, and object states for these detected object instances: + + + $asset_ids + + + Inspect the attached bbox-name image and return the JSON object. diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml new file mode 100644 index 00000000..cabf99cb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml @@ -0,0 +1,468 @@ +name: scene_intake +version: 1 + +text_system: | + + You are a careful 3D tabletop scene intake assistant for TEXT input. + + + + You will receive a text description of a tabletop scene. + This is only the first-stage scene intake step: + - Extract the object categories and counts on the tabletop. + - Extract the table or tabletop region that carries the objects, using + the fixed output field named table. + + Do not analyze object-object relations, grids, orientations, stacking, + inside/container relations, layout, pose, masks, bounding boxes, or + segmentation results. + + + + - Output only real physical objects that can become 3D asset generation targets. + - Do not include the table or tabletop region in assets. + - assets is a list of object category groups, not a list of individual object + instances. + - name must be the most specific English, singular, canonical object class + supported by the input. + - Prefer a concrete small category over a broad category. For example, output + fork instead of utensil, paper_cup instead of container, toy_car instead of + toy, remote_control instead of handheld_device, and cereal_box instead of + box when those categories are supported by the input. + - Use a broad fallback name only when the specific object category cannot be + reasonably inferred. + - Prefer snake_case names, such as apple, banana, soccer_ball, coffee_mug. + - Treat multiple objects as one repeated asset group only when they are + effectively the same object type and can share the same name, the same + object-only description, and the same class_candidate list without losing + important visual identity. + - Never output two asset rows with the same name. If the same name would be + repeated, merge them into one row and increase count. + - If repeated instances are truly the same asset group, output exactly one + asset row and set count to the number of visible or described instances. + - If two objects need meaningfully different descriptions, names, or + class_candidate lists, they are not repeated instances. Output separate + asset rows with specific different names. + - Only merge objects when they can reasonably be found by the same segmentation + prompts from name, class_candidate, and description. + - Do not merge visually different subtypes under a broad name. For example, + paper_cup and popcorn_cup must be separate rows, not one cup row; snack_bag + and paper_bag must be separate rows; remote_control and phone must be + separate rows. + - Do not output instance IDs such as apple_0 or banana_0. Instance IDs will be + generated by code from name and count. + - Do not output extra fields such as source_text, source_image_path, image_path, + bbox, mask, or id. + - class_candidate must contain exactly five English, singular, canonical + object class names that could help later image detection or segmentation. + - class_candidate must prioritize specific small categories. The first item + must equal name. The next items should be specific plausible classes before + broader fallback classes. + - Do not replace a known small category with a broad category. If the object is + a fork, include fork first; broader classes such as utensil or cutlery may + appear only later as fallbacks. + - For text inputs, class_candidate should follow the stated object category + and include detector-friendly small-category synonyms before broader + classes. + + + + - table.name, table.description, table.complete_table_description, + table.class_candidate, and every asset.description must be non-empty. + - Descriptions are used to generate images and then 3D geometry. + - Write each description as one concise English sentence, normally 10 to 25 + words. + - Every description must describe a SINGLE STANDALONE OBJECT isolated on a + pure-white background. Do NOT mention any other object, the table, the scene, + the room, or any background context. + - Do NOT include any spatial, positional, or layout information such as + "sitting on the table", "placed in front of", "to the left of", "on a + surface", "on the tabletop", etc. + - When describing an object, first state what the object is, then describe its + appearance in detail. + - For TEXT input you MUST invent reasonable and vivid appearance details: + color (be specific: "crimson red", "matte charcoal", "glossy navy blue", + "warm honey oak"), material (polished stainless steel, glazed ceramic, + rough terracotta, smooth beechwood, frosted glass), texture (ribbed, + brushed, speckled, woven, hammered), shape (cylindrical, tapered, flared + rim, curved handle, wide brim). + - Vary colours across objects — do not make everything white or neutral. + A tabletop scene naturally has diverse materials and hues. + - table.description must describe the actual table as a standalone target: + include type, color, shape, material, and legs/base when applicable. + - table.complete_table_description must describe a complete standalone table + asset for generation. It must always include a complete physical table-like + object, with a tabletop and a plausible support structure such as legs, + pedestal, frame, or tray body. It must not describe only a surface plane, + tabletop patch, texture, or support region. + - Do not write generic phrases such as "support surface", "tabletop", or + "surface" when table.name is a concrete object such as table, desk, tray, + counter, shelf, or floor. Use the concrete class in the description. + - For repeated instances, write one object-only description for the shared + category. Do not mention instance positions. + - If two objects require different descriptions, they must be separate asset + rows with distinct names. + + + + - Do not output a table id. The code will set table.id to "table". + - The table field represents the scene table or tabletop target. table.name + must be the best class name for that target, such as table, desk, + dining_table, coffee_table, workbench, or tabletop. + - table.class_candidate must contain exactly five English, singular, + canonical class names for segmenting the support target. The first item must + equal table.name. + - For text inputs, set table.is_complete_visible_table to false. + + + + { + "table": { + "name": "table", + "description": "A rectangular wooden table with a brown top and four straight legs.", + "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", + "is_complete_visible_table": false, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + }, + "assets": [ + { + "name": "apple", + "description": "A shiny deep-red apple with a smooth curved shape and a small brown stem on top.", + "class_candidate": ["apple", "fruit", "red_apple", "food", "produce"], + "count": 1 + }, + { + "name": "coffee_mug", + "description": "A glossy navy blue ceramic coffee mug with a curved handle and a slightly flared rim.", + "class_candidate": ["coffee_mug", "ceramic_mug", "mug", "cup", "drinkware"], + "count": 2 + } + ] + } + + + + - The top-level object must contain only table and assets. + - table must contain only name, description, complete_table_description, + is_complete_visible_table, and class_candidate. + - Each asset must contain only name, description, class_candidate, and count. + - table.name must be a non-empty string. + - table.description must be a non-empty string. + - table.complete_table_description must be a non-empty string. + - table.is_complete_visible_table must be a boolean. + - table.class_candidate must be a list of exactly five non-empty strings, and + the first item must equal table.name. + - assets must be a list. + - Each asset.name must be a non-empty string. + - Each asset.description must be a non-empty string. + - Each asset.class_candidate must be a list of exactly five non-empty strings. + - Each asset.count must be an integer greater than or equal to 1. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +image_system: | + + You are a careful 3D tabletop scene intake assistant for IMAGE input. + + + + You will receive one image of a tabletop scene. + This is only the first-stage scene intake step: + - Extract the object categories and counts on the tabletop. + - Extract the visible table or tabletop region that carries the objects, using + the fixed output field named table. + + Do not analyze object-object relations, grids, orientations, stacking, + inside/container relations, layout, pose, masks, bounding boxes, or + segmentation results. + + + + - Output only real physical objects that can become 3D asset generation targets. + - Do not include the table or tabletop region in assets. + - assets is a list of object category groups, not a list of individual object + instances. + - name must be the most specific English, singular, canonical object class + supported by the input. + - Prefer a concrete small category over a broad category. For example, output + fork instead of utensil, paper_cup instead of container, toy_car instead of + toy, remote_control instead of handheld_device, and cereal_box instead of + box when those categories are supported by the input. + - Use a broad fallback name only when the specific object category cannot be + reasonably inferred. + - Prefer snake_case names, such as apple, banana, soccer_ball, coffee_mug. + - Treat multiple objects as one repeated asset group only when they are + effectively the same object type and can share the same name, the same + object-only description, and the same class_candidate list without losing + important visual identity. + - Never output two asset rows with the same name. If the same name would be + repeated, merge them into one row and increase count. + - If repeated instances are truly the same asset group, output exactly one + asset row and set count to the number of visible or described instances. + - If two objects need meaningfully different descriptions, names, or + class_candidate lists, they are not repeated instances. Output separate + asset rows with specific different names. + - Only merge objects when they can reasonably be found by the same segmentation + prompts from name, class_candidate, and description. + - Do not merge visually different subtypes under a broad name. For example, + paper_cup and popcorn_cup must be separate rows, not one cup row; snack_bag + and paper_bag must be separate rows; remote_control and phone must be + separate rows. + - Do not output instance IDs such as apple_0 or banana_0. Instance IDs will be + generated by code from name and count. + - Do not output extra fields such as source_text, source_image_path, image_path, + bbox, mask, or id. + - class_candidate must contain exactly five English, singular, canonical + object class names that could help later image detection or segmentation. + - class_candidate must prioritize specific small categories. The first item + must equal name. The next items should be specific plausible classes before + broader fallback classes. + - Do not replace a known small category with a broad category. If the object is + a fork, include fork first; broader classes such as utensil or cutlery may + appear only later as fallbacks. + - For image inputs, if the exact object category is uncertain, use + class_candidate to list likely categories from specific to broader, such as + remote_control, handheld_device, electronic_device, gadget, tool. + + + + - table.name, table.description, table.complete_table_description, + table.class_candidate, and every asset.description must be non-empty. + - Descriptions are used to generate images and then 3D geometry. + - Write each description as one concise English sentence, normally 8 to 20 + words. + - Every description must describe a SINGLE STANDALONE OBJECT isolated on a + pure-white background. Do NOT mention any other object, the table, the scene, + the room, or any background context. + - Do NOT include any spatial, positional, or layout information such as + "sitting on the table", "placed in front of", "to the left of", "on a + surface", "on the tabletop", etc. + - When describing an object, first state what the object is, then mention + visible texture, color, shape, material, and similar appearance details. + - Keep descriptions simple. Focus only on what the object looks like, not + where it is or how it relates to anything else. + - For IMAGE inputs, include ONLY information supported by the image. + Do NOT invent or embellish details not visible in the image. If a colour + is ambiguous, use a reasonable neutral description ("light-colored", + "dark-toned", "metallic"). + - table.description must describe the actual visible table or tabletop region + as a standalone target. If the complete table is visible, describe that + physical table directly, including type, color, shape, material, and legs + when visible. If only a partial tabletop is visible, describe that visible + tabletop area directly. + - table.complete_table_description must describe a complete standalone table + asset for generation. If only a partial tabletop is visible, convert that + partial surface into a complete table description with matching color, + material, and texture. + - table.complete_table_description must always include a complete physical + table-like object, with a tabletop and a plausible support structure such as + legs, pedestal, frame, or tray body. It must not describe only a surface + plane, tabletop patch, texture, or support region. + - Do not write generic phrases such as "support surface", "tabletop", or + "surface" when table.name is a concrete object such as table, desk, tray, + counter, shelf, or floor. Use the concrete class in the description. + - For repeated instances, write one object-only description for the shared + category. Do not mention instance positions. + - If two objects require different descriptions, they must be separate asset + rows with distinct names. + + + + - Do not output a table id. The code will set table.id to "table". + - The table field represents the scene table or tabletop target. table.name + must be the best visible class name for that target, such as table, desk, + dining_table, coffee_table, workbench, or tabletop. + - table.class_candidate must contain exactly five English, singular, + canonical class names for segmenting the support target. The first item must + equal table.name. + - For image inputs, set table.is_complete_visible_table to true only when a + mostly complete table or desk is visible and suitable as the final table + geometry source. "Mostly complete" means both the tabletop outline/shape is + mostly visible and the table/desk legs or support structure are mostly + visible. + - Set table.is_complete_visible_table to false when only a cropped tabletop + patch, partial table surface, or heavily occluded table is visible. + - Set table.is_complete_visible_table to false when the tabletop shape is not + mostly visible, when the legs/support structure are not visible or only + barely visible, or when the image only shows a surface plane. + - If table.is_complete_visible_table is false, table.description may describe + the visible partial tabletop, but table.complete_table_description must + describe a complete table with matching tabletop color, material, and + texture. + - If table.description describes only a visible surface or tabletop patch, + table.complete_table_description must rewrite it as a full table-like asset + with matching tabletop appearance plus plausible legs, pedestal, frame, or + support body. + + + + { + "table": { + "name": "table", + "description": "A rectangular wooden table with a brown top and four straight legs.", + "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", + "is_complete_visible_table": false, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + }, + "assets": [ + { + "name": "apple", + "description": "A round apple with smooth red skin visible on the table.", + "class_candidate": ["apple", "fruit", "red_apple", "food", "produce"], + "count": 1 + }, + { + "name": "coffee_mug", + "description": "A white ceramic coffee mug with a curved handle.", + "class_candidate": ["coffee_mug", "ceramic_mug", "mug", "cup", "drinkware"], + "count": 2 + } + ] + } + + + + - The top-level object must contain only table and assets. + - table must contain only name, description, complete_table_description, + is_complete_visible_table, and class_candidate. + - Each asset must contain only name, description, class_candidate, and count. + - table.name must be a non-empty string. + - table.description must be a non-empty string. + - table.complete_table_description must be a non-empty string. + - table.is_complete_visible_table must be a boolean. + - table.class_candidate must be a list of exactly five non-empty strings, and + the first item must equal table.name. + - assets must be a list. + - Each asset.name must be a non-empty string. + - Each asset.description must be a non-empty string. + - Each asset.class_candidate must be a list of exactly five non-empty strings. + - Each asset.count must be an integer greater than or equal to 1. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +text_user: | + Extract the objects and support target from this text: + $text + +image_user: | + Extract tabletop objects and the visible support target from this image. + +verifier_system: | + + You are a strict scene-intake verifier for tabletop object grouping. + + + + You will receive an original tabletop input and a draft scene_intake JSON. + Verify and correct the draft so it follows the same scene_intake schema. + + Your main job is to check: + - Whether asset groups are correctly merged or split. + - Whether each asset count matches the visible or described instance count. + - Whether each name is specific enough for later image segmentation. + - Whether table.name, table.description, table.complete_table_description, + table.is_complete_visible_table, and table.class_candidate describe the + actual table/tabletop target. + - For image inputs, independently re-check table.is_complete_visible_table + against the original image. + - Independently re-check that table.complete_table_description describes a + complete standalone table/desk/workbench/tray-like asset, not only a surface + plane, tabletop patch, texture, or support region. + + Return the corrected scene_intake JSON. Do not return comments, diffs, or + explanations. + + + + - assets is a list of object category groups, not individual instances. + - Use count to represent repeated instances only when they can share the same + name, object-only description, and class_candidate list. + - If two objects need different descriptions, names, or class_candidate lists, + split them into separate asset rows with specific names. + - Never keep two asset rows with the same name. If they are truly repeated + instances, merge them and increase count. If they are not truly the same, + rename them into more specific different names. + - Do not merge visually different subtypes under a broad name. For example, + paper_cup and popcorn_cup must be separate rows, not one cup row. + - Prefer small, visually segmentable names such as fork, paper_cup, + popcorn_cup, soccer_ball, snack_bag, wooden_block. + - Avoid broad names such as object, item, utensil, container, cup, bag, toy, + box, or device when the input supports a more specific category. + - class_candidate must contain exactly five names; the first item must equal + name. + - table.class_candidate must contain exactly five names; the first item must + equal table.name. + - Preserve the fixed table field as the table/tabletop target. + - For text inputs, table.is_complete_visible_table must be false. + - For image inputs, do not trust the draft value of + table.is_complete_visible_table. Judge it again from the attached original + image. + - For image inputs, table.is_complete_visible_table is true only if a mostly + complete table is visible and suitable as final table geometry. "Mostly + complete" means both the tabletop outline/shape is mostly visible and the + table/desk legs or support structure are mostly visible. + - If only a partial tabletop is visible, table.is_complete_visible_table must + be false and table.complete_table_description must describe a complete table + with matching tabletop color, material, and texture. + - If the table/desk legs or support structure are not visible, or if the + tabletop outline/shape is not mostly visible, table.is_complete_visible_table + must be false. + - table.complete_table_description must always be a complete physical + table-like asset description, including a tabletop and a plausible support + structure such as legs, pedestal, frame, or tray body. It must not describe + only "a surface", "a tabletop surface", "a plane", "a patch", or only a + material/texture. + - If the draft table.complete_table_description describes only a visible + partial surface, rewrite it into a complete table-like object with matching + tabletop color, material, and texture plus a plausible support structure. + - For image inputs, only count clearly visible target instances. If uncertain, + use the most conservative count supported by the image. + - For text inputs, count only objects explicitly stated or strongly implied by + the text. + + + + { + "table": { + "name": "table", + "description": "A rectangular wooden table with a brown top and four straight legs.", + "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", + "is_complete_visible_table": false, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + }, + "assets": [ + { + "name": "paper_cup", + "description": "A small white paper cup with blue printed details.", + "class_candidate": ["paper_cup", "disposable_cup", "cup", "drinkware", "container"], + "count": 1 + } + ] + } + + + + - The top-level object must contain only table and assets. + - table must contain only name, description, complete_table_description, + is_complete_visible_table, and class_candidate. + - Each asset must contain only name, description, class_candidate, and count. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +verifier_text_user: | + Verify and correct this draft scene_intake JSON against the original text. + + + $text + + + + $scene_intake_json + + +verifier_image_user: | + Verify and correct this draft scene_intake JSON against the attached tabletop image. + + + $scene_intake_json + diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml new file mode 100644 index 00000000..7a267d09 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml @@ -0,0 +1,110 @@ +name: text_relations +version: 1 + +system: | + + You are a strict tabletop text spatial-relation extractor. + + + + Extract only spatial constraints that are explicitly stated or strongly and + directly implied by the user's text. Do not complete the full scene layout. + Do not infer unstated object positions. Output only canonical left_of and + front_of relations. Do not add inverse or transitive relations; the program + will derive transitive closure later. + + + + - object_relations: direct object-object relations stated in text. + - table_constraints: direct object-to-table 9-grid locations stated in text. + - object_layouts: direct object support-pose constraints stated in text. + + + + - Only use these relation values: left_of, front_of. + - If the text says "A is left of B", output exactly A left_of B. + - If the text says "A is right of B", output exactly B left_of A. + - If the text says "A is in front of B", output exactly A front_of B. + - If the text says "A is behind B", output exactly B front_of A. + - Do not output right_of or behind. + - Do not output transitive relations. + - Use only asset names from the provided scene-intake assets. + + + + - Only output table_constraints when the original text explicitly states an + object-to-table region. + - Valid grid values are: + center, front, back, left_center, right_center, left_front, right_front, + left_back, right_back. + - Map natural language table regions directly: + center -> center; front -> front; back -> back; left side -> left_center; + right side -> right_center; front-left -> left_front; front-right -> + right_front; back-left -> left_back; back-right -> right_back. + - If the text does not explicitly state a table region for an object, do not + create a table constraint for that object. + - Do not infer table grid locations from object-object relations. + - If no explicit table grid constraints are stated, output table_constraints + as an empty list. + + + + - Output object_layouts only when the text explicitly describes an object's + support pose or when the object category itself strongly implies arbitrary + layout, such as a ball or round fruit. + - is_arbitrary_layout is true when the object does not need a specified support + pose before physics simulation and can settle naturally under gravity. + - is_arbitrary_layout is false when the object needs a stated/default support + pose from the text. + - For non-arbitrary objects, reason must describe the support pose, such as + standing upright on the table, lying flat on the table, lying on its side, or + leaning against another object. + + + + { + "object_relations": [ + { + "subject": "paper_cup", + "relation": "left_of", + "object": "plate", + "evidence": "The text says the paper cup is left of the plate." + } + ], + "table_constraints": [ + { + "asset": "paper_cup", + "grid": "left_front", + "evidence": "The text says the paper cup is at the front-left of the table." + } + ], + "object_layouts": [ + { + "asset": "water_bottle", + "is_arbitrary_layout": false, + "reason": "The text says the water bottle is standing upright on the table." + } + ] + } + + + + - If no relation of a type is stated, output an empty list for that field. + - Every subject, object, and asset must be one of the provided scene-intake + asset names. + - The top-level object must contain only object_relations, table_constraints, + and object_layouts. + - Do not output anchor or inferred table-region fields. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +user: | + Extract explicit text spatial constraints from this prompt. + + + $asset_names + + + + $text + diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml new file mode 100644 index 00000000..22d33af3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml @@ -0,0 +1,225 @@ +name: unified_scene_gen +version: 1 + +up_down_flip_check_system: | + + You are a careful 3D tabletop geometry orientation verifier. + + + + You will receive: + - Image A: the original tabletop scene photo. + - Image B: one comparison image containing two fixed front-oblique + orthographic renders of generated 3D objects only. Each render has a + visible numeric label. + + Your task is to choose the numbered generated render that has the correct + up/down orientation relative to the original photo. + + + + - Choose selected_number=1 when candidate 1 better matches the original + photo's visible object tops and support-facing sides. + - Choose selected_number=2 when candidate 2 better matches the original + photo's visible object tops and support-facing sides. + - Do not request a yaw rotation around the vertical axis. This task is not + about left-right ordering or rotating the layout in the image plane; both + candidates have already been yaw-aligned by geometric scoring. + - The generated renders are not strict top views. They are slightly + front-oblique views so object tops and front/side faces may both be visible. + - Ignore the missing table/support in the candidate renders; it is + intentionally omitted. + - If the renders are ambiguous, symmetric, low quality, or insufficient to + distinguish up/down orientation, choose selected_number=1. + - confidence must be a number from 0 to 1. + - reason must be concise and explain the visual evidence. + + + + { + "selected_number": 1, + "confidence": 0.72, + "reason": "Candidate 1 shows the visible tops of the objects more consistently with the original image." + } + + + + - Output JSON only. Do not include markdown or explanations outside JSON. + - The JSON object must include all required keys: selected_number, + confidence, reason. + - selected_number must be exactly 1 or 2. + + +up_down_flip_check_user: | + Compare the original scene photo with the numbered generated object-only + front-oblique comparison image. + + + Choose which generated render has the correct up/down orientation. Return + exactly one JSON object with: + - selected_number: 1 or 2 + - confidence: number from 0 to 1 + - reason: short string + + +asset_metric_scale_system: | + + You estimate plausible real-world tabletop object bounding-box dimensions + from semantic descriptions. + + + + Given an object name and description, output one plausible real-world + bounding-box dimension in centimeters. + + + + - The dimensions must be in centimeters. + - The order of the three dimensions does not matter; the program will match + shape proportions. + - Estimate the full real-world object bbox, not only the visible part. + - Use common tabletop object sizes when the description is generic. + - Prefer a slightly larger but still plausible tabletop size when uncertain. + - Use confidence to express semantic certainty, not visual certainty. + - Output JSON only. Do not include markdown or text outside JSON. + + + + { + "bbox_dims_cm": [18.0, 8.0, 5.0], + "confidence": 0.72, + "reason": "Typical compact tabletop item size." + } + + +asset_metric_scale_user: | + Estimate plausible real-world bounding-box dimensions for this object. + + + $object_name + + + + $object_description + + + Return exactly one JSON object with: + - bbox_dims_cm: one slightly generous plausible size, three positive numbers in centimeters + - confidence: number from 0 to 1 + - reason: short string + +image_metric_scale_system: | + + You estimate plausible real-world tabletop object bounding-box dimensions + from a labeled scene image and object descriptions. + + + + You will receive: + - One image with each object marked by a bounding box and its object name. + - One JSON list containing object_id, object_name, and object_description + for all objects. + + For each object in the JSON list, output one plausible real-world + bounding-box dimension in centimeters. + + + + - Output one entry for every object_id in the input JSON. + - Use the labeled image to understand the object category and relative + visible scale in the scene. + - Use object_name and object_description as semantic anchors. + - The dimensions must be in centimeters. + - The order of the three dimensions does not matter. + - Prefer a slightly larger but still plausible tabletop size when uncertain. + - Use confidence to express semantic certainty. + - Output JSON only. Do not include markdown or text outside JSON. + + + + { + "object_scales": [ + { + "object_id": "interact_cup_0", + "bbox_dims_cm": [8.0, 8.0, 12.0], + "confidence": 0.78, + "reason": "Typical tabletop cup size." + } + ] + } + + +image_metric_scale_user: | + Estimate real-world dimensions for every object in the JSON below. + + + $objects_json + + + The attached image has bbox + name labels matching object_name. Return exactly + one JSON object with: + - object_scales: list of objects, one for every input object_id + - object_id: copied exactly from input + - bbox_dims_cm: one slightly generous plausible size, three positive numbers in centimeters + - confidence: number from 0 to 1 + - reason: short string + +text_metric_scale_system: | + + You estimate plausible real-world tabletop object bounding-box dimensions + from a full text scene prompt and object descriptions. + + + + You will receive: + - The user's original scene text. + - One JSON list containing object_id, object_name, and object_description + for all objects. + + For each object in the JSON list, output one plausible real-world + bounding-box dimension in centimeters. + + + + - Output one entry for every object_id in the input JSON. + - Use the full scene text to infer intended object scale and context. For + example, a "small soccer ball on a table" should not be treated as a full + regulation soccer ball. + - Use object_name and object_description as semantic anchors. + - The dimensions must be in centimeters. + - The order of the three dimensions does not matter. + - Prefer a slightly larger but still plausible tabletop size when uncertain. + - Use confidence to express semantic certainty. + - Output JSON only. Do not include markdown or text outside JSON. + + + + { + "object_scales": [ + { + "object_id": "interact_small_soccer_ball_0", + "bbox_dims_cm": [6.0, 6.0, 6.0], + "confidence": 0.74, + "reason": "The scene text describes a small tabletop soccer ball." + } + ] + } + + +text_metric_scale_user: | + Estimate real-world dimensions for every object in the JSON below. + + + $user_text + + + + $objects_json + + + Return exactly one JSON object with: + - object_scales: list of objects, one for every input object_id + - object_id: copied exactly from input + - bbox_dims_cm: one slightly generous plausible size, three positive numbers in centimeters + - confidence: number from 0 to 1 + - reason: short string diff --git a/embodichain/gen_sim/prompt2scene/utils/__init__.py b/embodichain/gen_sim/prompt2scene/utils/__init__.py new file mode 100644 index 00000000..8378c49a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/utils/__init__.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from . import log +from embodichain.gen_sim.prompt2scene.utils.io import ( + image_to_data_url, + relative_path, + write_json, +) +from embodichain.gen_sim.prompt2scene.utils.log import ( + log_api_request_start, + log_info, + log_warning, +) + +__all__ = [ + "log", + "log_api_request_start", + "log_info", + "log_warning", + "image_to_data_url", + "relative_path", + "write_json", +] diff --git a/embodichain/gen_sim/prompt2scene/utils/io.py b/embodichain/gen_sim/prompt2scene/utils/io.py new file mode 100644 index 00000000..6057d198 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/utils/io.py @@ -0,0 +1,66 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import base64 +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = ["image_to_data_url", "relative_path", "write_json"] + + +def relative_path(path: str | Path, root: Path) -> str: + """Return ``path`` relative to ``root`` when it is contained by it.""" + resolved_path = Path(path) + try: + return str(resolved_path.relative_to(root)) + except ValueError: + return str(path) + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + """Write a JSON payload with prompt2scene's default formatting. + + Args: + path: Output JSON file path. + payload: JSON-serializable dictionary payload. + """ + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + if not path.is_file(): + raise FileNotFoundError(f"JSON output was not written: {path}") + log_info(f"Wrote JSON: {path}") + + +def image_to_data_url(image_path: Path) -> str: + """Return a base64 data URL for a local image file.""" + suffix_to_mime = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", + ".gif": "image/gif", + } + mime_type = suffix_to_mime.get(image_path.suffix.lower(), "image/png") + encoded = base64.b64encode(image_path.read_bytes()).decode("ascii") + return f"data:{mime_type};base64,{encoded}" diff --git a/embodichain/gen_sim/prompt2scene/utils/log.py b/embodichain/gen_sim/prompt2scene/utils/log.py new file mode 100644 index 00000000..47bdfa44 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/utils/log.py @@ -0,0 +1,62 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import logging +from typing import Any + +__all__ = ["log_api_request_start", "log_info", "log_warning"] + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [EmbodiChain %(levelname)s]: %(message)s", + datefmt="%H:%M:%S", +) + +_LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.INFO) + + +def _format_message(level: str, message: str) -> str: + _ = level + return f"Prompt2Scene: {message}" + + +def log_info(message: str) -> None: + """Log an info message using the EmbodiChain log prefix.""" + _LOGGER.info(_format_message("INFO", message)) + + +def log_warning(message: str) -> None: + """Log a warning message using the EmbodiChain log prefix.""" + _LOGGER.warning(_format_message("WARNING", message)) + + +def log_api_request_start( + *, + step: str, + request: str, + attempt: int | None = None, + **details: Any, +) -> None: + """Log the start of an API request with a stable key order.""" + fields = [f"step={step}", f"request={request}"] + if attempt is not None: + fields.append(f"attempt={attempt}") + for key, value in details.items(): + fields.append(f"{key}={value}") + log_info("api request start " + " ".join(fields)) diff --git a/embodichain/gen_sim/prompt2scene/workflows/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/__init__.py new file mode 100644 index 00000000..393b0022 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/__init__.py @@ -0,0 +1,41 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + DEBUG_DIRNAME, + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + RAW_MODEL_OUTPUT_FILENAME, + SCENE_INTAKE_STEP, + STEP_RESULT_FILENAME, + TEXT_RELATIONS_STEP, + UNIFIED_SCENE_STEP, + WorkflowArtifactWriter, +) + +__all__ = [ + "DEBUG_DIRNAME", + "IMAGE_SEGMENTS_STEP", + "IMAGE_SPATIAL_RELATIONS_STEP", + "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_INTAKE_STEP", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_STEP", + "UNIFIED_SCENE_STEP", + "WorkflowArtifactWriter", +] diff --git a/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py new file mode 100644 index 00000000..6587ccbb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py @@ -0,0 +1,271 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +import re +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.io import write_json + +__all__ = [ + "DEBUG_DIRNAME", + "IMAGE_SEGMENTS_STEP", + "IMAGE_SPATIAL_RELATIONS_STEP", + "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_INTAKE_STEP", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_STEP", + "UNIFIED_SCENE_GEN_STEP", + "UNIFIED_SCENE_STEP", + "WorkflowArtifactWriter", + "debug_dir_path", + "debug_round_dir_path", + "next_debug_round_dir_path", + "next_debug_round_name", + "step_dir_path", + "step_result_path", + "write_debug_json", + "write_debug_round_json", + "write_next_raw_model_output", + "write_raw_model_output", + "write_step_result", +] + +STEP_RESULT_FILENAME = "result.json" +DEBUG_DIRNAME = "debug" +RAW_MODEL_OUTPUT_FILENAME = "raw_model_output.json" + +SCENE_INTAKE_STEP = "scene_intake" +IMAGE_SEGMENTS_STEP = "image_segments" +IMAGE_SPATIAL_RELATIONS_STEP = "image_spatial_relations" +TEXT_RELATIONS_STEP = "text_relations" +UNIFIED_SCENE_STEP = "unified_scene" +UNIFIED_SCENE_GEN_STEP = "unified_scene_gen" + +DEBUG_ROUND_PATTERN = re.compile(r"^round_(\d+)(?:_|$)") + + +def step_dir_path(output_root: Path, step_name: str) -> Path: + """Return the directory path for a pipeline step.""" + return output_root / step_name + + +def step_result_path(output_root: Path, step_name: str) -> Path: + """Return the final result JSON path for a pipeline step.""" + return step_dir_path(output_root, step_name) / STEP_RESULT_FILENAME + + +def debug_dir_path(output_root: Path, step_name: str) -> Path: + """Return the debug directory path for a pipeline step.""" + return step_dir_path(output_root, step_name) / DEBUG_DIRNAME + + +def debug_round_dir_path( + output_root: Path, + step_name: str, + round_name: str, +) -> Path: + """Return a debug subdirectory path for one model/tool round.""" + return debug_dir_path(output_root, step_name) / round_name + + +def next_debug_round_name( + output_root: Path, + step_name: str, + label: str | None = None, +) -> str: + """Return the next step-local debug round name.""" + debug_dir = debug_dir_path(output_root, step_name) + max_index = 0 + if debug_dir.is_dir(): + for path in debug_dir.iterdir(): + if not path.is_dir(): + continue + match = DEBUG_ROUND_PATTERN.match(path.name) + if match is not None: + max_index = max(max_index, int(match.group(1))) + round_name = f"round_{max_index + 1:03d}" + if label: + round_name = f"{round_name}_{_path_token(label)}" + return round_name + + +def next_debug_round_dir_path( + output_root: Path, + step_name: str, + label: str | None = None, +) -> Path: + """Return the next step-local debug round directory path.""" + return debug_round_dir_path( + output_root, + step_name, + next_debug_round_name(output_root, step_name, label), + ) + + +def write_step_result( + output_root: Path, + step_name: str, + payload: dict[str, Any], +) -> Path: + """Write a step's final result JSON and return its path.""" + path = step_result_path(output_root, step_name) + write_json(path, payload) + return path + + +def write_debug_json( + output_root: Path, + step_name: str, + round_name: str, + filename: str, + payload: dict[str, Any], +) -> Path: + """Write a debug JSON file under one step debug round.""" + path = debug_round_dir_path(output_root, step_name, round_name) / filename + write_json(path, payload) + return path + + +def write_debug_round_json( + debug_round_dir: Path, + filename: str, + payload: dict[str, Any], +) -> Path: + """Write a debug JSON file under an already selected debug round directory.""" + path = debug_round_dir / filename + write_json(path, payload) + return path + + +def write_raw_model_output( + output_root: Path, + step_name: str, + round_name: str, + payload: dict[str, Any], +) -> Path: + """Write one raw structured model output under a step debug round.""" + return write_debug_json( + output_root, + step_name, + round_name, + RAW_MODEL_OUTPUT_FILENAME, + payload, + ) + + +def write_next_raw_model_output( + output_root: Path, + step_name: str, + payload: dict[str, Any], + label: str | None = None, +) -> Path: + """Write raw model output under the next step-local debug round.""" + round_name = next_debug_round_name(output_root, step_name, label) + return write_raw_model_output(output_root, step_name, round_name, payload) + + +class WorkflowArtifactWriter: + """Write workflow artifacts under a fixed step directory.""" + + def __init__(self, output_root: Path, step_name: str) -> None: + self._output_root = output_root + self._step_name = step_name + + @property + def output_root(self) -> Path: + return self._output_root + + @property + def step_name(self) -> str: + return self._step_name + + @property + def step_dir(self) -> Path: + return step_dir_path(self._output_root, self._step_name) + + @property + def debug_dir(self) -> Path: + return debug_dir_path(self._output_root, self._step_name) + + @property + def result_path(self) -> Path: + return step_result_path(self._output_root, self._step_name) + + def next_debug_round_name(self, label: str | None = None) -> str: + """Return the next debug round name for this step.""" + return next_debug_round_name(self._output_root, self._step_name, label) + + def next_debug_round_dir(self, label: str | None = None) -> Path: + """Return the next debug round directory for this step.""" + return next_debug_round_dir_path(self._output_root, self._step_name, label) + + def debug_round_dir(self, round_name: str) -> Path: + """Return one debug round directory under this step.""" + return debug_round_dir_path(self._output_root, self._step_name, round_name) + + def write_step_result(self, payload: dict[str, Any]) -> Path: + """Write the step's final result JSON.""" + return write_step_result(self._output_root, self._step_name, payload) + + def write_debug_round_json( + self, + *, + round_name: str, + filename: str, + payload: dict[str, Any], + ) -> Path: + """Write a JSON artifact inside one named debug round.""" + return write_debug_round_json( + self.debug_round_dir(round_name), + filename=filename, + payload=payload, + ) + + def write_raw_model_output( + self, + *, + round_name: str, + payload: dict[str, Any], + ) -> Path: + """Write a raw model output into one named debug round.""" + return write_raw_model_output( + self._output_root, + self._step_name, + round_name, + payload, + ) + + def write_next_raw_model_output( + self, + *, + payload: dict[str, Any], + label: str | None = None, + ) -> Path: + """Write a raw model output into the next available debug round.""" + return write_next_raw_model_output( + self._output_root, + self._step_name, + payload, + label=label, + ) + + +def _path_token(value: str) -> str: + token = "".join(character if character.isalnum() else "_" for character in value) + return token.strip("_")[:80] or "round" diff --git a/embodichain/gen_sim/prompt2scene/workflows/attempt_state.py b/embodichain/gen_sim/prompt2scene/workflows/attempt_state.py new file mode 100644 index 00000000..15407e78 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/attempt_state.py @@ -0,0 +1,30 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import TypedDict + +__all__ = ["AttemptState"] + + +class AttemptState(TypedDict): + """Common retry/error fields for one model-call stage.""" + + attempt_count: int + max_attempts: int + last_error: str | None + errors: list[str] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py new file mode 100644 index 00000000..ab49ab72 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.graph import ( + build_image_relations_graph, + run_image_relations, +) + +__all__ = ["build_image_relations_graph", "run_image_relations"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py new file mode 100644 index 00000000..ff67f3a0 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py @@ -0,0 +1,189 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import ( + OpenAICompatibleLLMCfg, + build_chat_model, +) +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.nodes import ( + call_vlm_filter_initial_segments_node, + call_vlm_spatial_layout_node, + normalize_asset_segments_node, + prepare_segmentation_input_node, + retry_missing_by_candidates_node, + segment_table_node, + segment_by_name_node, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.image_relations.state import ( + ImageRelationsState, +) + +__all__ = ["build_image_relations_graph", "run_image_relations"] + + +def route_after_filter_extra_instances(state: ImageRelationsState) -> str: + """Route to retry or continue after VLM extra-instance filtering.""" + if state["last_error"] is None: + return "continue" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "continue" + + +def route_after_spatial_layout(state: ImageRelationsState) -> str: + """Route to retry or finish after spatial-layout extraction.""" + if state["last_error"] is None: + return "end" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def build_image_relations_graph(llm: Any) -> Any: + """Build the fixed LangGraph image asset segmentation workflow.""" + graph = StateGraph(ImageRelationsState) + graph.add_node("prepare_segmentation_input", prepare_segmentation_input_node) + graph.add_node("segment_by_name", segment_by_name_node) + graph.add_node( + "call_vlm_filter_initial_segments", + lambda state: call_vlm_filter_initial_segments_node(state, llm=llm), + ) + graph.add_node( + "retry_missing_by_candidates", + lambda state: retry_missing_by_candidates_node(state, llm=llm), + ) + graph.add_node("normalize_asset_segments", normalize_asset_segments_node) + graph.add_node( + "segment_table", + lambda state: segment_table_node(state, llm=llm), + ) + graph.add_node( + "call_vlm_spatial_layout", + lambda state: call_vlm_spatial_layout_node(state, llm=llm), + ) + + graph.set_entry_point("prepare_segmentation_input") + graph.add_edge("prepare_segmentation_input", "segment_by_name") + graph.add_edge("segment_by_name", "call_vlm_filter_initial_segments") + graph.add_conditional_edges( + "call_vlm_filter_initial_segments", + route_after_filter_extra_instances, + { + "retry": "call_vlm_filter_initial_segments", + "continue": "retry_missing_by_candidates", + }, + ) + graph.add_edge("retry_missing_by_candidates", "normalize_asset_segments") + graph.add_edge("normalize_asset_segments", "segment_table") + graph.add_edge("segment_table", "call_vlm_spatial_layout") + graph.add_conditional_edges( + "call_vlm_spatial_layout", + route_after_spatial_layout, + { + "retry": "call_vlm_spatial_layout", + "end": END, + }, + ) + return graph.compile() + + +def run_image_relations( + request: Prompt2SceneInput, + *, + scene_intake: SceneIntakeSpec, + llm_cfg: OpenAICompatibleLLMCfg, + output_root: Path, +) -> ImageRelationSpec: + """Run image asset segmentation alignment for one prompt2scene request.""" + llm = build_chat_model(llm_cfg) + graph = build_image_relations_graph(llm) + result = graph.invoke( + { + "request": request, + "scene_intake": scene_intake, + "output_root": output_root, + "segment_groups": [], + "raw_model_output": None, + "image_relations": None, + "attempt_count": 0, + "max_attempts": llm_cfg.max_attempts, + "last_error": None, + "errors": [], + } + ) + + image_relations = result.get("image_relations") + if ( + image_relations is not None + and image_relations.status == "ok" + and image_relations.anchor is not None + ): + return image_relations + if image_relations is not None and image_relations.status == "ok": + error = format_result_missing_error( + "Image relations", + "spatial layout", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) + if image_relations is not None: + failed_groups = [ + group.to_manifest() + for group in image_relations.groups + if group.status != "ok" + ] + if ( + image_relations.table_group is not None + and image_relations.table_group.status != "ok" + ): + failed_groups.append(image_relations.table_group.to_manifest()) + error = ( + "Image relations failed to align all image segments. " + f"Failed groups: {failed_groups}" + ) + log.log_warning(error) + raise RuntimeError(error) + + error = format_result_missing_error( + "Image relations", + "ImageRelationSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py new file mode 100644 index 00000000..ab8b6952 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py @@ -0,0 +1,511 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + decode_rle_mask, + draw_numbered_masks, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageAssetSegment, + ImageRelationGroup, + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.request import InputKind +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + SPATIAL_LAYOUT_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log_api_request_start, + log, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.utils import ( + append_unique, + apply_spatial_layout_output, + asset_bbox_label, + draw_labeled_bboxes, + expand_asset_ids, + filter_group_segments_with_vlm, + filter_segments_with_vlm, + merge_non_overlapping_segments, + prompt_text, + path_token, + require_image_path, + segment_prompt, + segments_from_response, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.prompts import ( + build_filter_extra_instances_messages, + build_spatial_layout_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.state import ( + ImageRelationsState, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, + is_model_output_error, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) + +__all__ = [ + "call_vlm_filter_extra_instances_node", + "call_vlm_filter_initial_segments_node", + "call_vlm_spatial_layout_node", + "normalize_asset_segments_node", + "prepare_segmentation_input_node", + "retry_missing_by_candidates_node", + "segment_table_node", + "segment_by_name_node", +] + +def prepare_segmentation_input_node(state: ImageRelationsState) -> dict[str, object]: + """Prepare scene-intake asset groups for class-level segmentation.""" + request = state["request"] + if request.input_kind != InputKind.IMAGE or request.image_path is None: + raise ValueError("Image relations requires an image input.") + + segment_groups = [] + for asset in state["scene_intake"].assets: + group = { + "name": asset.name, + "description": asset.description, + "asset_ids": expand_asset_ids(asset.id, asset.count), + "class_candidate": list(asset.class_candidate), + "segments": [], + "tried_prompts": [], + "debug_images": [], + "status": "pending", + "error": None, + "expected_count": asset.count, + } + segment_groups.append(group) + return {"segment_groups": segment_groups} + + +def segment_by_name_node(state: ImageRelationsState) -> dict[str, object]: + """Run SAM3 once per object name.""" + image_path = require_image_path(state) + segment_groups = [] + for group in state["segment_groups"]: + prompt = prompt_text(group["name"]) + response = segment_prompt(image_path=image_path, prompt=prompt) + group = dict(group) + group["tried_prompts"] = append_unique(group["tried_prompts"], prompt) + group["segments"] = segments_from_response( + group=group, + response=response, + source_prompt=prompt, + ) + segment_groups.append(group) + return {"segment_groups": segment_groups} + + +def call_vlm_filter_extra_instances_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Compatibility wrapper for the initial VLM segment filter.""" + return call_vlm_filter_initial_segments_node(state, llm=llm) + + +def call_vlm_filter_initial_segments_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Ask VLM to remove wrong masks from initial name-based SAM3 output.""" + return filter_segments_with_vlm(state=state, llm=llm, stage="initial") +def retry_missing_by_candidates_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Use remaining class candidates to add missing segment instances.""" + image_path = require_image_path(state) + artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + segment_groups = [] + for group in state["segment_groups"]: + group = dict(group) + segments = group["segments"] + expected_count = group["expected_count"] + for candidate_name in group["class_candidate"][1:]: + if len(segments) >= expected_count: + break + prompt = prompt_text(candidate_name) + if prompt in group["tried_prompts"]: + continue + response = segment_prompt(image_path=image_path, prompt=prompt) + group["tried_prompts"] = append_unique(group["tried_prompts"], prompt) + new_segments = segments_from_response( + group=group, + response=response, + source_prompt=prompt, + ) + new_segments = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=new_segments, + stage=f"fallback_{path_token(prompt)}", + ) + segments = merge_non_overlapping_segments( + existing=segments, + incoming=new_segments, + limit=expected_count, + ) + if len(segments) < expected_count: + description_prompt = str(group.get("description") or "").strip() + if description_prompt and description_prompt not in group["tried_prompts"]: + response = segment_prompt( + image_path=image_path, + prompt=description_prompt, + ) + group["tried_prompts"] = append_unique( + group["tried_prompts"], + description_prompt, + ) + new_segments = segments_from_response( + group=group, + response=response, + source_prompt=description_prompt, + ) + new_segments = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=new_segments, + stage="fallback_description", + ) + segments = merge_non_overlapping_segments( + existing=segments, + incoming=new_segments, + limit=expected_count, + ) + group["segments"] = segments + segment_groups.append(group) + return {"segment_groups": segment_groups} + + +def normalize_asset_segments_node(state: ImageRelationsState) -> dict[str, object]: + """Assign final segments to scene-intake asset IDs.""" + image_path = require_image_path(state) + asset_segments: list[ImageAssetSegment] = [] + relation_groups: list[ImageRelationGroup] = [] + status = "ok" + + for group in state["segment_groups"]: + expected_count = group["expected_count"] + segments = group["segments"] + group_status = "ok" + error = None + if len(segments) < expected_count: + group_status = "failed" + error = "missing_segments" + status = "failed" + elif len(segments) > expected_count: + group_status = "failed" + error = "extra_segments" + status = "failed" + + relation_groups.append( + ImageRelationGroup( + name=group["name"], + expected_count=expected_count, + detected_count=len(segments), + status=group_status, + tried_prompts=list(group["tried_prompts"]), + asset_ids=list(group["asset_ids"]), + debug_images=list(group["debug_images"]), + error=error, + ) + ) + + if group_status != "ok": + continue + for asset_id, segment in zip(group["asset_ids"], segments): + asset_segments.append( + ImageAssetSegment( + asset_id=asset_id, + name=group["name"], + segment_id=segment["segment_id"], + bbox_xyxy=list(segment["bbox_xyxy"]), + score=float(segment["score"]), + source_prompt=segment["source_prompt"], + mask_rle=segment.get("mask_rle"), + ) + ) + + bbox_name_image_path = None + if status == "ok": + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + IMAGE_SEGMENTS_STEP, + ) + bbox_name_image_path = str( + draw_labeled_bboxes( + image_path=image_path, + boxes=[ + { + "bbox_xyxy": segment.bbox_xyxy, + "label": asset_bbox_label(segment.asset_id), + } + for segment in asset_segments + ], + output_path=artifact_writer.step_dir / "asset_segments_bbox_name.png", + ) + ) + + image_relations = ImageRelationSpec( + status=status, + image_path=str(image_path), + asset_segments=asset_segments, + groups=relation_groups, + bbox_name_image_path=bbox_name_image_path, + ) + WorkflowArtifactWriter( + state["output_root"], + IMAGE_SEGMENTS_STEP, + ).write_step_result(image_relations.to_segmentation_manifest()) + return {"image_relations": image_relations} + + +def segment_table_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Segment the table/support target after object segmentation is complete.""" + image_relations = state["image_relations"] + if image_relations is None or image_relations.status != "ok": + return {} + + image_path = require_image_path(state) + table = state["scene_intake"].table + artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + group = { + "name": table.name, + "description": table.description, + "asset_ids": [table.id], + "class_candidate": list(table.class_candidate), + "segments": [], + "tried_prompts": [], + "debug_images": [], + "status": "pending", + "error": None, + "expected_count": 1, + } + segments: list[dict[str, Any]] = [] + + for prompt in _table_segmentation_prompts(group): + if len(segments) >= 1: + break + response = segment_prompt(image_path=image_path, prompt=prompt) + group["tried_prompts"] = append_unique(group["tried_prompts"], prompt) + new_segments = segments_from_response( + group=group, + response=response, + source_prompt=prompt, + ) + _write_table_candidate_debug_image( + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=new_segments, + stage=f"table_{path_token(prompt)}", + ) + selected_segment = _select_largest_table_segment(new_segments) + if selected_segment is not None: + segments = [selected_segment] + + group_status = "ok" if len(segments) == 1 else "failed" + error = None if group_status == "ok" else "missing_table_segment" + table_group = ImageRelationGroup( + name=group["name"], + expected_count=1, + detected_count=len(segments), + status=group_status, + tried_prompts=list(group["tried_prompts"]), + asset_ids=[table.id], + debug_images=list(group["debug_images"]), + error=error, + ) + table_segment = None + if group_status == "ok": + segment = segments[0] + table_segment = ImageAssetSegment( + asset_id=table.id, + name=table.name, + segment_id=segment["segment_id"], + bbox_xyxy=list(segment["bbox_xyxy"]), + score=float(segment["score"]), + source_prompt=segment["source_prompt"], + mask_rle=segment.get("mask_rle"), + ) + + updated_image_relations = ImageRelationSpec( + status="ok" if group_status == "ok" else "failed", + image_path=image_relations.image_path, + asset_segments=image_relations.asset_segments, + groups=image_relations.groups, + table_segment=table_segment, + table_group=table_group, + bbox_name_image_path=image_relations.bbox_name_image_path, + anchor=image_relations.anchor, + x_order=image_relations.x_order, + y_order=image_relations.y_order, + asset_layouts=image_relations.asset_layouts, + ) + artifact_writer.write_step_result(updated_image_relations.to_segmentation_manifest()) + return {"image_relations": updated_image_relations} + + +def _table_segmentation_prompts(group: dict[str, Any]) -> list[str]: + """Return table/support segmentation prompts in object-style fallback order.""" + prompts = [prompt_text(group["name"])] + for candidate_name in group["class_candidate"][1:]: + prompts.append(prompt_text(candidate_name)) + description_prompt = str(group.get("description") or "").strip() + if description_prompt: + prompts.append(description_prompt) + + unique_prompts: list[str] = [] + for prompt in prompts: + if prompt and prompt not in unique_prompts: + unique_prompts.append(prompt) + return unique_prompts + + +def _write_table_candidate_debug_image( + *, + image_path: Path, + artifact_writer: WorkflowArtifactWriter, + group: dict[str, Any], + segments: list[dict[str, Any]], + stage: str, +) -> None: + """Write table/support candidate mask debug image without VLM filtering.""" + if not segments: + return + round_name = artifact_writer.next_debug_round_name(label=f"{stage}_{group['name']}") + round_dir = artifact_writer.debug_round_dir(round_name) + debug_image_path = draw_numbered_masks( + image_path=image_path, + segments=segments, + output_path=round_dir / "mask.png", + ) + group["debug_images"] = append_unique( + group["debug_images"], + str(debug_image_path), + ) + + +def _select_largest_table_segment( + segments: list[dict[str, Any]], +) -> dict[str, Any] | None: + """Select the largest SAM3 table/support candidate without VLM filtering.""" + if not segments: + return None + return max(segments, key=_segment_area) + + +def _segment_area(segment: dict[str, Any]) -> float: + mask_rle = segment.get("mask_rle") + if mask_rle is not None: + try: + mask = decode_rle_mask(mask_rle).convert("L") + histogram = mask.histogram() + return float(sum(count for value, count in enumerate(histogram) if value)) + except Exception: + pass + x1, y1, x2, y2 = segment["bbox_xyxy"] + return max(0.0, float(x2) - float(x1)) * max(0.0, float(y2) - float(y1)) + + +def call_vlm_spatial_layout_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Ask VLM for object ordering, anchor grid, and per-object layout states.""" + image_relations = state["image_relations"] + if image_relations is None or image_relations.status != "ok": + return {} + if image_relations.bbox_name_image_path is None: + raise ValueError("Image spatial layout requires bbox_name_image_path.") + + attempt_count = state["attempt_count"] + 1 + asset_ids = [segment.asset_id for segment in image_relations.asset_segments] + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + IMAGE_SPATIAL_RELATIONS_STEP, + ) + messages = build_spatial_layout_messages( + bbox_name_image_path=Path(image_relations.bbox_name_image_path), + asset_ids=asset_ids, + ) + + try: + log_api_request_start( + step=IMAGE_SPATIAL_RELATIONS_STEP, + request="spatial_layout", + attempt=attempt_count, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SPATIAL_LAYOUT_JSON_SCHEMA, + messages=messages, + context="Image spatial layout", + step_name=IMAGE_SPATIAL_RELATIONS_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="spatial_layout", + artifact_writer=artifact_writer, + ) + updated_image_relations = apply_spatial_layout_output( + image_relations=image_relations, + raw_model_output=raw_model_output, + ) + artifact_writer.write_step_result(updated_image_relations.to_spatial_manifest()) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + error = format_attempt_error("Image relations spatial layout", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "last_error": error, + "errors": state["errors"] + [error], + } + raise + return { + "attempt_count": attempt_count, + "image_relations": updated_image_relations, + "last_error": None, + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py new file mode 100644 index 00000000..f974f442 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py @@ -0,0 +1,113 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_filter_extra_instances_messages", + "build_spatial_layout_messages", +] + +IMAGE_RELATIONS_PROMPT_NAME = "image_relations.yaml" + + +def build_filter_extra_instances_messages( + *, + debug_image_path: Path, + name: str, + description: str, + expected_count: int, + class_candidate: list[str], +) -> list[dict[str, Any]]: + """Build LangChain-compatible messages for VLM extra-mask filtering.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + prompt_key="filter_extra_instances_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + { + "name": name.replace("_", " "), + "description": description, + "expected_count": str(expected_count), + "class_candidate": ", ".join( + candidate.replace("_", " ") + for candidate in class_candidate + ), + }, + prompt_key="filter_extra_instances_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(debug_image_path)}, + }, + ], + }, + ] + + +def build_spatial_layout_messages( + *, + bbox_name_image_path: Path, + asset_ids: list[str], +) -> list[dict[str, Any]]: + """Build messages for VLM spatial ordering and object-state extraction.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + prompt_key="spatial_layout_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + { + "asset_ids": "\n".join( + f"- {asset_id}" for asset_id in asset_ids + ), + }, + prompt_key="spatial_layout_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py new file mode 100644 index 00000000..500f7c70 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py @@ -0,0 +1,250 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.spatial import GRID_VALUE_LIST + +__all__ = [ + "FILTER_EXTRA_INSTANCES_JSON_SCHEMA", + "ImageAnchor", + "ImageAssetLayout", + "ImageAssetSegment", + "ImageRelationGroup", + "ImageRelationSpec", + "SPATIAL_LAYOUT_JSON_SCHEMA", +] + +FILTER_EXTRA_INSTANCES_JSON_SCHEMA: dict[str, Any] = { + "title": "FilterExtraImageInstancesOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "extra_instance_numbers": { + "type": "array", + "description": "1-based mask numbers that should be removed.", + "items": {"type": "integer", "minimum": 1}, + }, + "reason": { + "type": "string", + "description": "Brief reason for the removal decision.", + }, + }, + "required": ["extra_instance_numbers", "reason"], +} + +SPATIAL_LAYOUT_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageSpatialLayoutOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "anchor": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset_id": {"type": "string", "minLength": 1}, + "grid": { + "type": "string", + "enum": GRID_VALUE_LIST, + }, + "reason": {"type": "string"}, + }, + "required": ["asset_id", "grid", "reason"], + }, + "x_order": { + "type": "array", + "description": "Asset-id groups ordered from left to right.", + "items": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + }, + "minItems": 1, + }, + "y_order": { + "type": "array", + "description": "Asset-id groups ordered from front to back.", + "items": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + }, + "minItems": 1, + }, + "asset_states": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": True, + "properties": { + "asset_id": {"type": "string", "minLength": 1}, + "is_arbitrary_layout": {"type": "boolean"}, + "reason": {"type": "string", "minLength": 1}, + }, + "required": [ + "asset_id", + "is_arbitrary_layout", + "reason", + ], + }, + }, + }, + "required": ["anchor", "x_order", "y_order", "asset_states"], +} + + +@dataclass(frozen=True) +class ImageAssetSegment: + """Image segmentation result aligned to one scene-intake asset.""" + + asset_id: str + name: str + segment_id: str + bbox_xyxy: list[float] + score: float + source_prompt: str + mask_rle: dict[str, Any] | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the segment to JSON-safe data.""" + return { + "asset_id": self.asset_id, + "name": self.name, + "segment_id": self.segment_id, + "bbox_xyxy": list(self.bbox_xyxy), + "score": self.score, + "source_prompt": self.source_prompt, + "mask_rle": self.mask_rle, + } + + +@dataclass(frozen=True) +class ImageRelationGroup: + """Segmentation alignment status for assets sharing one object name.""" + + name: str + expected_count: int + detected_count: int + status: str + tried_prompts: list[str] = field(default_factory=list) + asset_ids: list[str] = field(default_factory=list) + debug_images: list[str] = field(default_factory=list) + error: str | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the group to JSON-safe data.""" + return { + "name": self.name, + "expected_count": self.expected_count, + "detected_count": self.detected_count, + "status": self.status, + "tried_prompts": list(self.tried_prompts), + "asset_ids": list(self.asset_ids), + "debug_images": list(self.debug_images), + "error": self.error, + } + + +@dataclass(frozen=True) +class ImageAnchor: + """Anchor object used to place relative ordering onto the table grid.""" + + asset_id: str + grid: str + reason: str = "" + + def to_manifest(self) -> dict[str, Any]: + """Convert the anchor to JSON-safe data.""" + return { + "asset_id": self.asset_id, + "grid": self.grid, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class ImageAssetLayout: + """Support state for one image asset instance.""" + + asset_id: str + is_arbitrary_layout: bool + reason: str = "" + + def to_manifest(self) -> dict[str, Any]: + """Convert the layout to JSON-safe data.""" + return { + "asset_id": self.asset_id, + "is_arbitrary_layout": self.is_arbitrary_layout, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class ImageRelationSpec: + """Image asset segmentation alignment and spatial relations.""" + + status: str + image_path: str + asset_segments: list[ImageAssetSegment] + groups: list[ImageRelationGroup] + table_segment: ImageAssetSegment | None = None + table_group: ImageRelationGroup | None = None + bbox_name_image_path: str | None = None + anchor: ImageAnchor | None = None + x_order: list[list[str]] = field(default_factory=list) + y_order: list[list[str]] = field(default_factory=list) + asset_layouts: list[ImageAssetLayout] = field(default_factory=list) + + def to_manifest(self) -> dict[str, Any]: + """Convert the image relation spec to JSON-safe data.""" + manifest = self.to_segmentation_manifest() + manifest.update(self.to_spatial_manifest()) + return manifest + + def to_segmentation_manifest(self) -> dict[str, Any]: + """Convert only the segmentation alignment result to JSON-safe data.""" + return { + "image_path": self.image_path, + "bbox_name_image_path": self.bbox_name_image_path, + "asset_segments": [ + segment.to_manifest() for segment in self.asset_segments + ], + "groups": [group.to_manifest() for group in self.groups], + "table_segment": ( + self.table_segment.to_manifest() if self.table_segment else None + ), + "table_group": ( + self.table_group.to_manifest() if self.table_group else None + ), + } + + def to_spatial_manifest(self) -> dict[str, Any]: + """Convert only spatial relations and layout states to JSON-safe data.""" + return { + "image_path": self.image_path, + "bbox_name_image_path": self.bbox_name_image_path, + "anchor": self.anchor.to_manifest() if self.anchor else None, + "spatial_order": { + "left_to_right": [list(group) for group in self.x_order], + "front_to_back": [list(group) for group in self.y_order], + }, + "objects": [ + layout.to_manifest() for layout in self.asset_layouts + ], + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py new file mode 100644 index 00000000..59853005 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["ImageRelationsState"] + + +class ImageRelationsState(AttemptState): + """LangGraph state for image asset segmentation alignment.""" + + request: Prompt2SceneInput + scene_intake: SceneIntakeSpec + output_root: Path + segment_groups: list[dict[str, Any]] + raw_model_output: dict[str, Any] | None + image_relations: ImageRelationSpec | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py new file mode 100644 index 00000000..27e3b1b3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py @@ -0,0 +1,435 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + ImageSegmentationClient, + ImageSegmentationError, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, + bbox_iou, + draw_labeled_bboxes, + draw_numbered_masks, + is_usable_segmentation_candidate, + sort_segments_by_bbox, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + ImageAnchor, + ImageAssetLayout, + ImageAssetSegment, + ImageRelationGroup, + ImageRelationSpec, + SPATIAL_LAYOUT_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + GRID_VALUES, + validate_exact_asset_id_coverage, +) +from embodichain.gen_sim.prompt2scene.utils import log_api_request_start, log +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + RAW_MODEL_OUTPUT_FILENAME, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.prompts import ( + build_filter_extra_instances_messages, + build_spatial_layout_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, + is_model_output_error, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) + +__all__ = [ + "MAX_SEGMENT_RETRIES", + "OVERLAP_IOU_THRESHOLD", + "append_unique", + "apply_spatial_layout_output", + "asset_bbox_label", + "expand_asset_ids", + "filter_group_segments_with_vlm", + "filter_segments_with_vlm", + "merge_non_overlapping_segments", + "draw_labeled_bboxes", + "parse_anchor", + "parse_asset_states", + "parse_order_groups", + "path_token", + "prompt_text", + "remove_extra_numbered_segments", + "require_image_path", + "segment_prompt", + "segments_from_response", + "sort_segments_by_bbox", +] + +MAX_SEGMENT_RETRIES = 1 +OVERLAP_IOU_THRESHOLD = 0.5 + + +def require_image_path(state: dict[str, Any]) -> Path: + """Return the request image path or raise if the input is invalid.""" + image_path = state["request"].image_path + if image_path is None: + raise ValueError("Image relations requires request.image_path.") + return image_path + + +def prompt_text(name: str) -> str: + """Convert an asset name to a natural-language prompt.""" + return name.replace("_", " ") + + +def asset_bbox_label(asset_id: str) -> str: + """Convert an internal asset id into a display label.""" + prefix = "interact_" + return asset_id[len(prefix) :] if asset_id.startswith(prefix) else asset_id + + +def expand_asset_ids(asset_id: str, count: int) -> list[str]: + """Expand a grouped asset id into instance ids.""" + return [f"{asset_id}_{index}" for index in range(count)] + + +def path_token(value: str) -> str: + """Convert a label into a filesystem-safe token.""" + token = "".join(character if character.isalnum() else "_" for character in value) + return token.strip("_")[:80] or "prompt" + + +def append_unique(values: list[str], value: str) -> list[str]: + """Append a string only if it does not already exist in the list.""" + return values if value in values else values + [value] + + +def segment_prompt( + *, + image_path: Path, + prompt: str, +) -> ImageSegmentationServerResponse: + """Call the segmentation server with a single prompt.""" + client = ImageSegmentationClient() + log_api_request_start( + step=IMAGE_SEGMENTS_STEP, + request="sam3_segment", + prompt=prompt, + ) + result = client.segment( + ImageSegmentationServerRequest(prompt=prompt, image_path=image_path), + max_retries=MAX_SEGMENT_RETRIES, + ) + if isinstance(result, ImageSegmentationError): + log.log_warning(result.error_message) + raise RuntimeError(result.error_message) + return result + + +def segments_from_response( + *, + group: dict[str, Any], + response: ImageSegmentationServerResponse, + source_prompt: str, +) -> list[dict[str, Any]]: + """Convert segmentation server output into internal segment dicts.""" + segments = [] + for candidate in response.result.candidates: + if not is_usable_segmentation_candidate(candidate): + continue + segments.append( + { + "segment_id": f"{group['name']}_{len(segments)}", + "bbox_xyxy": list(candidate.bbox_xyxy), + "score": float(candidate.score), + "mask_rle": candidate.mask_rle, + "source_prompt": source_prompt, + } + ) + return sort_segments_by_bbox(segments) + + +def apply_spatial_layout_output( + *, + image_relations: ImageRelationSpec, + raw_model_output: dict[str, Any], +) -> ImageRelationSpec: + """Apply VLM spatial-layout output to an image-relations spec.""" + asset_ids = [segment.asset_id for segment in image_relations.asset_segments] + asset_id_set = set(asset_ids) + + anchor = parse_anchor(raw_model_output.get("anchor"), asset_id_set=asset_id_set) + x_order = parse_order_groups( + raw_model_output.get("x_order"), + asset_ids=asset_ids, + field_name="x_order", + ) + y_order = parse_order_groups( + raw_model_output.get("y_order"), + asset_ids=asset_ids, + field_name="y_order", + ) + state_by_asset_id = parse_asset_states( + raw_model_output.get("asset_states"), + asset_ids=asset_ids, + ) + asset_layouts = [ + ImageAssetLayout( + asset_id=asset_id, + is_arbitrary_layout=state_by_asset_id[asset_id]["is_arbitrary_layout"], + reason=state_by_asset_id[asset_id]["reason"], + ) + for asset_id in asset_ids + ] + return ImageRelationSpec( + status=image_relations.status, + image_path=image_relations.image_path, + asset_segments=image_relations.asset_segments, + groups=image_relations.groups, + table_segment=image_relations.table_segment, + table_group=image_relations.table_group, + bbox_name_image_path=image_relations.bbox_name_image_path, + anchor=anchor, + x_order=x_order, + y_order=y_order, + asset_layouts=asset_layouts, + ) + + +def parse_anchor(raw_anchor: Any, *, asset_id_set: set[str]) -> ImageAnchor: + """Parse and validate the anchor entry.""" + if not isinstance(raw_anchor, dict): + raise ValueError("anchor must be an object.") + asset_id = str(raw_anchor.get("asset_id") or "").strip() + grid = str(raw_anchor.get("grid") or "").strip() + reason = str(raw_anchor.get("reason") or "").strip() + if asset_id not in asset_id_set: + raise ValueError(f"anchor.asset_id is not a known asset: {asset_id!r}.") + if grid not in GRID_VALUES: + raise ValueError(f"anchor.grid is not valid: {grid!r}.") + return ImageAnchor(asset_id=asset_id, grid=grid, reason=reason) + + +def parse_order_groups( + raw_order: Any, + *, + asset_ids: list[str], + field_name: str, +) -> list[list[str]]: + """Parse ordered asset-id groups from VLM output.""" + if not isinstance(raw_order, list) or not raw_order: + raise ValueError(f"{field_name} must be a non-empty list.") + + groups: list[list[str]] = [] + flattened: list[str] = [] + for group_index, raw_group in enumerate(raw_order): + if not isinstance(raw_group, list) or not raw_group: + raise ValueError(f"{field_name}[{group_index}] must be a non-empty list.") + group: list[str] = [] + for raw_asset_id in raw_group: + asset_id = str(raw_asset_id).strip() + group.append(asset_id) + flattened.append(asset_id) + groups.append(group) + + validate_exact_asset_id_coverage( + values=flattened, + expected_asset_ids=asset_ids, + context=field_name, + ) + return groups + + +def parse_asset_states( + raw_asset_states: Any, + *, + asset_ids: list[str], +) -> dict[str, dict[str, Any]]: + """Parse per-asset layout state annotations.""" + if not isinstance(raw_asset_states, list): + raise ValueError("asset_states must be a list.") + + state_by_asset_id: dict[str, dict[str, Any]] = {} + for state_index, raw_state in enumerate(raw_asset_states): + if not isinstance(raw_state, dict): + raise ValueError(f"asset_states[{state_index}] must be an object.") + asset_id = str(raw_state.get("asset_id") or "").strip() + is_arbitrary_layout = raw_state.get("is_arbitrary_layout") + reason = str(raw_state.get("reason") or "").strip() + if not isinstance(is_arbitrary_layout, bool): + raise ValueError( + f"asset_states[{state_index}].is_arbitrary_layout must be boolean." + ) + if not reason: + raise ValueError(f"asset_states[{state_index}].reason must be non-empty.") + if asset_id in state_by_asset_id: + raise ValueError(f"asset_states has duplicate asset_id: {asset_id!r}.") + state_by_asset_id[asset_id] = { + "is_arbitrary_layout": is_arbitrary_layout, + "reason": reason, + } + + validate_exact_asset_id_coverage( + values=list(state_by_asset_id), + expected_asset_ids=asset_ids, + context="asset_states", + ) + return state_by_asset_id + + +def filter_group_segments_with_vlm( + *, + llm: Any, + image_path: Path, + artifact_writer: WorkflowArtifactWriter, + group: dict[str, Any], + segments: list[dict[str, Any]], + stage: str, +) -> list[dict[str, Any]]: + """Ask VLM to remove wrong or duplicate instances from one SAM3 result.""" + segments = sort_segments_by_bbox(segments) + if not segments: + return segments + + round_name = artifact_writer.next_debug_round_name(label=f"{stage}_{group['name']}") + round_dir = artifact_writer.debug_round_dir(round_name) + debug_image_path = draw_numbered_masks( + image_path=image_path, + segments=segments, + output_path=round_dir / "mask.png", + ) + group["debug_images"] = append_unique( + group["debug_images"], + str(debug_image_path), + ) + log_api_request_start( + step=IMAGE_SEGMENTS_STEP, + request=f"vlm_filter_{stage}", + debug_image=str(debug_image_path), + ) + messages = build_filter_extra_instances_messages( + debug_image_path=debug_image_path, + name=group["name"], + description=group["description"], + expected_count=group["expected_count"], + class_candidate=group["class_candidate"], + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + messages=messages, + context=f"Image relation {stage} segmentation filtering", + step_name=IMAGE_SEGMENTS_STEP, + output_root=None, + attempt_count=0, + raw_output_writer=lambda payload: artifact_writer.write_debug_round_json( + round_name=round_name, + filename=RAW_MODEL_OUTPUT_FILENAME, + payload=payload, + ), + ) + return remove_extra_numbered_segments( + segments=segments, + raw_model_output=raw_model_output, + ) + + +def filter_segments_with_vlm( + *, + state: dict[str, Any], + llm: Any, + stage: str, +) -> dict[str, object]: + """Filter all segment groups with VLM and return an updated state patch.""" + segment_groups = [] + attempt_count = state["attempt_count"] + 1 + image_path = require_image_path(state) + artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + + try: + for group in state["segment_groups"]: + group = dict(group) + group["segments"] = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=group["segments"], + stage=stage, + ) + segment_groups.append(group) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + error = format_attempt_error("Image relations VLM filter", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "last_error": error, + "errors": state["errors"] + [error], + } + raise + + return { + "attempt_count": attempt_count, + "segment_groups": segment_groups, + "last_error": None, + } + + +def remove_extra_numbered_segments( + *, + segments: list[dict[str, Any]], + raw_model_output: dict[str, Any], +) -> list[dict[str, Any]]: + """Remove numbered masks flagged as extra by the VLM.""" + extra_numbers = raw_model_output.get("extra_instance_numbers") + if not isinstance(extra_numbers, list): + raise ValueError("extra_instance_numbers must be a list.") + extra_indices = {int(number) - 1 for number in extra_numbers} + if any(index < 0 or index >= len(segments) for index in extra_indices): + raise ValueError("VLM returned an out-of-range extra mask number.") + kept = [ + segment for index, segment in enumerate(segments) if index not in extra_indices + ] + return kept + + +def merge_non_overlapping_segments( + *, + existing: list[dict[str, Any]], + incoming: list[dict[str, Any]], + limit: int, +) -> list[dict[str, Any]]: + """Merge non-overlapping segments until a limit is reached.""" + merged = list(existing) + for segment in sorted( + incoming, key=lambda item: float(item["score"]), reverse=True + ): + if len(merged) >= limit: + break + if all( + bbox_iou(segment["bbox_xyxy"], other["bbox_xyxy"]) < OVERLAP_IOU_THRESHOLD + for other in merged + ): + merged.append(segment) + return sort_segments_by_bbox(merged) diff --git a/embodichain/gen_sim/prompt2scene/workflows/llm_output.py b/embodichain/gen_sim/prompt2scene/workflows/llm_output.py new file mode 100644 index 00000000..bcc98bcb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/llm_output.py @@ -0,0 +1,285 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Callable + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + WorkflowArtifactWriter, + write_next_raw_model_output, +) + +__all__ = [ + "bind_structured_output", + "coerce_json_object_output", + "is_model_output_error", + "call_structured_json_model_step", + "StructuredModelCallError", + "validate_json_schema", +] + + +class StructuredModelCallError(Exception): + """Retryable structured-model call failure.""" + + def __init__( + self, + *, + context: str, + attempt_count: int, + original_exc: Exception, + ) -> None: + self.context = context + self.attempt_count = attempt_count + self.original_exc = original_exc + super().__init__(str(original_exc)) + + +def bind_structured_output(llm: Any, schema: dict[str, Any]) -> Any: + """Bind a JSON schema to an LLM when the model wrapper supports it.""" + if hasattr(llm, "with_structured_output"): + return llm.with_structured_output(schema) + return llm + + +def coerce_json_object_output(response: Any, *, context: str) -> dict[str, Any]: + """Coerce a model response into a JSON object.""" + if isinstance(response, dict): + return response + + content = getattr(response, "content", response) + if isinstance(content, dict): + return content + + if isinstance(content, list): + text_parts = [ + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ] + content = "\n".join(text_parts) + + if isinstance(content, str): + return _parse_json_text(content, context=context) + + raise ValueError(f"{context} model output has unsupported type: {type(response)!r}") + + +def is_model_output_error(exc: Exception) -> bool: + """Return whether an exception is a retryable model output formatting error.""" + class_name = exc.__class__.__name__ + module_name = exc.__class__.__module__ + return ( + class_name + in { + "JSONDecodeError", + "OutputParserException", + "SchemaValidationError", + "ValidationError", + "StructuredModelCallError", + } + or module_name.startswith("pydantic") + ) + + +def validate_json_schema( + value: Any, + schema: dict[str, Any], + *, + context: str, +) -> None: + """Validate model output against the subset of JSON Schema used locally.""" + _validate_schema_value(value, schema, path=context) + + +def call_structured_json_model_step( + *, + llm: Any, + schema: dict[str, Any], + messages: list[dict[str, Any]], + context: str, + step_name: str, + output_root: Path | None, + attempt_count: int, + raw_output_label: str | None = None, + artifact_writer: WorkflowArtifactWriter | None = None, + raw_output_writer: Callable[[dict[str, Any]], None] | None = None, +) -> dict[str, Any]: + """Call a structured-output model, validate JSON, and persist raw output.""" + model = bind_structured_output(llm, schema) + try: + response = model.invoke(messages) + raw_model_output = coerce_json_object_output(response, context=context) + validate_json_schema( + raw_model_output, + schema, + context=f"{context} output", + ) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + raise StructuredModelCallError( + context=context, + attempt_count=attempt_count, + original_exc=exc, + ) from exc + raise + + if raw_output_writer is not None: + raw_output_writer(raw_model_output) + elif artifact_writer is not None: + artifact_writer.write_next_raw_model_output( + payload=raw_model_output, + label=raw_output_label, + ) + elif output_root is not None: + write_next_raw_model_output( + output_root=output_root, + step_name=step_name, + payload=raw_model_output, + label=raw_output_label, + ) + return raw_model_output + + +def _parse_json_text(content: str, *, context: str) -> dict[str, Any]: + stripped = content.strip() + if stripped.startswith("```"): + lines = stripped.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].startswith("```"): + lines = lines[:-1] + stripped = "\n".join(lines).strip() + parsed = json.loads(stripped) + if not isinstance(parsed, dict): + raise ValueError(f"{context} model output must be a JSON object.") + return parsed + + +def _validate_schema_value(value: Any, schema: dict[str, Any], *, path: str) -> None: + expected_type = schema.get("type") + if expected_type is not None: + _validate_type(value, expected_type, path=path) + + enum_values = schema.get("enum") + if isinstance(enum_values, list) and value not in enum_values: + raise ValueError(f"{path} must be one of {enum_values}.") + + if expected_type == "object" or isinstance(value, dict): + _validate_object(value, schema, path=path) + elif expected_type == "array" or isinstance(value, list): + _validate_array(value, schema, path=path) + elif expected_type == "string" or isinstance(value, str): + _validate_string(value, schema, path=path) + elif expected_type in {"integer", "number"}: + _validate_number(value, schema, path=path) + + +def _validate_type(value: Any, expected_type: Any, *, path: str) -> None: + if isinstance(expected_type, list): + if any(_matches_type(value, item) for item in expected_type): + return + raise ValueError(f"{path} must match one of these types: {expected_type}.") + + if not _matches_type(value, expected_type): + raise ValueError(f"{path} must be {expected_type}.") + + +def _matches_type(value: Any, expected_type: str) -> bool: + if expected_type == "object": + return isinstance(value, dict) + if expected_type == "array": + return isinstance(value, list) + if expected_type == "string": + return isinstance(value, str) + if expected_type == "integer": + return isinstance(value, int) and not isinstance(value, bool) + if expected_type == "number": + return isinstance(value, int | float) and not isinstance(value, bool) + if expected_type == "boolean": + return isinstance(value, bool) + if expected_type == "null": + return value is None + return True + + +def _validate_object(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, dict): + return + + properties = schema.get("properties") + properties = properties if isinstance(properties, dict) else {} + + required = schema.get("required", []) + if isinstance(required, list): + missing = [key for key in required if key not in value] + if missing: + raise ValueError(f"{path} missing required keys: {missing}.") + + if schema.get("additionalProperties") is False: + extra = sorted(set(value) - set(properties)) + if extra: + raise ValueError(f"{path} has unexpected keys: {extra}.") + + for key, child_schema in properties.items(): + if key not in value or not isinstance(child_schema, dict): + continue + _validate_schema_value(value[key], child_schema, path=f"{path}.{key}") + + +def _validate_array(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, list): + return + + min_items = schema.get("minItems") + if isinstance(min_items, int) and len(value) < min_items: + raise ValueError(f"{path} must contain at least {min_items} items.") + + max_items = schema.get("maxItems") + if isinstance(max_items, int) and len(value) > max_items: + raise ValueError(f"{path} must contain at most {max_items} items.") + + items_schema = schema.get("items") + if not isinstance(items_schema, dict): + return + + for index, item in enumerate(value): + _validate_schema_value(item, items_schema, path=f"{path}[{index}]") + + +def _validate_string(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, str): + return + + min_length = schema.get("minLength") + if isinstance(min_length, int) and len(value) < min_length: + raise ValueError(f"{path} must contain at least {min_length} characters.") + + max_length = schema.get("maxLength") + if isinstance(max_length, int) and len(value) > max_length: + raise ValueError(f"{path} must contain at most {max_length} characters.") + + +def _validate_number(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, int | float) or isinstance(value, bool): + return + + minimum = schema.get("minimum") + if isinstance(minimum, int | float) and value < minimum: + raise ValueError(f"{path} must be greater than or equal to {minimum}.") diff --git a/embodichain/gen_sim/prompt2scene/workflows/request.py b/embodichain/gen_sim/prompt2scene/workflows/request.py new file mode 100644 index 00000000..8cd01c30 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/request.py @@ -0,0 +1,110 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +__all__ = ["InputKind", "Prompt2SceneInput"] + +SUPPORTED_IMAGE_SUFFIXES: frozenset[str] = frozenset({".jpg", ".jpeg", ".png"}) + + +class InputKind(str, Enum): + """Supported prompt2scene input kinds.""" + + IMAGE = "image" + TEXT = "text" + + +@dataclass(frozen=True) +class Prompt2SceneInput: + """Normalized prompt2scene input.""" + + input_kind: InputKind + output_root: Path + image_path: Path | None = None + text: str | None = None + + @classmethod + def from_cli_args( + cls, + *, + image_path: Path | None, + text: str | None, + output_root: Path, + ) -> "Prompt2SceneInput": + """Create a prompt2scene input from CLI arguments. + + Args: + image_path: Input image path, if image mode is selected. + text: Text prompt, if text mode is selected. + output_root: Directory where prompt2scene outputs are written. + + Returns: + Normalized prompt2scene input. + + Raises: + FileNotFoundError: If the image input path does not exist. + ValueError: If the image path is invalid or text input is empty. + """ + output_root = output_root.expanduser().resolve() + + if image_path is not None: + image_path = image_path.expanduser().resolve() + cls._validate_image_path(image_path) + return cls( + input_kind=InputKind.IMAGE, + image_path=image_path, + output_root=output_root, + ) + + if text is None or not text.strip(): + raise ValueError("Text input must be non-empty.") + + return cls( + input_kind=InputKind.TEXT, + text=text.strip(), + output_root=output_root, + ) + + def to_manifest(self) -> dict[str, str]: + """Convert the input to a JSON-serializable manifest.""" + manifest: dict[str, str] = { + "input_kind": self.input_kind.value, + "output_root": str(self.output_root), + } + if self.input_kind == InputKind.IMAGE: + image_path = self.image_path + manifest["image_path"] = str(image_path) + else: + text = self.text + manifest["text"] = "" if text is None else text + return manifest + + @staticmethod + def _validate_image_path(image_path: Path) -> None: + """Validate supported image input paths.""" + if not image_path.exists(): + raise FileNotFoundError(f"Image input not found: {image_path}") + if not image_path.is_file(): + raise ValueError(f"Image input is not a file: {image_path}") + if image_path.suffix.lower() not in SUPPORTED_IMAGE_SUFFIXES: + raise ValueError( + "Image input must have one of these extensions: .jpg, .jpeg, .png" + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py new file mode 100644 index 00000000..ac862308 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.graph import ( + build_scene_intake_graph, + run_scene_intake, +) + +__all__ = ["build_scene_intake_graph", "run_scene_intake"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py new file mode 100644 index 00000000..77874b15 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py @@ -0,0 +1,142 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import ( + OpenAICompatibleLLMCfg, + build_chat_model, +) +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.nodes import ( + call_vlm_scene_intake_node, + call_vlm_verify_scene_intake_node, + normalize_scene_intake_node, + normalize_verified_scene_intake_node, + prepare_input_node, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.state import ( + SceneIntakeState, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["build_scene_intake_graph", "run_scene_intake"] + + +def route_after_normalize(state: SceneIntakeState) -> str: + """Route to retry or verify after draft scene intake normalization.""" + if state["draft_scene_intake"] is not None: + return "verify" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def route_after_verified_normalize(state: SceneIntakeState) -> str: + """Route to retry or finish after scene intake verifier normalization.""" + if state["scene_intake"] is not None: + return "end" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def build_scene_intake_graph(llm: Any) -> Any: + """Build the fixed LangGraph scene intake workflow.""" + graph = StateGraph(SceneIntakeState) + graph.add_node("prepare_input", prepare_input_node) + graph.add_node( + "call_vlm_scene_intake", + lambda state: call_vlm_scene_intake_node(state, llm=llm), + ) + graph.add_node("normalize_scene_intake", normalize_scene_intake_node) + graph.add_node( + "call_vlm_verify_scene_intake", + lambda state: call_vlm_verify_scene_intake_node(state, llm=llm), + ) + graph.add_node( + "normalize_verified_scene_intake", + normalize_verified_scene_intake_node, + ) + + graph.set_entry_point("prepare_input") + graph.add_edge("prepare_input", "call_vlm_scene_intake") + graph.add_edge("call_vlm_scene_intake", "normalize_scene_intake") + graph.add_conditional_edges( + "normalize_scene_intake", + route_after_normalize, + { + "retry": "call_vlm_scene_intake", + "verify": "call_vlm_verify_scene_intake", + "end": END, + }, + ) + graph.add_edge("call_vlm_verify_scene_intake", "normalize_verified_scene_intake") + graph.add_conditional_edges( + "normalize_verified_scene_intake", + route_after_verified_normalize, + { + "retry": "call_vlm_verify_scene_intake", + "end": END, + }, + ) + return graph.compile() + + +def run_scene_intake( + request: Prompt2SceneInput, + llm_cfg: OpenAICompatibleLLMCfg, +) -> SceneIntakeSpec: + """Run fixed VLM-based scene intake for one prompt2scene request.""" + llm = build_chat_model(llm_cfg) + graph = build_scene_intake_graph(llm) + result = graph.invoke( + { + "request": request, + "messages": [], + "raw_model_output": None, + "draft_scene_intake": None, + "scene_intake": None, + "attempt_count": 0, + "max_attempts": llm_cfg.max_attempts, + "last_error": None, + "errors": [], + } + ) + + scene_intake = result.get("scene_intake") + if scene_intake is not None: + return scene_intake + + error = format_result_missing_error( + "Scene intake", + "SceneIntakeSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py new file mode 100644 index 00000000..8c7baf55 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py @@ -0,0 +1,211 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SCENE_INTAKE_JSON_SCHEMA, + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log_api_request_start, + log, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + SCENE_INTAKE_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + StructuredModelCallError, + call_structured_json_model_step, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.prompts import ( + build_scene_intake_messages, + build_scene_intake_verifier_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.state import ( + SceneIntakeState, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.utils import ( + build_scene_intake_spec, +) + +__all__ = [ + "call_vlm_scene_intake_node", + "call_vlm_verify_scene_intake_node", + "normalize_scene_intake_node", + "normalize_verified_scene_intake_node", + "prepare_input_node", +] + + +def prepare_input_node(state: SceneIntakeState) -> dict[str, object]: + """Prepare chat messages for the scene intake model call.""" + return {"messages": build_scene_intake_messages(state["request"])} + + +def call_vlm_scene_intake_node( + state: SceneIntakeState, + *, + llm: Any, +) -> dict[str, object]: + """Call the configured VLM for fixed scene intake extraction.""" + attempt_count = state["attempt_count"] + 1 + + try: + log_api_request_start( + step=SCENE_INTAKE_STEP, + request="extract", + attempt=attempt_count, + ) + artifact_writer = WorkflowArtifactWriter( + state["request"].output_root, + SCENE_INTAKE_STEP, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SCENE_INTAKE_JSON_SCHEMA, + messages=state["messages"], + context="Scene intake", + step_name=SCENE_INTAKE_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="extract", + artifact_writer=artifact_writer, + ) + except StructuredModelCallError as exc: + error = format_attempt_error("Scene intake", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "raw_model_output": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return { + "attempt_count": attempt_count, + "raw_model_output": raw_model_output, + "last_error": None, + } + + +def normalize_scene_intake_node(state: SceneIntakeState) -> dict[str, object]: + """Normalize raw VLM JSON into a draft scene intake schema.""" + raw_model_output = state["raw_model_output"] + if raw_model_output is None: + return {} + + try: + scene_intake = build_scene_intake_spec( + request=state["request"], + model_output=raw_model_output, + ) + except ValueError as exc: + error = format_attempt_error("Scene intake", state["attempt_count"], exc) + return { + "draft_scene_intake": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return {"draft_scene_intake": scene_intake, "scene_intake": None} + + +def call_vlm_verify_scene_intake_node( + state: SceneIntakeState, + *, + llm: Any, +) -> dict[str, object]: + """Ask VLM to verify and correct scene-intake grouping and counts.""" + draft_scene_intake = state["draft_scene_intake"] + if draft_scene_intake is None: + return {} + + attempt_count = state["attempt_count"] + 1 + messages = build_scene_intake_verifier_messages( + request=state["request"], + scene_intake=draft_scene_intake, + ) + + try: + log_api_request_start( + step=SCENE_INTAKE_STEP, + request="verify", + attempt=attempt_count, + ) + artifact_writer = WorkflowArtifactWriter( + state["request"].output_root, + SCENE_INTAKE_STEP, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SCENE_INTAKE_JSON_SCHEMA, + messages=messages, + context="Scene intake verifier", + step_name=SCENE_INTAKE_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="verify", + artifact_writer=artifact_writer, + ) + except StructuredModelCallError as exc: + error = format_attempt_error("Scene intake verifier", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "raw_model_output": None, + "scene_intake": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return { + "attempt_count": attempt_count, + "raw_model_output": raw_model_output, + "scene_intake": None, + "last_error": None, + } + + +def normalize_verified_scene_intake_node( + state: SceneIntakeState, +) -> dict[str, object]: + """Normalize verifier output into the final scene intake schema.""" + raw_model_output = state["raw_model_output"] + if raw_model_output is None: + return {} + + try: + scene_intake = build_scene_intake_spec( + request=state["request"], + model_output=raw_model_output, + ) + except ValueError as exc: + error = format_attempt_error("Scene intake verifier", state["attempt_count"], exc) + log.log_warning(error) + return { + "scene_intake": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return {"scene_intake": scene_intake, "last_error": None} diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py new file mode 100644 index 00000000..611c5bf9 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py @@ -0,0 +1,197 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.workflows.request import ( + InputKind, + Prompt2SceneInput, +) +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) + +__all__ = [ + "build_scene_intake_messages", + "build_scene_intake_verifier_messages", +] + +SCENE_INTAKE_PROMPT_NAME = "scene_intake.yaml" + + +def build_scene_intake_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + """Build LangChain-compatible messages for scene intake.""" + if request.input_kind == InputKind.TEXT: + return _build_text_messages(request) + return _build_image_messages(request) + + +def build_scene_intake_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake: SceneIntakeSpec, +) -> list[dict[str, Any]]: + """Build messages for scene-intake group and count verification.""" + scene_intake_json = json.dumps( + { + "table": { + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": ( + scene_intake.table.is_complete_visible_table + ), + "class_candidate": list(scene_intake.table.class_candidate), + }, + "assets": [ + { + "name": asset.name, + "description": asset.description, + "class_candidate": list(asset.class_candidate), + "count": asset.count, + } + for asset in scene_intake.assets + ], + }, + ensure_ascii=False, + indent=2, + ) + if request.input_kind == InputKind.TEXT: + return _build_text_verifier_messages( + request=request, + scene_intake_json=scene_intake_json, + ) + return _build_image_verifier_messages( + request=request, + scene_intake_json=scene_intake_json, + ) + + +def _build_text_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt(SCENE_INTAKE_PROMPT_NAME, prompt_key="text_system"), + }, + { + "role": "user", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + {"text": request.text or ""}, + prompt_key="text_user", + ), + }, + ] + + +def _build_image_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + image_path = request.image_path + if image_path is None: + raise ValueError("Image input requires image_path.") + + return [ + { + "role": "system", + "content": render_prompt(SCENE_INTAKE_PROMPT_NAME, prompt_key="image_system"), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + prompt_key="image_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(image_path)}, + }, + ], + }, + ] + + +def _build_text_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake_json: str, +) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + prompt_key="verifier_system", + ), + }, + { + "role": "user", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + { + "text": request.text or "", + "scene_intake_json": scene_intake_json, + }, + prompt_key="verifier_text_user", + ), + }, + ] + + +def _build_image_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake_json: str, +) -> list[dict[str, Any]]: + image_path = request.image_path + if image_path is None: + raise ValueError("Image input requires image_path.") + + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + prompt_key="verifier_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + {"scene_intake_json": scene_intake_json}, + prompt_key="verifier_image_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py new file mode 100644 index 00000000..80c9ca27 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py @@ -0,0 +1,244 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import ( + InputKind, + Prompt2SceneInput, +) + +__all__ = [ + "SCENE_INTAKE_JSON_SCHEMA", + "SceneIntakeAsset", + "SceneIntakeInputRecord", + "SceneIntakeSpec", + "SceneIntakeTable", +] + +SCENE_INTAKE_JSON_SCHEMA: dict[str, Any] = { + "title": "SceneIntakeModelOutput", + "description": ( + "Objects and table information extracted from a text or image input." + ), + "type": "object", + "additionalProperties": False, + "properties": { + "table": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": { + "type": "string", + "description": ( + "Canonical English class name for the visible table " + "or tabletop target, such as table, desk, dining_table, " + "coffee_table, workbench, or tabletop." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 180, + "description": ( + "One concise standalone appearance description of the " + "visible table or tabletop region." + ), + }, + "complete_table_description": { + "type": "string", + "minLength": 20, + "maxLength": 220, + "description": ( + "One concise standalone description of a complete table " + "asset for text-to-3D generation, matching the visible " + "tabletop color, material, and texture." + ), + }, + "is_complete_visible_table": { + "type": "boolean", + "description": ( + "For image input, whether a mostly complete table is " + "visible and suitable as the final table geometry source. " + "For text input, this should be false." + ), + }, + "class_candidate": { + "type": "array", + "minItems": 5, + "maxItems": 5, + "description": ( + "Exactly five likely class names for segmenting the " + "visible table or tabletop target." + ), + "items": { + "type": "string", + "minLength": 1, + }, + }, + }, + "required": [ + "name", + "description", + "complete_table_description", + "is_complete_visible_table", + "class_candidate", + ], + }, + "assets": { + "type": "array", + "description": ( + "Object category groups on or intended for the tabletop scene." + ), + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": { + "type": "string", + "description": ( + "Canonical English object name, singular, " + "snake_case preferred." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 180, + "description": ( + "One concise appearance description of the object for " + "image and 3D geometry generation." + ), + }, + "class_candidate": { + "type": "array", + "minItems": 5, + "maxItems": 5, + "description": ( + "Exactly five likely object class names for later " + "image detection or segmentation." + ), + "items": { + "type": "string", + "minLength": 1, + }, + }, + "count": { + "type": "integer", + "description": ( + "Number of repeated instances in this object category " + "group. Only group objects that can share the same name, " + "description, and class_candidate list." + ), + "minimum": 1, + }, + }, + "required": ["name", "description", "class_candidate", "count"], + }, + }, + }, + "required": ["table", "assets"], +} + + +@dataclass(frozen=True) +class SceneIntakeInputRecord: + """Normalized input source recorded by scene intake.""" + + input_kind: InputKind + text: str | None = None + image_path: str | None = None + + @classmethod + def from_request(cls, request: Prompt2SceneInput) -> "SceneIntakeInputRecord": + """Create an input record from a prompt2scene request.""" + return cls( + input_kind=request.input_kind, + text=request.text, + image_path=str(request.image_path) if request.image_path else None, + ) + + def to_manifest(self) -> dict[str, str | None]: + """Convert the input record to JSON-safe data.""" + return { + "input_kind": self.input_kind.value, + "text": self.text, + "image_path": self.image_path, + } + + +@dataclass(frozen=True) +class SceneIntakeTable: + """Table/support information extracted during scene intake.""" + + id: str = "table" + name: str = "table" + description: str = "" + complete_table_description: str = "" + is_complete_visible_table: bool = False + class_candidate: list[str] = field(default_factory=list) + + def to_manifest(self) -> dict[str, object]: + """Convert the table record to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "complete_table_description": self.complete_table_description, + "is_complete_visible_table": self.is_complete_visible_table, + "class_candidate": list(self.class_candidate), + } + + +@dataclass(frozen=True) +class SceneIntakeAsset: + """Object category group extracted during scene intake.""" + + id: str + name: str + count: int = 1 + description: str = "" + class_candidate: list[str] = field(default_factory=list) + + def to_manifest(self) -> dict[str, object]: + """Convert the asset record to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "count": self.count, + "description": self.description, + "class_candidate": list(self.class_candidate), + } + + +@dataclass(frozen=True) +class SceneIntakeSpec: + """Unified first-step scene intake output for text and image inputs.""" + + input: SceneIntakeInputRecord + table: SceneIntakeTable + assets: list[SceneIntakeAsset] + + def to_manifest(self) -> dict[str, object]: + """Convert the intake spec to JSON-safe data.""" + return { + "input": self.input.to_manifest(), + "table": self.table.to_manifest(), + "assets": [asset.to_manifest() for asset in self.assets], + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py new file mode 100644 index 00000000..7a96619f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["SceneIntakeState"] + + +class SceneIntakeState(AttemptState): + """LangGraph state for the fixed scene intake workflow.""" + + request: Prompt2SceneInput + messages: list[Any] + raw_model_output: dict[str, Any] | None + draft_scene_intake: SceneIntakeSpec | None + scene_intake: SceneIntakeSpec | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py new file mode 100644 index 00000000..e49fe9b3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py @@ -0,0 +1,229 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import re +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeAsset, + SceneIntakeInputRecord, + SceneIntakeSpec, + SceneIntakeTable, +) + +__all__ = ["build_scene_intake_spec", "normalize_asset_name"] + + +def normalize_asset_name(name: str) -> str: + """Normalize an object name for stable asset IDs.""" + normalized = name.strip().lower() + normalized = normalized.replace("-", " ").replace("/", " ") + normalized = re.sub(r"[^a-z0-9\s_]", "", normalized) + normalized = re.sub(r"\s+", "_", normalized) + normalized = re.sub(r"_+", "_", normalized).strip("_") + return normalized or "object" + + +def build_scene_intake_spec( + *, + request: Prompt2SceneInput, + model_output: dict[str, Any], +) -> SceneIntakeSpec: + """Normalize raw VLM JSON into the stable scene intake schema.""" + _validate_exact_keys( + model_output, + allowed_keys={"table", "assets"}, + context="Scene intake model output", + ) + input_record = SceneIntakeInputRecord.from_request(request) + table = _parse_table(_require_mapping(model_output.get("table"), "table")) + assets = _parse_assets(_require_list(model_output.get("assets"), "assets")) + return SceneIntakeSpec(input=input_record, table=table, assets=assets) + + +def _parse_table(raw_table: dict[str, Any]) -> SceneIntakeTable: + _validate_exact_keys( + raw_table, + allowed_keys={ + "name", + "description", + "complete_table_description", + "is_complete_visible_table", + "class_candidate", + }, + context="Scene intake table", + ) + + if "name" not in raw_table: + raise ValueError("Scene intake table.name is required.") + raw_name = str(raw_table["name"]).strip() + if not raw_name: + raise ValueError("Scene intake table.name must be non-empty.") + name = normalize_asset_name(raw_name) + + if "description" not in raw_table: + raise ValueError("Scene intake table.description is required.") + description = str(raw_table["description"]).strip() + if not description: + raise ValueError("Scene intake table.description must be non-empty.") + + if "complete_table_description" not in raw_table: + raise ValueError("Scene intake table.complete_table_description is required.") + complete_table_description = str( + raw_table["complete_table_description"] + ).strip() + if not complete_table_description: + raise ValueError( + "Scene intake table.complete_table_description must be non-empty." + ) + + if "is_complete_visible_table" not in raw_table: + raise ValueError("Scene intake table.is_complete_visible_table is required.") + is_complete_visible_table = raw_table["is_complete_visible_table"] + if not isinstance(is_complete_visible_table, bool): + raise ValueError( + "Scene intake table.is_complete_visible_table must be a boolean." + ) + + class_candidate = _parse_class_candidate( + raw_table.get("class_candidate"), + asset_index="table", + raw_name=name, + ) + + return SceneIntakeTable( + name=name, + description=description, + complete_table_description=complete_table_description, + is_complete_visible_table=is_complete_visible_table, + class_candidate=class_candidate, + ) + + +def _parse_assets(raw_assets: list[Any]) -> list[SceneIntakeAsset]: + assets: list[SceneIntakeAsset] = [] + seen_names: set[str] = set() + + for asset_index, raw_asset in enumerate(raw_assets): + if not isinstance(raw_asset, dict): + raise ValueError(f"Scene intake asset {asset_index} must be an object.") + _validate_exact_keys( + raw_asset, + allowed_keys={"name", "description", "class_candidate", "count"}, + context=f"Scene intake asset {asset_index}", + ) + + if "name" not in raw_asset: + raise ValueError(f"Scene intake asset {asset_index}.name is required.") + raw_name = str(raw_asset["name"]).strip() + if not raw_name: + raise ValueError( + f"Scene intake asset {asset_index}.name must be non-empty." + ) + + if "description" not in raw_asset: + raise ValueError( + f"Scene intake asset {asset_index}.description is required." + ) + description = str(raw_asset["description"]).strip() + if not description: + raise ValueError( + f"Scene intake asset {asset_index}.description must be non-empty." + ) + + class_candidate = _parse_class_candidate( + raw_asset.get("class_candidate"), + asset_index=asset_index, + raw_name=raw_name, + ) + count = _parse_count(raw_asset.get("count"), asset_index=asset_index) + base_name = normalize_asset_name(raw_name) + name = base_name + suffix = 2 + while name in seen_names: + name = f"{base_name}_{suffix}" + suffix += 1 + seen_names.add(name) + assets.append( + SceneIntakeAsset( + id=f"interact_{name}", + name=name, + count=count, + description=description, + class_candidate=class_candidate, + ) + ) + return assets + + +def _parse_class_candidate( + raw_class_candidate: Any, + *, + asset_index: int | str, + raw_name: str, +) -> list[str]: + if not isinstance(raw_class_candidate, list): + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate must be a list." + ) + class_candidate = [normalize_asset_name(str(item)) for item in raw_class_candidate] + if len(class_candidate) != 5: + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate must contain exactly five entries." + ) + if any(not candidate for candidate in class_candidate): + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate has empty entries." + ) + if class_candidate[0] != normalize_asset_name(raw_name): + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate[0] must equal name." + ) + return class_candidate + + +def _parse_count(raw_count: Any, *, asset_index: int) -> int: + if not isinstance(raw_count, int) or isinstance(raw_count, bool): + raise ValueError(f"Scene intake asset {asset_index}.count must be an integer.") + if raw_count < 1: + raise ValueError(f"Scene intake asset {asset_index}.count must be >= 1.") + return raw_count + + +def _validate_exact_keys( + value: dict[str, Any], + *, + allowed_keys: set[str], + context: str, +) -> None: + extra_keys = sorted(set(value) - allowed_keys) + if extra_keys: + raise ValueError(f"{context} has unexpected keys: {extra_keys}.") + + +def _require_mapping(value: Any, context: str) -> dict[str, Any]: + if not isinstance(value, dict): + raise ValueError(f"{context} must be an object.") + return value + + +def _require_list(value: Any, context: str) -> list[Any]: + if not isinstance(value, list): + raise ValueError(f"{context} must be a list.") + return value diff --git a/embodichain/gen_sim/prompt2scene/workflows/spatial.py b/embodichain/gen_sim/prompt2scene/workflows/spatial.py new file mode 100644 index 00000000..b5f93868 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/spatial.py @@ -0,0 +1,309 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__ = [ + "GRID_VALUE_LIST", + "GRID_VALUES", + "RELATION_VALUE_LIST", + "RELATION_VALUES", + "assign_grids_from_anchor_and_orders", + "derive_relations_from_orders", + "invert_relation", + "normalize_relation", + "transitive_relation_closure", + "validate_exact_asset_id_coverage", +] + +RELATION_VALUE_LIST = ["left_of", "front_of"] +RELATION_VALUES = frozenset(RELATION_VALUE_LIST) +INVERSE_RELATIONS = { + "left_of": "right_of", + "right_of": "left_of", + "front_of": "behind", + "behind": "front_of", +} + +GRID_VALUE_LIST = [ + "center", + "front", + "back", + "left_center", + "right_center", + "left_front", + "right_front", + "left_back", + "right_back", +] +GRID_VALUES = frozenset(GRID_VALUE_LIST) + + +def validate_exact_asset_id_coverage( + *, + values: list[str], + expected_asset_ids: list[str], + context: str, +) -> None: + """Validate that values contain every expected asset id exactly once.""" + expected = set(expected_asset_ids) + actual = set(values) + duplicates = sorted({asset_id for asset_id in values if values.count(asset_id) > 1}) + missing = sorted(expected - actual) + unknown = sorted(actual - expected) + if duplicates: + raise ValueError(f"{context} has duplicate asset ids: {duplicates}.") + if missing: + raise ValueError(f"{context} is missing asset ids: {missing}.") + if unknown: + raise ValueError(f"{context} has unknown asset ids: {unknown}.") + + +def assign_grids_from_anchor_and_orders( + *, + anchor_asset_id: str, + anchor_grid: str, + x_order: list[list[str]], + y_order: list[list[str]], + asset_ids: list[str], +) -> dict[str, str]: + """Assign 9-grid labels from one anchor grid and two object orderings.""" + anchor_x, anchor_y = _split_grid(anchor_grid) + x_indices = _order_indices(x_order) + y_indices = _order_indices(y_order) + anchor_x_index = x_indices[anchor_asset_id] + anchor_y_index = y_indices[anchor_asset_id] + + grids: dict[str, str] = {} + for asset_id in asset_ids: + x_label = _axis_label_from_anchor( + index=x_indices[asset_id], + anchor_index=anchor_x_index, + anchor_label=anchor_x, + before_label="left", + after_label="right", + ) + y_label = _axis_label_from_anchor( + index=y_indices[asset_id], + anchor_index=anchor_y_index, + anchor_label=anchor_y, + before_label="front", + after_label="back", + ) + grids[asset_id] = _join_grid(x_label=x_label, y_label=y_label) + return grids + + +def invert_relation(relation: str) -> str: + """Return the inverse of a supported spatial relation.""" + if relation not in INVERSE_RELATIONS: + raise ValueError(f"Unsupported spatial relation: {relation!r}.") + return INVERSE_RELATIONS[relation] + + +def normalize_relation( + *, + subject: str, + relation: str, + object_id: str, +) -> tuple[str, str, str]: + """Normalize a relation into a canonical directional axis edge.""" + if relation == "left_of": + return subject, "left_of", object_id + if relation == "right_of": + return object_id, "left_of", subject + if relation == "front_of": + return subject, "front_of", object_id + if relation == "behind": + return object_id, "front_of", subject + raise ValueError(f"Unsupported spatial relation: {relation!r}.") + + +def transitive_relation_closure( + relations: list[dict[str, str]], +) -> list[dict[str, str]]: + """Expand canonical left/front relations with transitive closure.""" + direct_edges: dict[str, set[tuple[str, str]]] = { + "left_of": set(), + "front_of": set(), + } + input_edges: set[tuple[str, str, str]] = set() + for relation_record in relations: + subject = relation_record["subject"] + relation = relation_record["relation"] + object_id = relation_record["object"] + canonical_subject, canonical_relation, canonical_object = normalize_relation( + subject=subject, + relation=relation, + object_id=object_id, + ) + if canonical_subject == canonical_object: + raise ValueError("Spatial relation cannot reference the same object.") + edge = (canonical_subject, canonical_object) + inverse_edge = (canonical_object, canonical_subject) + if inverse_edge in direct_edges[canonical_relation]: + raise ValueError( + "Conflicting spatial relations: " + f"{canonical_subject!r} {canonical_relation} {canonical_object!r}." + ) + direct_edges[canonical_relation].add(edge) + input_edges.add((subject, relation, object_id)) + + output: list[dict[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for canonical_relation, edges in direct_edges.items(): + for subject, object_id in sorted(_transitive_edges(edges)): + _append_relation( + output=output, + seen=seen, + subject=subject, + relation=canonical_relation, + object_id=object_id, + source=( + "input" + if (subject, canonical_relation, object_id) in input_edges + else "closure" + ), + ) + return output + + +def derive_relations_from_orders( + *, + x_order: list[list[str]], + y_order: list[list[str]], +) -> list[dict[str, str]]: + """Derive canonical relations from adjacent order groups.""" + relations: list[dict[str, str]] = [] + relations.extend(_relations_from_order_groups(x_order, relation="left_of")) + relations.extend(_relations_from_order_groups(y_order, relation="front_of")) + closed = transitive_relation_closure(relations) + return [ + { + **relation, + "source": "order" if relation["source"] == "input" else relation["source"], + } + for relation in closed + ] + + +def _order_indices(order: list[list[str]]) -> dict[str, int]: + return { + asset_id: group_index + for group_index, group in enumerate(order) + for asset_id in group + } + + +def _split_grid(grid: str) -> tuple[str, str]: + if grid == "center": + return "center", "center" + if grid in {"front", "back"}: + return "center", grid + if grid in {"left_center", "right_center"}: + return grid.split("_", maxsplit=1)[0], "center" + x_label, y_label = grid.split("_", maxsplit=1) + return x_label, y_label + + +def _axis_label_from_anchor( + *, + index: int, + anchor_index: int, + anchor_label: str, + before_label: str, + after_label: str, +) -> str: + if index < anchor_index: + return before_label + if index > anchor_index: + return after_label + return anchor_label + + +def _join_grid(*, x_label: str, y_label: str) -> str: + if x_label == "center" and y_label == "center": + return "center" + if x_label == "center": + return y_label + if y_label == "center": + return f"{x_label}_center" + return f"{x_label}_{y_label}" + + +def _relations_from_order_groups( + order_groups: list[list[str]], + *, + relation: str, +) -> list[dict[str, str]]: + relations: list[dict[str, str]] = [] + for earlier_group, later_group in zip(order_groups, order_groups[1:]): + for subject in earlier_group: + for object_id in later_group: + relations.append( + { + "subject": subject, + "relation": relation, + "object": object_id, + "source": "input", + } + ) + return relations + + +def _transitive_edges( + edges: set[tuple[str, str]], +) -> set[tuple[str, str]]: + adjacency: dict[str, set[str]] = {} + for subject, object_id in edges: + adjacency.setdefault(subject, set()).add(object_id) + adjacency.setdefault(object_id, set()) + + closure: set[tuple[str, str]] = set(edges) + for start in adjacency: + stack = list(adjacency[start]) + visited: set[str] = set() + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + closure.add((start, current)) + stack.extend(adjacency.get(current, ())) + return closure + + +def _append_relation( + *, + output: list[dict[str, str]], + seen: set[tuple[str, str, str]], + subject: str, + relation: str, + object_id: str, + source: str, +) -> None: + key = (subject, relation, object_id) + if key in seen: + return + seen.add(key) + output.append( + { + "subject": subject, + "relation": relation, + "object": object_id, + "source": source, + } + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/stage_errors.py b/embodichain/gen_sim/prompt2scene/workflows/stage_errors.py new file mode 100644 index 00000000..f8d8c230 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/stage_errors.py @@ -0,0 +1,40 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__ = ["format_attempt_error", "format_result_missing_error"] + + +def format_attempt_error(stage_name: str, attempt_count: int, exc: Exception) -> str: + """Format a retryable stage failure message.""" + return f"{stage_name} attempt {attempt_count} failed: {exc}" + + +def format_result_missing_error( + stage_name: str, + result_name: str, + *, + attempt_count: int, + last_error: str | None, + errors: list[str], +) -> str: + """Format a missing-final-result error message.""" + return ( + f"{stage_name} failed to produce a {result_name} after " + f"{attempt_count} attempts. Last error: {last_error}. " + f"All retryable errors: {errors}" + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py new file mode 100644 index 00000000..e2c03539 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.text_relations.graph import ( + build_text_relations_graph, + run_text_relations, +) + +__all__ = ["build_text_relations_graph", "run_text_relations"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py new file mode 100644 index 00000000..f6aa6078 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py @@ -0,0 +1,124 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import ( + OpenAICompatibleLLMCfg, + build_chat_model, +) +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.nodes import ( + call_llm_text_relations_node, + normalize_text_relations_node, + prepare_text_relation_messages_node, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.state import ( + TextRelationsState, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["build_text_relations_graph", "run_text_relations"] + + +def route_after_text_relation_normalize(state: TextRelationsState) -> str: + """Route to retry or finish after text relation normalization.""" + if state["text_relations"] is not None: + return "end" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def build_text_relations_graph(llm: Any) -> Any: + """Build the fixed text spatial-relation extraction workflow.""" + graph = StateGraph(TextRelationsState) + graph.add_node( + "prepare_text_relation_messages", + prepare_text_relation_messages_node, + ) + graph.add_node( + "call_llm_text_relations", + lambda state: call_llm_text_relations_node(state, llm=llm), + ) + graph.add_node("normalize_text_relations", normalize_text_relations_node) + + graph.set_entry_point("prepare_text_relation_messages") + graph.add_edge("prepare_text_relation_messages", "call_llm_text_relations") + graph.add_edge("call_llm_text_relations", "normalize_text_relations") + graph.add_conditional_edges( + "normalize_text_relations", + route_after_text_relation_normalize, + { + "retry": "call_llm_text_relations", + "end": END, + }, + ) + return graph.compile() + + +def run_text_relations( + request: Prompt2SceneInput, + *, + scene_intake: SceneIntakeSpec, + llm_cfg: OpenAICompatibleLLMCfg, + output_root: Path, +) -> TextRelationSpec: + """Run text spatial-relation extraction for one prompt2scene request.""" + llm = build_chat_model(llm_cfg) + graph = build_text_relations_graph(llm) + result = graph.invoke( + { + "request": request, + "scene_intake": scene_intake, + "output_root": output_root, + "messages": [], + "raw_model_output": None, + "text_relations": None, + "attempt_count": 0, + "max_attempts": llm_cfg.max_attempts, + "last_error": None, + "errors": [], + } + ) + + text_relations = result.get("text_relations") + if text_relations is not None: + return text_relations + + error = format_result_missing_error( + "Text relations", + "TextRelationSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py new file mode 100644 index 00000000..67b1fc3c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py @@ -0,0 +1,144 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import InputKind +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TEXT_RELATIONS_JSON_SCHEMA, + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log_api_request_start, + log, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + TEXT_RELATIONS_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + StructuredModelCallError, + call_structured_json_model_step, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.prompts import ( + build_text_relation_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.state import ( + TextRelationsState, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.utils import ( + build_text_relation_spec, +) + +__all__ = [ + "call_llm_text_relations_node", + "normalize_text_relations_node", + "prepare_text_relation_messages_node", +] + + +def prepare_text_relation_messages_node( + state: TextRelationsState, +) -> dict[str, object]: + """Prepare text-relation extraction messages.""" + request = state["request"] + if request.input_kind != InputKind.TEXT: + raise ValueError("Text relations requires a text input.") + return { + "messages": build_text_relation_messages( + request=request, + scene_intake=state["scene_intake"], + ) + } + + +def call_llm_text_relations_node( + state: TextRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Call LLM to extract explicit text spatial constraints.""" + attempt_count = state["attempt_count"] + 1 + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + TEXT_RELATIONS_STEP, + ) + + try: + log_api_request_start( + step=TEXT_RELATIONS_STEP, + request="extract", + attempt=attempt_count, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=TEXT_RELATIONS_JSON_SCHEMA, + messages=state["messages"], + context="Text relations", + step_name=TEXT_RELATIONS_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="extract", + artifact_writer=artifact_writer, + ) + except StructuredModelCallError as exc: + error = format_attempt_error("Text relations", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "raw_model_output": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return { + "attempt_count": attempt_count, + "raw_model_output": raw_model_output, + "last_error": None, + } + + +def normalize_text_relations_node(state: TextRelationsState) -> dict[str, object]: + """Normalize raw LLM output into TextRelationSpec.""" + raw_model_output = state["raw_model_output"] + if raw_model_output is None: + return {} + + try: + text_relations = build_text_relation_spec( + scene_intake=state["scene_intake"], + model_output=raw_model_output, + ) + except ValueError as exc: + error = format_attempt_error("Text relations", state["attempt_count"], exc) + log.log_warning(error) + return { + "text_relations": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + TEXT_RELATIONS_STEP, + ) + artifact_writer.write_step_result(text_relations.to_manifest()) + return {"text_relations": text_relations, "last_error": None} diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py new file mode 100644 index 00000000..a6f02e4f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py @@ -0,0 +1,55 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) + +__all__ = ["build_text_relation_messages"] + +TEXT_RELATIONS_PROMPT_NAME = "text_relations.yaml" + + +def build_text_relation_messages( + *, + request: Prompt2SceneInput, + scene_intake: SceneIntakeSpec, +) -> list[dict[str, Any]]: + """Build messages for explicit text spatial-relation extraction.""" + asset_names = "\n".join(f"- {asset.name}" for asset in scene_intake.assets) + return [ + { + "role": "system", + "content": render_prompt(TEXT_RELATIONS_PROMPT_NAME, prompt_key="system"), + }, + { + "role": "user", + "content": render_prompt( + TEXT_RELATIONS_PROMPT_NAME, + { + "asset_names": asset_names, + "text": request.text or "", + }, + prompt_key="user", + ), + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py new file mode 100644 index 00000000..db2e513f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py @@ -0,0 +1,164 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + GRID_VALUE_LIST, + RELATION_VALUE_LIST, +) + +__all__ = [ + "TEXT_RELATIONS_JSON_SCHEMA", + "TextObjectLayout", + "TextObjectRelation", + "TextRelationSpec", + "TextTableConstraint", +] + +TEXT_RELATIONS_JSON_SCHEMA: dict[str, Any] = { + "title": "TextRelationsOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "object_relations": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "subject": {"type": "string", "minLength": 1}, + "relation": { + "type": "string", + "enum": RELATION_VALUE_LIST, + }, + "object": {"type": "string", "minLength": 1}, + "evidence": {"type": "string", "minLength": 1}, + }, + "required": ["subject", "relation", "object", "evidence"], + }, + }, + "table_constraints": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset": {"type": "string", "minLength": 1}, + "grid": { + "type": "string", + "enum": GRID_VALUE_LIST, + }, + "evidence": {"type": "string", "minLength": 1}, + }, + "required": ["asset", "grid", "evidence"], + }, + }, + "object_layouts": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset": {"type": "string", "minLength": 1}, + "is_arbitrary_layout": {"type": "boolean"}, + "reason": {"type": "string", "minLength": 1}, + }, + "required": ["asset", "is_arbitrary_layout", "reason"], + }, + }, + }, + "required": ["object_relations", "table_constraints", "object_layouts"], +} + + +@dataclass(frozen=True) +class TextObjectRelation: + """Text-stated relation between two scene-intake asset groups.""" + + subject: str + relation: str + object: str + evidence: str + + def to_manifest(self) -> dict[str, str]: + """Convert the relation to JSON-safe data.""" + return { + "subject": self.subject, + "relation": self.relation, + "object": self.object, + "evidence": self.evidence, + } + + +@dataclass(frozen=True) +class TextTableConstraint: + """Text-stated table grid constraint for one asset group.""" + + asset: str + grid: str + evidence: str + + def to_manifest(self) -> dict[str, str]: + """Convert the table constraint to JSON-safe data.""" + return { + "asset": self.asset, + "grid": self.grid, + "evidence": self.evidence, + } + + +@dataclass(frozen=True) +class TextObjectLayout: + """Text-stated object support-pose constraint.""" + + asset: str + is_arbitrary_layout: bool + reason: str + + def to_manifest(self) -> dict[str, object]: + """Convert the layout constraint to JSON-safe data.""" + return { + "asset": self.asset, + "is_arbitrary_layout": self.is_arbitrary_layout, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class TextRelationSpec: + """Spatial constraints explicitly extracted from a text prompt.""" + + source_text: str + object_relations: list[TextObjectRelation] = field(default_factory=list) + table_constraints: list[TextTableConstraint] = field(default_factory=list) + object_layouts: list[TextObjectLayout] = field(default_factory=list) + + def to_manifest(self) -> dict[str, object]: + """Convert the text relations to JSON-safe data.""" + return { + "source_text": self.source_text, + "object_relations": [ + relation.to_manifest() for relation in self.object_relations + ], + "table_constraints": [ + constraint.to_manifest() for constraint in self.table_constraints + ], + "object_layouts": [layout.to_manifest() for layout in self.object_layouts], + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py new file mode 100644 index 00000000..b8dfa4c9 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["TextRelationsState"] + + +class TextRelationsState(AttemptState): + """LangGraph state for explicit text spatial-relation extraction.""" + + request: Prompt2SceneInput + scene_intake: SceneIntakeSpec + output_root: Path + messages: list[Any] + raw_model_output: dict[str, Any] | None + text_relations: TextRelationSpec | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py new file mode 100644 index 00000000..58002713 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py @@ -0,0 +1,191 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + GRID_VALUES, + RELATION_VALUES, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.utils import ( + normalize_asset_name, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextObjectLayout, + TextObjectRelation, + TextRelationSpec, + TextTableConstraint, +) + +__all__ = [ + "build_text_relation_spec", +] + + +def build_text_relation_spec( + *, + scene_intake: SceneIntakeSpec, + model_output: dict[str, Any], +) -> TextRelationSpec: + """Normalize raw LLM JSON into text relation constraints.""" + asset_names = {asset.name for asset in scene_intake.assets} + object_relations = _parse_object_relations( + model_output.get("object_relations"), + asset_names=asset_names, + ) + table_constraints = _parse_table_constraints( + model_output.get("table_constraints"), + asset_names=asset_names, + ) + object_layouts = _parse_object_layouts( + model_output.get("object_layouts"), + asset_names=asset_names, + ) + return TextRelationSpec( + source_text=scene_intake.input.text or "", + object_relations=object_relations, + table_constraints=table_constraints, + object_layouts=object_layouts, + ) + + +def _parse_object_relations( + raw_relations: Any, + *, + asset_names: set[str], +) -> list[TextObjectRelation]: + if not isinstance(raw_relations, list): + raise ValueError("text_relations.object_relations must be a list.") + relations: list[TextObjectRelation] = [] + seen: set[tuple[str, str, str]] = set() + for index, raw_relation in enumerate(raw_relations): + if not isinstance(raw_relation, dict): + raise ValueError( + f"text_relations.object_relations[{index}] must be an object." + ) + subject = _parse_asset_name(raw_relation.get("subject"), asset_names, index) + relation = str(raw_relation.get("relation") or "").strip() + object_name = _parse_asset_name(raw_relation.get("object"), asset_names, index) + evidence = str(raw_relation.get("evidence") or "").strip() + if relation not in RELATION_VALUES: + raise ValueError( + f"text_relations.object_relations[{index}].relation is invalid." + ) + if not evidence: + raise ValueError( + f"text_relations.object_relations[{index}].evidence is required." + ) + key = (subject, relation, object_name) + if key in seen: + continue + seen.add(key) + relations.append( + TextObjectRelation( + subject=subject, + relation=relation, + object=object_name, + evidence=evidence, + ) + ) + return relations + + +def _parse_table_constraints( + raw_constraints: Any, + *, + asset_names: set[str], +) -> list[TextTableConstraint]: + if not isinstance(raw_constraints, list): + raise ValueError("text_relations.table_constraints must be a list.") + constraints: list[TextTableConstraint] = [] + seen: set[tuple[str, str]] = set() + for index, raw_constraint in enumerate(raw_constraints): + if not isinstance(raw_constraint, dict): + raise ValueError( + f"text_relations.table_constraints[{index}] must be an object." + ) + asset = _parse_asset_name(raw_constraint.get("asset"), asset_names, index) + grid = str(raw_constraint.get("grid") or "").strip() + evidence = str(raw_constraint.get("evidence") or "").strip() + if grid not in GRID_VALUES: + raise ValueError( + f"text_relations.table_constraints[{index}].grid is invalid." + ) + if not evidence: + raise ValueError( + f"text_relations.table_constraints[{index}].evidence is required." + ) + key = (asset, grid) + if key in seen: + continue + seen.add(key) + constraints.append( + TextTableConstraint(asset=asset, grid=grid, evidence=evidence) + ) + return constraints + + +def _parse_object_layouts( + raw_layouts: Any, + *, + asset_names: set[str], +) -> list[TextObjectLayout]: + if not isinstance(raw_layouts, list): + raise ValueError("text_relations.object_layouts must be a list.") + layouts: list[TextObjectLayout] = [] + seen: set[str] = set() + for index, raw_layout in enumerate(raw_layouts): + if not isinstance(raw_layout, dict): + raise ValueError( + f"text_relations.object_layouts[{index}] must be an object." + ) + asset = _parse_asset_name(raw_layout.get("asset"), asset_names, index) + is_arbitrary_layout = raw_layout.get("is_arbitrary_layout") + reason = str(raw_layout.get("reason") or "").strip() + if not isinstance(is_arbitrary_layout, bool): + raise ValueError( + "text_relations.object_layouts" + f"[{index}].is_arbitrary_layout must be boolean." + ) + if not reason: + raise ValueError( + f"text_relations.object_layouts[{index}].reason is required." + ) + if asset in seen: + continue + seen.add(asset) + layouts.append( + TextObjectLayout( + asset=asset, + is_arbitrary_layout=is_arbitrary_layout, + reason=reason, + ) + ) + return layouts + + +def _parse_asset_name(raw_name: Any, asset_names: set[str], index: int) -> str: + name = normalize_asset_name(str(raw_name or "")) + if name not in asset_names: + raise ValueError( + f"text_relations item {index} references unknown scene asset: {name!r}." + ) + return name diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py new file mode 100644 index 00000000..015c4151 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py new file mode 100644 index 00000000..7431f0c0 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py @@ -0,0 +1,97 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.schema import ( + UnifiedSceneSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.nodes import ( + build_unified_scene_node, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.state import ( + UnifiedSceneState, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["build_unified_scene_graph", "run_unified_scene"] + + +def build_unified_scene_graph() -> Any: + """Build the fixed unified-scene assembly workflow.""" + graph = StateGraph(UnifiedSceneState) + graph.add_node("build_unified_scene", build_unified_scene_node) + graph.set_entry_point("build_unified_scene") + graph.add_edge("build_unified_scene", END) + return graph.compile() + + +def run_unified_scene( + request: Prompt2SceneInput, + *, + scene_intake: SceneIntakeSpec, + image_relations: ImageRelationSpec | None = None, + text_relations: TextRelationSpec | None = None, + output_root: Path, +) -> UnifiedSceneSpec: + """Run final unified-scene assembly for one prompt2scene request.""" + graph = build_unified_scene_graph() + result = graph.invoke( + { + "request": request, + "scene_intake": scene_intake, + "output_root": output_root, + "image_relations": image_relations, + "text_relations": text_relations, + "unified_scene": None, + "attempt_count": 0, + "max_attempts": 1, + "last_error": None, + "errors": [], + } + ) + + unified_scene = result.get("unified_scene") + if unified_scene is not None: + return unified_scene + + error = format_result_missing_error( + "Unified scene", + "UnifiedSceneSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py new file mode 100644 index 00000000..5d65a737 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py @@ -0,0 +1,57 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + UNIFIED_SCENE_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.state import ( + UnifiedSceneState, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.utils import ( + build_unified_scene_from_image_relations, + build_unified_scene_from_text_relations, +) + +__all__ = ["build_unified_scene_node"] + + +def build_unified_scene_node(state: UnifiedSceneState) -> dict[str, object]: + """Assemble the final unified scene manifest.""" + scene_intake = state["scene_intake"] + image_relations = state.get("image_relations") + text_relations = state.get("text_relations") + + if image_relations is not None and image_relations.status == "ok": + unified_scene = build_unified_scene_from_image_relations( + scene_intake=scene_intake, + image_relations=image_relations, + ) + elif text_relations is not None: + unified_scene = build_unified_scene_from_text_relations( + scene_intake=scene_intake, + text_relations=text_relations, + ) + else: + raise ValueError("Unified scene requires image_relations or text_relations.") + + WorkflowArtifactWriter( + state["output_root"], + UNIFIED_SCENE_STEP, + ).write_step_result(unified_scene.to_manifest()) + return {"unified_scene": unified_scene} diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py new file mode 100644 index 00000000..f3d13125 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py @@ -0,0 +1,157 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +__all__ = [ + "UnifiedObject", + "UnifiedSceneSpec", + "UnifiedSpatial", + "UnifiedSpatialAnchor", + "UnifiedSpatialRelation", + "UnifiedTable", +] + + +@dataclass(frozen=True) +class UnifiedTable: + """Unified table/support object.""" + + id: str + name: str + description: str + complete_table_description: str + is_complete_visible_table: bool + class_candidate: list[str] + image_path: str | None = None + mesh_path: str | None = None + grid_cells: dict[str, list[str]] | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the table to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "complete_table_description": self.complete_table_description, + "is_complete_visible_table": self.is_complete_visible_table, + "class_candidate": list(self.class_candidate), + "image_path": self.image_path, + "mesh_path": self.mesh_path, + "grid_cells": self.grid_cells, + } + + +@dataclass(frozen=True) +class UnifiedObject: + """Unified object instance used by downstream scene generation.""" + + id: str + name: str + description: str + class_candidate: list[str] + grid: str | None = None + is_arbitrary_layout: bool = False + layout_reason: str = "" + image_path: str | None = None + mesh_path: str | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the object to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "class_candidate": list(self.class_candidate), + "grid": self.grid, + "is_arbitrary_layout": self.is_arbitrary_layout, + "layout_reason": self.layout_reason, + "image_path": self.image_path, + "mesh_path": self.mesh_path, + } + + +@dataclass(frozen=True) +class UnifiedSpatialAnchor: + """Spatial anchor used to infer a full table grid.""" + + object_id: str + grid: str + reason: str = "" + + def to_manifest(self) -> dict[str, str]: + """Convert the anchor to JSON-safe data.""" + return { + "object_id": self.object_id, + "grid": self.grid, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class UnifiedSpatialRelation: + """Unified pairwise spatial relation between two objects.""" + + subject: str + relation: str + object: str + source: str + + def to_manifest(self) -> dict[str, str]: + """Convert the relation to JSON-safe data.""" + return { + "subject": self.subject, + "relation": self.relation, + "object": self.object, + "source": self.source, + } + + +@dataclass(frozen=True) +class UnifiedSpatial: + """Unified spatial relations for a scene.""" + + anchor: UnifiedSpatialAnchor | None = None + relations: list[UnifiedSpatialRelation] = field(default_factory=list) + + def to_manifest(self) -> dict[str, Any]: + """Convert the spatial record to JSON-safe data.""" + return { + "anchor": self.anchor.to_manifest() if self.anchor else None, + "relations": [relation.to_manifest() for relation in self.relations], + } + + +@dataclass(frozen=True) +class UnifiedSceneSpec: + """Unified scene representation consumed by downstream generation steps.""" + + input: dict[str, Any] + table: UnifiedTable + objects: list[UnifiedObject] + spatial: UnifiedSpatial + + def to_manifest(self) -> dict[str, Any]: + """Convert the unified scene to JSON-safe data.""" + return { + "input": dict(self.input), + "table": self.table.to_manifest(), + "objects": [obj.to_manifest() for obj in self.objects], + "spatial": self.spatial.to_manifest(), + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py new file mode 100644 index 00000000..8152a6bf --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py @@ -0,0 +1,45 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["UnifiedSceneState"] + + +class UnifiedSceneState(AttemptState): + """LangGraph state for unified scene assembly.""" + + request: Prompt2SceneInput + scene_intake: SceneIntakeSpec + output_root: Path + image_relations: ImageRelationSpec | None + text_relations: TextRelationSpec | None + unified_scene: Any | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py new file mode 100644 index 00000000..e17b5e7b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py @@ -0,0 +1,332 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections import defaultdict +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageAnchor, + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + assign_grids_from_anchor_and_orders, + derive_relations_from_orders, + transitive_relation_closure, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.schema import ( + UnifiedObject, + UnifiedSceneSpec, + UnifiedSpatialAnchor, + UnifiedSpatialRelation, + UnifiedSpatial, + UnifiedTable, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeAsset, + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextObjectLayout, + TextRelationSpec, +) + +__all__ = [ + "build_unified_object", + "build_unified_object_specs", + "build_unified_scene_from_image_relations", + "build_unified_scene_from_text_relations", + "build_unified_spatial_anchor", + "build_unified_table", + "grid_cells_from_objects", + "object_ids_by_name", + "relations_by_object_id", + "resolve_image_layout", + "resolve_text_layout", + "text_grids_by_object_id", +] + + +def build_unified_object_specs( + assets: list[SceneIntakeAsset], +) -> list[dict[str, Any]]: + """Expand scene-intake assets into unified object instance specs.""" + specs: list[dict[str, Any]] = [] + for asset in assets: + for index in range(asset.count): + specs.append( + { + "id": f"{asset.id}_{index}", + "name": asset.name, + "description": asset.description, + "class_candidate": list(asset.class_candidate), + } + ) + return specs + + +def object_ids_by_name(object_specs: list[dict[str, Any]]) -> dict[str, list[str]]: + """Group expanded object ids by object name.""" + grouped: dict[str, list[str]] = defaultdict(list) + for spec in object_specs: + grouped[str(spec["name"])].append(str(spec["id"])) + return dict(grouped) + + +def build_unified_table( + scene_intake: SceneIntakeSpec, + *, + grid_cells: dict[str, list[str]] | None = None, +) -> dict[str, Any]: + """Build a unified table record from scene intake.""" + return { + "id": scene_intake.table.id, + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": scene_intake.table.is_complete_visible_table, + "class_candidate": list(scene_intake.table.class_candidate), + "image_path": None, + "mesh_path": None, + "grid_cells": grid_cells, + } + + +def build_unified_spatial_anchor(anchor: ImageAnchor | None) -> dict[str, Any] | None: + """Convert the image anchor to a unified spatial anchor record.""" + if anchor is None: + return None + return { + "object_id": anchor.asset_id, + "grid": anchor.grid, + "reason": anchor.reason, + } + + +def build_unified_object( + *, + spec: dict[str, Any], + grid: str | None, + is_arbitrary_layout: bool, + layout_reason: str, +) -> dict[str, Any]: + """Build one unified object record.""" + return { + "id": spec["id"], + "name": spec["name"], + "description": spec["description"], + "class_candidate": list(spec["class_candidate"]), + "grid": grid, + "is_arbitrary_layout": is_arbitrary_layout, + "layout_reason": layout_reason, + "image_path": None, + "mesh_path": None, + } + + +def resolve_image_layout( + asset_id: str, + layout_by_id: dict[str, Any], +) -> tuple[bool, str]: + """Resolve an image asset's layout state.""" + layout = layout_by_id.get(asset_id) + if layout is None: + return False, "" + return bool(layout.is_arbitrary_layout), str(layout.reason) + + +def resolve_text_layout( + name: str, + layout_by_name: dict[str, TextObjectLayout], +) -> tuple[bool, str]: + """Resolve a text asset's layout state.""" + layout = layout_by_name.get(name) + if layout is None: + return False, "" + return bool(layout.is_arbitrary_layout), str(layout.reason) + + +def text_grids_by_object_id( + *, + text_relations: TextRelationSpec, + ids_by_name: dict[str, list[str]], +) -> dict[str, str | None]: + """Assign explicit text table constraints to object ids.""" + grids: dict[str, str | None] = {object_id: None for ids in ids_by_name.values() for object_id in ids} + for constraint in text_relations.table_constraints: + for object_id in ids_by_name.get(constraint.asset, []): + grids[object_id] = constraint.grid + return grids + + +def grid_cells_from_objects(objects: list[dict[str, Any]]) -> dict[str, list[str]] | None: + """Build table grid cell membership from unified objects.""" + grid_cells: dict[str, list[str]] = { + "center": [], + "front": [], + "back": [], + "left_center": [], + "right_center": [], + "left_front": [], + "right_front": [], + "left_back": [], + "right_back": [], + } + any_grid = False + for obj in objects: + grid = obj.get("grid") + if not grid: + continue + any_grid = True + grid_cells.setdefault(str(grid), []).append(str(obj["id"])) + return grid_cells if any_grid else None + + +def relations_by_object_id( + *, + text_relations: TextRelationSpec, + ids_by_name: dict[str, list[str]], +) -> list[dict[str, str]]: + """Expand text relations to object-id relations.""" + relations: list[dict[str, str]] = [] + for relation in text_relations.object_relations: + subjects = ids_by_name.get(relation.subject, []) + objects = ids_by_name.get(relation.object, []) + for subject in subjects: + for object_id in objects: + if subject == object_id: + continue + relations.append( + { + "subject": subject, + "relation": relation.relation, + "object": object_id, + "source": "input", + } + ) + return relations + + +def build_unified_scene_from_image_relations( + *, + scene_intake: SceneIntakeSpec, + image_relations: ImageRelationSpec, +) -> UnifiedSceneSpec: + """Build a unified scene from image relation outputs.""" + object_specs = build_unified_object_specs(scene_intake.assets) + anchor = build_unified_spatial_anchor(image_relations.anchor) + if anchor is None: + raise ValueError("Image unified scene requires an anchor.") + layout_by_id = { + layout.asset_id: layout for layout in image_relations.asset_layouts + } + objects = [] + for spec in object_specs: + is_arbitrary_layout, layout_reason = resolve_image_layout( + spec["id"], + layout_by_id, + ) + objects.append( + UnifiedObject( + **build_unified_object( + spec=spec, + grid=anchor["grid"] if spec["id"] == anchor["object_id"] else None, + is_arbitrary_layout=is_arbitrary_layout, + layout_reason=layout_reason, + ) + ) + ) + relations = [ + UnifiedSpatialRelation(**relation) + for relation in derive_relations_from_orders( + x_order=image_relations.x_order, + y_order=image_relations.y_order, + ) + ] + return UnifiedSceneSpec( + input=scene_intake.input.to_manifest(), + table=UnifiedTable( + **build_unified_table( + scene_intake, + grid_cells=grid_cells_from_objects( + [object_.to_manifest() for object_ in objects] + ), + ) + ), + objects=objects, + spatial=UnifiedSpatial( + anchor=UnifiedSpatialAnchor(**anchor), + relations=relations, + ), + ) + + +def build_unified_scene_from_text_relations( + *, + scene_intake: SceneIntakeSpec, + text_relations: TextRelationSpec, +) -> UnifiedSceneSpec: + """Build a unified scene from text relation outputs.""" + object_specs = build_unified_object_specs(scene_intake.assets) + ids_by_name = object_ids_by_name(object_specs) + grid_by_id = text_grids_by_object_id( + text_relations=text_relations, + ids_by_name=ids_by_name, + ) + layout_by_name = { + layout.asset: layout for layout in text_relations.object_layouts + } + objects = [] + for spec in object_specs: + is_arbitrary_layout, layout_reason = resolve_text_layout( + spec["name"], + layout_by_name, + ) + objects.append( + UnifiedObject( + **build_unified_object( + spec=spec, + grid=grid_by_id.get(spec["id"]), + is_arbitrary_layout=is_arbitrary_layout, + layout_reason=layout_reason, + ) + ) + ) + relations = [ + UnifiedSpatialRelation(**relation) + for relation in transitive_relation_closure( + relations_by_object_id( + text_relations=text_relations, + ids_by_name=ids_by_name, + ) + ) + ] + return UnifiedSceneSpec( + input=scene_intake.input.to_manifest(), + table=UnifiedTable( + **build_unified_table( + scene_intake, + grid_cells=grid_cells_from_objects( + [object_.to_manifest() for object_ in objects] + ), + ) + ), + objects=objects, + spatial=UnifiedSpatial(anchor=None, relations=relations), + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py new file mode 100644 index 00000000..ac849443 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py @@ -0,0 +1,27 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.graph import ( + build_unified_scene_gen_graph, + run_unified_scene_gen, +) + +__all__ = [ + "build_unified_scene_gen_graph", + "run_unified_scene_gen", +] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py new file mode 100644 index 00000000..5d542b39 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py @@ -0,0 +1,106 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import build_chat_model +from embodichain.gen_sim.prompt2scene.llms.config import OpenAICompatibleLLMCfg +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.nodes import ( + fit_image_table_to_clutter_node, + fit_text_table_to_clutter_node, + generate_image_assets_node, + generate_text_assets_node, + generate_text_clutter_layout_node, + load_unified_scene_input_kind_node, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.state import ( + UnifiedSceneGenState, +) +__all__ = [ + "build_unified_scene_gen_graph", + "route_after_load_input_kind", + "run_unified_scene_gen", +] + + +def route_after_load_input_kind(state: UnifiedSceneGenState) -> str: + """Route unified-scene generation by the original input kind.""" + input_kind = state["input_kind"] + if input_kind == "text": + return "generate_text_assets" + if input_kind == "image": + return "generate_image_assets" + raise ValueError(f"Unsupported unified-scene input_kind: {input_kind!r}.") + + +def build_unified_scene_gen_graph() -> Any: + """Build the unified-scene generation graph.""" + graph = StateGraph(UnifiedSceneGenState) + graph.add_node("load_unified_scene_input_kind", load_unified_scene_input_kind_node) + graph.add_node("generate_text_assets", generate_text_assets_node) + graph.add_node("generate_text_clutter_layout", generate_text_clutter_layout_node) + graph.add_node("fit_text_table_to_clutter", fit_text_table_to_clutter_node) + graph.add_node("generate_image_assets", generate_image_assets_node) + graph.add_node("fit_image_table_to_clutter", fit_image_table_to_clutter_node) + + graph.set_entry_point("load_unified_scene_input_kind") + graph.add_conditional_edges( + "load_unified_scene_input_kind", + route_after_load_input_kind, + { + "generate_text_assets": "generate_text_assets", + "generate_image_assets": "generate_image_assets", + }, + ) + graph.add_edge("generate_text_assets", "generate_text_clutter_layout") + graph.add_edge("generate_text_clutter_layout", "fit_text_table_to_clutter") + graph.add_edge("fit_text_table_to_clutter", END) + graph.add_edge("generate_image_assets", "fit_image_table_to_clutter") + graph.add_edge("fit_image_table_to_clutter", END) + return graph.compile() + + +def run_unified_scene_gen( + output_root: Path, + *, + unified_scene_result_path: Path | None = None, + llm_cfg: OpenAICompatibleLLMCfg | None = None, +) -> UnifiedSceneGenState: + """Run downstream generation routing from a unified-scene result.""" + llm = build_chat_model(llm_cfg) if llm_cfg is not None else None + initial_state: UnifiedSceneGenState = { + "output_root": output_root, + "unified_scene_result_path": unified_scene_result_path, + "llm": llm, + "unified_scene": None, + "input_kind": None, + "table_result": None, + "text_object_results": [], + "text_clutter_settle_result": None, + "image_objects_layout_result": None, + "table_fit_result": None, + "generation_status": None, + "attempt_count": 0, + "max_attempts": 1, + "last_error": None, + "errors": [], + } + return build_unified_scene_gen_graph().invoke(initial_state) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py new file mode 100644 index 00000000..e12e41f1 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py @@ -0,0 +1,392 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json + +from embodichain.gen_sim.prompt2scene.utils.log import log_info +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.state import ( + UnifiedSceneGenState, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + UNIFIED_SCENE_GEN_STEP, + UNIFIED_SCENE_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_asset_generation import ( + generate_text_object_assets, + generate_text_table_asset, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_scene_metric_scale import ( + estimate_text_scene_metric_scale, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_clutter_layout import ( + generate_text_clutter_layout, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.table_fit_scene import ( + fit_image_scene_table, + fit_text_scene_table, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_scene_asset_generation import ( + generate_image_scene_assets, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.paths import ( + UnifiedScenePaths, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.prompts import ( + build_text_metric_scale_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.schema import ( + IMAGE_METRIC_SCALE_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.scene_update import ( + update_unified_scene, +) + +__all__ = [ + "fit_image_table_to_clutter_node", + "fit_text_table_to_clutter_node", + "generate_image_assets_node", + "generate_text_assets_node", + "generate_text_clutter_layout_node", + "load_unified_scene_input_kind_node", +] + + +def load_unified_scene_input_kind_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Load unified-scene output and determine the generation route.""" + paths = UnifiedScenePaths(state["output_root"]) + result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) + if not result_path.is_file(): + raise FileNotFoundError(f"Unified scene result not found: {result_path}") + + with result_path.open("r", encoding="utf-8") as f: + unified_scene = json.load(f) + if not isinstance(unified_scene, dict): + raise ValueError("Unified scene result must be a JSON object.") + + input_record = unified_scene.get("input") + if not isinstance(input_record, dict): + raise ValueError("Unified scene result requires input object.") + + input_kind = str(input_record.get("input_kind") or "").strip() + if input_kind not in {"text", "image"}: + raise ValueError( + "Unified scene input.input_kind must be 'text' or 'image', " + f"got {input_kind!r}." + ) + + return { + "unified_scene_result_path": result_path, + "unified_scene": unified_scene, + "input_kind": input_kind, + } + + +def generate_text_assets_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Generate images, RGBA cutouts, geometry, and sim-ready GLBs for a + text-origin unified scene. + """ + unified_scene = state["unified_scene"] + if unified_scene is None: + return {"generation_status": "no_unified_scene"} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + image_gen_dir, glb_gen_dir, debug_dir = paths.prepare_generation_dirs() + log_info( + "generate_text_assets started " + f"output_dir={output_root / UNIFIED_SCENE_GEN_STEP}" + ) + + table_spec = unified_scene.get("table") or {} + table_result = generate_text_table_asset( + table_spec=table_spec, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + + object_specs = unified_scene.get("objects") or [] + object_results = generate_text_object_assets( + object_specs=object_specs, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + metric_prompt_objects = [ + { + "object_id": str(obj.get("id", "")), + "object_name": str(obj.get("name", "")), + "object_description": str(obj.get("description", "")), + } + for obj in object_results + ] + user_text = str((unified_scene.get("input") or {}).get("text") or "") + text_metric_scale_result = estimate_text_scene_metric_scale( + object_results=object_results, + user_text=user_text, + messages=build_text_metric_scale_messages( + user_text=user_text, + objects_json=metric_prompt_objects, + ), + schema=IMAGE_METRIC_SCALE_JSON_SCHEMA, + output_dir=glb_gen_dir / "metric_scale", + output_root=output_root, + llm=state.get("llm"), + step_name=UNIFIED_SCENE_STEP, + ) + + result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) + update_unified_scene(unified_scene, table_result, object_results, output_root) + write_json(result_path, unified_scene) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": table_result, + "objects": object_results, + "text_metric_scale": text_metric_scale_result, + "generation_status": "ok", + } + ) + log_info( + "generate_text_assets completed " + f"table_status={table_result.get('status')} " + f"object_count={len(object_results)}" + ) + + return { + "unified_scene": unified_scene, + "table_result": table_result, + "text_object_results": object_results, + "generation_status": "ok", + } + + +def generate_image_assets_node(state: UnifiedSceneGenState) -> dict[str, object]: + """Generate table assets and layout-aware object GLBs for image input. + + Table/support and objects are generated in one multi-object call from the + original image and existing segmentation masks. + """ + unified_scene = state["unified_scene"] + if unified_scene is None: + return {"generation_status": "no_unified_scene"} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + image_gen_dir, glb_gen_dir, debug_dir = paths.prepare_generation_dirs() + log_info( + "generate_image_assets started " + f"output_dir={output_root / UNIFIED_SCENE_GEN_STEP}" + ) + + segments_path = paths.image_segments_result + if not segments_path.is_file(): + raise FileNotFoundError( + f"Image segments result not found: {segments_path}" + ) + with segments_path.open("r", encoding="utf-8") as _f: + segments_data = json.load(_f) + if not isinstance(segments_data, dict): + raise ValueError("Image segments result must be a JSON object.") + + table_spec = unified_scene.get("table") or {} + # Image input uses the segmented table/support mask in the multi-object + # SAM3D call below. Text table generation belongs to the text branch. + object_specs = unified_scene.get("objects") or [] + object_layout_result = generate_image_scene_assets( + object_specs=object_specs, + table_spec=table_spec, + spatial_relations=(unified_scene.get("spatial") or {}).get("relations", []), + segments_data=segments_data, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + output_root=output_root, + llm=state.get("llm"), + ) + table_result = object_layout_result.get("table") or { + "id": str(table_spec.get("id", "table")), + "name": str(table_spec.get("name", "table")), + "status": "missing_table_generation", + } + object_results = object_layout_result.get("objects") or [] + generation_status = str(object_layout_result.get("status", "failed")) + if table_result.get("status") != "ok": + generation_status = str(table_result.get("status") or generation_status) + result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) + update_unified_scene(unified_scene, table_result, object_results, output_root) + write_json(result_path, unified_scene) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": table_result, + "objects_layout": object_layout_result, + "objects": object_results, + "table_fit_to_clutter": None, + "generation_status": generation_status, + } + ) + log_info(f"generate_image_assets completed status={generation_status}") + + return { + "unified_scene": unified_scene, + "table_result": table_result, + "text_object_results": object_results, + "image_objects_layout_result": object_layout_result, + "generation_status": generation_status, + } + + +def fit_image_table_to_clutter_node(state: UnifiedSceneGenState) -> dict[str, object]: + """Resize the final table to fit the aligned image-object clutter.""" + if state.get("input_kind") != "image": + return {} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + output_dir = paths.table_fit_dir + output_dir.mkdir(parents=True, exist_ok=True) + log_info(f"fit_image_table_to_clutter started output_dir={output_dir}") + layout_result = dict(state.get("image_objects_layout_result") or {}) + table_fit_result = fit_image_scene_table( + layout_result=layout_result, + fallback_table_result=state.get("table_result"), + output_root=output_root, + output_dir=output_dir, + ) + layout_result["table_fit_to_clutter"] = table_fit_result + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": state.get("table_result"), + "objects_layout": layout_result, + "objects": state.get("text_object_results") or [], + "table_fit_to_clutter": table_fit_result, + "generation_status": state.get("generation_status"), + } + ) + log_info( + f"fit_image_table_to_clutter completed status={table_fit_result.get('status')}" + ) + return { + "image_objects_layout_result": layout_result, + "table_fit_result": table_fit_result, + } + + +def generate_text_clutter_layout_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Scale text objects to real-world size, gravity-settle, centre at origin. + + Produces per-object settled GLBs and 2D AABB footprints for downstream + spatial layout optimisation and table fitting. + """ + if state.get("input_kind") != "text": + return {} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + output_dir = paths.text_clutter_dir + output_dir.mkdir(parents=True, exist_ok=True) + log_info(f"generate_text_clutter_layout started output_dir={output_dir}") + + text_object_results = state.get("text_object_results") or [] + if not text_object_results: + return { + "text_clutter_settle_result": { + "status": "skipped", + "reason": "no_text_objects", + } + } + + unified_scene = state.get("unified_scene") or {} + spatial_data = unified_scene.get("spatial") or {} + spatial_relations = spatial_data.get("relations", []) + table_constraints = spatial_data.get("table_constraints", []) + + settle_result = generate_text_clutter_layout( + object_results=text_object_results, + spatial_relations=spatial_relations, + table_constraints=table_constraints, + output_dir=output_dir, + output_root=output_root, + ) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": state.get("table_result"), + "objects": text_object_results, + "text_clutter_settle": settle_result, + "generation_status": state.get("generation_status"), + } + ) + log_info( + f"generate_text_clutter_layout completed status={settle_result.get('status')}" + ) + return { + "text_clutter_settle_result": settle_result, + } + + +def fit_text_table_to_clutter_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Resize the text-scene table to fit the laid-out clutter footprint.""" + if state.get("input_kind") != "text": + return {} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + table_result = state.get("table_result") + settle_result = state.get("text_clutter_settle_result") + + if table_result is None or settle_result is None: + return { + "table_fit_result": { + "status": "skipped", + "reason": "missing_table_or_settle_result", + } + } + + output_dir = paths.table_fit_dir + output_dir.mkdir(parents=True, exist_ok=True) + log_info(f"fit_text_table_to_clutter started output_dir={output_dir}") + table_fit_result = fit_text_scene_table( + table_result=table_result, + clutter_layout_result=settle_result, + output_root=output_root, + output_dir=output_dir, + ) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": table_result, + "objects": state.get("text_object_results") or [], + "text_clutter_settle": settle_result, + "table_fit_to_clutter": table_fit_result, + "generation_status": state.get("generation_status"), + } + ) + log_info( + f"fit_text_table_to_clutter completed status={table_fit_result.get('status')}" + ) + return { + "table_fit_result": table_fit_result, + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py new file mode 100644 index 00000000..c4af8054 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py @@ -0,0 +1,102 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + STEP_RESULT_FILENAME, + UNIFIED_SCENE_GEN_STEP, + UNIFIED_SCENE_STEP, +) + +__all__ = ["UnifiedScenePaths", "resolve_generated_path"] + + +def resolve_generated_path(value: Any, output_root: Path) -> Path: + """Resolve an absolute or output-root-relative generated artifact path.""" + if not value: + return Path() + path = Path(str(value)).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root.expanduser().resolve() / path).resolve() + + +@dataclass(frozen=True) +class UnifiedScenePaths: + """High-level paths owned by the unified-scene generation workflow.""" + + output_root: Path + + def __post_init__(self) -> None: + object.__setattr__( + self, + "output_root", + self.output_root.expanduser().resolve(), + ) + + @property + def workflow_root(self) -> Path: + return self.output_root / UNIFIED_SCENE_GEN_STEP + + @property + def image_gen_dir(self) -> Path: + return self.workflow_root / "image_gen" + + @property + def glb_gen_dir(self) -> Path: + return self.workflow_root / "glb_gen" + + @property + def debug_dir(self) -> Path: + return self.workflow_root / "debug" + + @property + def text_clutter_dir(self) -> Path: + return self.glb_gen_dir / "text_clutter_settled" + + @property + def table_fit_dir(self) -> Path: + return self.glb_gen_dir / "table_fit_to_clutter" + + @property + def image_segments_result(self) -> Path: + return self.output_root / IMAGE_SEGMENTS_STEP / STEP_RESULT_FILENAME + + def prepare_generation_dirs(self) -> tuple[Path, Path, Path]: + """Create and return the workflow's high-level generation directories.""" + directories = (self.image_gen_dir, self.glb_gen_dir, self.debug_dir) + for directory in directories: + directory.mkdir(parents=True, exist_ok=True) + return directories + + def resolve_scene_result(self, explicit_path: Path | None) -> Path: + """Resolve the unified-scene result produced by the preceding workflow.""" + if explicit_path is not None: + return explicit_path.expanduser().resolve() + + scene_dir = self.output_root / UNIFIED_SCENE_STEP + result_path = scene_dir / STEP_RESULT_FILENAME + if result_path.is_file(): + return result_path + + legacy_path = scene_dir / "results.json" + return legacy_path if legacy_path.is_file() else result_path diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py new file mode 100644 index 00000000..1543acfb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py @@ -0,0 +1,141 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_image_metric_scale_messages", + "build_text_metric_scale_messages", + "build_up_down_flip_check_messages", +] + +UNIFIED_SCENE_GEN_PROMPT_NAME = "unified_scene_gen.yaml" + + +def build_image_metric_scale_messages( + *, + bbox_name_image_path: Path, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for image-scene object metric scale estimation.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="image_metric_scale_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + { + "objects_json": json.dumps( + objects_json, + ensure_ascii=False, + indent=2, + ), + }, + prompt_key="image_metric_scale_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + +def build_text_metric_scale_messages( + *, + user_text: str, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for text-scene object metric scale estimation.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="text_metric_scale_system", + ), + }, + { + "role": "user", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + { + "user_text": user_text, + "objects_json": json.dumps( + objects_json, + ensure_ascii=False, + indent=2, + ), + }, + prompt_key="text_metric_scale_user", + ), + }, + ] + + +def build_up_down_flip_check_messages( + *, + original_image_path: Path, + comparison_image_path: Path, +) -> list[dict[str, Any]]: + """Build messages for VLM support-normal up/down flip verification.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(original_image_path)}, + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(comparison_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py new file mode 100644 index 00000000..2276e559 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py @@ -0,0 +1,76 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.io import relative_path + +__all__ = ["update_unified_scene"] + + +def update_unified_scene( + unified_scene: dict[str, Any], + table_result: dict[str, Any], + object_results: list[dict[str, Any]], + output_root: Path, +) -> None: + """Write generated asset references back into a unified-scene payload.""" + table = unified_scene.setdefault("table", {}) + metadata_keys = ( + "table_asset_source", + "support_normal_source", + "is_complete_visible_table", + "complete_table_description", + ) + path_keys = ( + "image_path", + "raw_geometry_path", + "support_reference_geometry_path", + "generated_table_raw_geometry_path", + "transformed_geometry_path", + "simready_geometry_path", + "aligned_geometry_path", + "mesh_path", + ) + for key in metadata_keys: + if key in table_result: + table[key] = table_result[key] + for key in path_keys: + if table_result.get(key): + table[key] = relative_path(table_result[key], output_root) + + objects_by_id = { + str(item.get("id", "")): item + for item in unified_scene.setdefault("objects", []) + if isinstance(item, dict) + } + for result in object_results: + target = objects_by_id.get(str(result.get("id", ""))) + if target is None: + continue + for key in ("image_path", "mesh_path", "aligned_geometry_path"): + if result.get(key): + target[key] = relative_path(result[key], output_root) + metric_scale = result.get("metric_scale") + if isinstance(metric_scale, dict): + target["metric_scale"] = { + key: value + for key, value in metric_scale.items() + if key not in {"result_path", "raw_model_output_path"} + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py new file mode 100644 index 00000000..b22fcebb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py @@ -0,0 +1,71 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +__all__ = [ + "IMAGE_METRIC_SCALE_JSON_SCHEMA", + "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", +] + +UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { + "title": "AlignedUpDownFlipCheckOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "selected_number": {"type": "integer", "enum": [1, 2]}, + "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, + "reason": {"type": "string"}, + }, + "required": ["selected_number", "confidence", "reason"], +} + +IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageMetricScaleEstimate", + "type": "object", + "additionalProperties": False, + "properties": { + "object_scales": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "object_id": {"type": "string"}, + "bbox_dims_cm": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "items": { + "type": "number", + "minimum": 1.0e-6, + }, + }, + "confidence": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + }, + "reason": {"type": "string"}, + }, + "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], + }, + }, + }, + "required": ["object_scales"], +} diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py new file mode 100644 index 00000000..12283516 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py @@ -0,0 +1,40 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["UnifiedSceneGenState"] + + +class UnifiedSceneGenState(AttemptState): + """LangGraph state for downstream unified-scene generation.""" + + output_root: Path + unified_scene_result_path: Path | None + llm: Any | None + unified_scene: dict[str, Any] | None + input_kind: str | None + table_result: dict[str, Any] | None + text_object_results: list[dict[str, Any]] + text_clutter_settle_result: dict[str, Any] | None + image_objects_layout_result: dict[str, Any] | None + table_fit_result: dict[str, Any] | None + generation_status: str | None From 625802fab67d624ad621ab475878c6e9d3938bd4 Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Mon, 29 Jun 2026 18:32:22 +0800 Subject: [PATCH 2/4] Fixed gym export bug: wrong object description; --- .../gen_sim/prompt2scene/agent_tools/tools/gym_export.py | 4 ++-- .../agent_tools/tools/image_scene_asset_generation.py | 1 + .../prompt2scene/agent_tools/tools/text_asset_generation.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py index 9f3c638f..0dcd6718 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py @@ -183,8 +183,8 @@ def export_gym_config( oid = str(obj.get("id", "")) if oid: object_meta_by_id[oid] = { - "description": str(obj.get("description", "")).strip(), - "name": str(obj.get("name", "")).strip(), + "description": str(obj.get("description") or "").strip(), + "name": str(obj.get("name") or "").strip(), } table_info = step_result.get("table") or {} diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py index 2275c40f..5df5984a 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -470,6 +470,7 @@ def generate_image_scene_assets( object_fields = ( "id", "name", + "description", "status", "image_path", "mesh_path", diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py index 1beb7603..b0d4a0f7 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py @@ -162,6 +162,7 @@ def generate_text_object_asset( return { "id": object_id, "name": object_name, + "description": description, "status": status, "image_path": image_path, "raw_geometry_path": raw_geometry_path, From 68adf9a1e11236fdd7d4014eefddf56572eac98a Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Mon, 29 Jun 2026 19:49:30 +0800 Subject: [PATCH 3/4] 1. Update prompt in scene_intake; 2. VLM judge percentage of the object clutter when the input tabletop seems to be a complete one; --- .../table_clutter_fit_manager/manager.py | 15 ++++++- .../tools/image_scene_asset_generation.py | 1 + .../agent_tools/tools/table_fit_scene.py | 2 + .../tools/text_asset_generation.py | 1 + .../prompts/data/scene_intake.yaml | 45 ++++++++++++++++--- .../workflows/scene_intake/prompts.py | 27 ++++++----- .../workflows/scene_intake/schema.py | 18 +++++++- .../workflows/scene_intake/utils.py | 29 +++++++++++- .../workflows/unified_scene/schema.py | 6 ++- .../workflows/unified_scene/utils.py | 7 ++- 10 files changed, 130 insertions(+), 21 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py index 987e1487..3a9a86e5 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py @@ -94,10 +94,19 @@ def fit_table_to_clutter( output_dir: Path, margin_cm: float = 10.0, support_occupancy_ratio: float = 0.80, + object_coverage_percent: int | None = None, gravity_settle_table: bool = True, sim_device: str = "cpu", ) -> dict[str, Any]: - """Fit a table mesh to an already laid-out clutter result.""" + """Fit a table mesh to an already laid-out clutter result. + + Args: + object_coverage_percent: If set (1-100), overrides + ``support_occupancy_ratio`` by converting the percentage to a ratio + (e.g. 30 → 0.30). The required table size is computed as + clutter_size / ratio. When None, the default + ``support_occupancy_ratio`` is used. + """ try: import trimesh except ImportError as exc: @@ -166,6 +175,10 @@ def fit_table_to_clutter( # Compute the required table size and uniform scale. clutter_size_cm = (clutter_bounds[1, :2] - clutter_bounds[0, :2]) * 100.0 + if object_coverage_percent is not None: + support_occupancy_ratio = float( + np.clip(object_coverage_percent / 100.0, 0.1, 1.0) + ) occupancy = float(np.clip(support_occupancy_ratio, 0.1, 1.0)) required_size_cm = clutter_size_cm / occupancy + 2.0 * float(margin_cm) support_size_cm = np.asarray(initial_support["size_xy"], dtype=np.float64) * 100.0 diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py index 5df5984a..9d3e42f1 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -456,6 +456,7 @@ def generate_image_scene_assets( "status", "is_complete_visible_table", "complete_table_description", + "object_coverage_percent", "table_asset_source", "support_normal_source", "image_path", diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py index ae96b3a3..273f15a6 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py @@ -42,6 +42,7 @@ def fit_text_scene_table( clutter_result=clutter_layout_result, output_root=output_root, output_dir=output_dir, + object_coverage_percent=table_result.get("object_coverage_percent"), ) log_info(f"text table fit completed status={result.get('status')}") return result @@ -94,6 +95,7 @@ def fit_image_scene_table( clutter_result=clutter_result, output_root=output_root, output_dir=output_dir, + object_coverage_percent=generated_table.get("object_coverage_percent"), ) log_info(f"image table fit completed status={result.get('status')}") return result diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py index b0d4a0f7..ada7ad78 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py @@ -283,6 +283,7 @@ def generate_text_table_asset( "is_complete_visible_table": bool( table_spec.get("is_complete_visible_table", False) ), + "object_coverage_percent": table_spec.get("object_coverage_percent"), "status": status, "image_path": image_path, "raw_geometry_path": raw_geometry_path, diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml index cabf99cb..bbdbbc8b 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml +++ b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml @@ -19,6 +19,10 @@ text_system: | + - CRITICAL: Include EVERY visible object on the tabletop without omission. Do + not skip, ignore, or drop any object, no matter how small, blurry, partially + occluded, or unfamiliar it appears. An incomplete assets list is the most + severe error you can make. - Output only real physical objects that can become 3D asset generation targets. - Do not include the table or tabletop region in assets. - assets is a list of object category groups, not a list of individual object @@ -178,6 +182,10 @@ image_system: | + - CRITICAL: Include EVERY visible object on the tabletop without omission. Do + not skip, ignore, or drop any object, no matter how small, blurry, partially + occluded, or unfamiliar it appears. An incomplete assets list is the most + severe error you can make. - Output only real physical objects that can become 3D asset generation targets. - Do not include the table or tabletop region in assets. - assets is a list of object category groups, not a list of individual object @@ -293,6 +301,19 @@ image_system: | table.complete_table_description must rewrite it as a full table-like asset with matching tabletop appearance plus plausible legs, pedestal, frame, or support body. + - For image input with is_complete_visible_table=true ONLY: choose + table.object_coverage_percent from exactly one of these four values. + Think in terms of SPATIAL SPREAD, not pixel area: imagine drawing the + smallest rectangle that encloses ALL objects on the tabletop, then ask + what fraction of the table surface that rectangle covers. Even sparse + small objects can score high if they are spread across the whole table. + 10 (objects clustered in one small region, most of the table is bare), + 30 (objects spread across a noticeable portion but large bare areas remain), + 50 (objects reach roughly half the table extent in at least one direction), + 70 (objects span most of the table, even if gaps exist between them). + Do not output any other value. + - For text input, or when is_complete_visible_table=false: OMIT the + object_coverage_percent field entirely. Do not include it in the output. @@ -301,8 +322,9 @@ image_system: | "name": "table", "description": "A rectangular wooden table with a brown top and four straight legs.", "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", - "is_complete_visible_table": false, - "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + "is_complete_visible_table": true, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"], + "object_coverage_percent": 25 }, "assets": [ { @@ -374,6 +396,10 @@ verifier_system: | + - CRITICAL: Do NOT remove any asset row from the draft assets list. Your job is + to check and correct counts, names, and class_candidate values — not to drop + objects. If an object exists in the draft, it must remain in the corrected + output. Only add new rows if objects were clearly missed. - assets is a list of object category groups, not individual instances. - Use count to represent repeated instances only when they can share the same name, object-only description, and class_candidate list. @@ -419,6 +445,13 @@ verifier_system: | use the most conservative count supported by the image. - For text inputs, count only objects explicitly stated or strongly implied by the text. + - For image input with is_complete_visible_table=true: independently + re-assess the tabletop coverage against the original image and pick + table.object_coverage_percent from exactly one of 10, 30, 50, 70. + Correct the draft value if the bucket does not match the visible + clutter density. + - For text input or when is_complete_visible_table is false: remove + object_coverage_percent from table entirely if it is present in the draft. @@ -427,8 +460,9 @@ verifier_system: | "name": "table", "description": "A rectangular wooden table with a brown top and four straight legs.", "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", - "is_complete_visible_table": false, - "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + "is_complete_visible_table": true, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"], + "object_coverage_percent": 30 }, "assets": [ { @@ -444,7 +478,8 @@ verifier_system: | - The top-level object must contain only table and assets. - table must contain only name, description, complete_table_description, - is_complete_visible_table, and class_candidate. + is_complete_visible_table, class_candidate, and optionally + object_coverage_percent (only when is_complete_visible_table is true). - Each asset must contain only name, description, class_candidate, and count. - Output JSON only. Do not include markdown or explanations outside JSON. diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py index 611c5bf9..421ec979 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py @@ -50,19 +50,24 @@ def build_scene_intake_verifier_messages( scene_intake: SceneIntakeSpec, ) -> list[dict[str, Any]]: """Build messages for scene-intake group and count verification.""" + table_draft: dict[str, object] = { + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": ( + scene_intake.table.is_complete_visible_table + ), + "class_candidate": list(scene_intake.table.class_candidate), + } + if scene_intake.table.object_coverage_percent is not None: + table_draft["object_coverage_percent"] = ( + scene_intake.table.object_coverage_percent + ) scene_intake_json = json.dumps( { - "table": { - "name": scene_intake.table.name, - "description": scene_intake.table.description, - "complete_table_description": ( - scene_intake.table.complete_table_description - ), - "is_complete_visible_table": ( - scene_intake.table.is_complete_visible_table - ), - "class_candidate": list(scene_intake.table.class_candidate), - }, + "table": table_draft, "assets": [ { "name": asset.name, diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py index 80c9ca27..31b55e6d 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py @@ -92,6 +92,18 @@ "minLength": 1, }, }, + "object_coverage_percent": { + "type": "integer", + "enum": [10, 30, 50, 70], + "description": ( + "For image input with a complete visible table ONLY: " + "choose the closest coverage bucket for objects on the " + "tabletop: 10 (mostly empty, a few small objects), " + "30 (lightly cluttered), 50 (moderately cluttered), " + "70 (densely packed). Omit this field entirely for " + "text input or when is_complete_visible_table is false." + ), + }, }, "required": [ "name", @@ -193,10 +205,11 @@ class SceneIntakeTable: complete_table_description: str = "" is_complete_visible_table: bool = False class_candidate: list[str] = field(default_factory=list) + object_coverage_percent: int | None = None def to_manifest(self) -> dict[str, object]: """Convert the table record to JSON-safe data.""" - return { + manifest: dict[str, object] = { "id": self.id, "name": self.name, "description": self.description, @@ -204,6 +217,9 @@ def to_manifest(self) -> dict[str, object]: "is_complete_visible_table": self.is_complete_visible_table, "class_candidate": list(self.class_candidate), } + if self.object_coverage_percent is not None: + manifest["object_coverage_percent"] = self.object_coverage_percent + return manifest @dataclass(frozen=True) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py index e49fe9b3..da084f55 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py @@ -19,6 +19,7 @@ import re from typing import Any +from embodichain.gen_sim.prompt2scene.utils.log import log_warning from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( SceneIntakeAsset, @@ -66,6 +67,7 @@ def _parse_table(raw_table: dict[str, Any]) -> SceneIntakeTable: "complete_table_description", "is_complete_visible_table", "class_candidate", + "object_coverage_percent", }, context="Scene intake table", ) @@ -107,12 +109,34 @@ def _parse_table(raw_table: dict[str, Any]) -> SceneIntakeTable: raw_name=name, ) + object_coverage_percent: int | None = None + raw_percent = raw_table.get("object_coverage_percent") + if raw_percent is not None: + if isinstance(raw_percent, bool): + raise ValueError( + "Scene intake table.object_coverage_percent must be an integer, " + "not a boolean." + ) + try: + object_coverage_percent = int(raw_percent) + except (TypeError, ValueError): + raise ValueError( + "Scene intake table.object_coverage_percent must be an integer " + f"between 1 and 100, got {raw_percent!r}." + ) + if object_coverage_percent not in (10, 30, 50, 70): + raise ValueError( + "Scene intake table.object_coverage_percent must be one of " + f"10, 30, 50, 70, got {object_coverage_percent}." + ) + return SceneIntakeTable( name=name, description=description, complete_table_description=complete_table_description, is_complete_visible_table=is_complete_visible_table, class_candidate=class_candidate, + object_coverage_percent=object_coverage_percent, ) @@ -214,7 +238,10 @@ def _validate_exact_keys( ) -> None: extra_keys = sorted(set(value) - allowed_keys) if extra_keys: - raise ValueError(f"{context} has unexpected keys: {extra_keys}.") + log_warning( + f"{context} has unexpected keys: {extra_keys}. " + f"These fields will be ignored." + ) def _require_mapping(value: Any, context: str) -> dict[str, Any]: diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py index f3d13125..baca2beb 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py @@ -42,10 +42,11 @@ class UnifiedTable: image_path: str | None = None mesh_path: str | None = None grid_cells: dict[str, list[str]] | None = None + object_coverage_percent: int | None = None def to_manifest(self) -> dict[str, Any]: """Convert the table to JSON-safe data.""" - return { + manifest: dict[str, Any] = { "id": self.id, "name": self.name, "description": self.description, @@ -56,6 +57,9 @@ def to_manifest(self) -> dict[str, Any]: "mesh_path": self.mesh_path, "grid_cells": self.grid_cells, } + if self.object_coverage_percent is not None: + manifest["object_coverage_percent"] = self.object_coverage_percent + return manifest @dataclass(frozen=True) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py index e17b5e7b..49e4a70c 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py @@ -93,7 +93,7 @@ def build_unified_table( grid_cells: dict[str, list[str]] | None = None, ) -> dict[str, Any]: """Build a unified table record from scene intake.""" - return { + table: dict[str, Any] = { "id": scene_intake.table.id, "name": scene_intake.table.name, "description": scene_intake.table.description, @@ -106,6 +106,11 @@ def build_unified_table( "mesh_path": None, "grid_cells": grid_cells, } + if scene_intake.table.object_coverage_percent is not None: + table["object_coverage_percent"] = ( + scene_intake.table.object_coverage_percent + ) + return table def build_unified_spatial_anchor(anchor: ImageAnchor | None) -> dict[str, Any] | None: From 8293f2cff7c17420083ea5afa3adc8064934e38e Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Tue, 30 Jun 2026 10:57:15 +0800 Subject: [PATCH 4/4] Fixed gym export bug; --- .../table_clutter_fit_manager/manager.py | 18 +- .../agent_tools/tools/gym_export.py | 315 +++++++++++++----- 2 files changed, 240 insertions(+), 93 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py index 3a9a86e5..eeb79a18 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py @@ -252,7 +252,23 @@ def fit_table_to_clutter( for oid, scene in shifted_clutter: object_path = output_dir / f"{oid}_on_table.glb" _copy_scene_with_transform(scene, z_to_y).export(object_path) - placed_objects.append({"id": oid, "path": str(object_path)}) + # Compute world-space AABB bottom-centre (sim Z-up coords) before + # the scene is converted to GLB Y-up for export. This is the + # reference position that gym_export uses to derive ``init_pos``. + _placed_mesh = _scene_to_mesh(scene, trimesh=trimesh) + _placed_b = np.asarray(_placed_mesh.bounds, dtype=np.float64) + world_aabb_bottom_center = [ + float(0.5 * (_placed_b[0, 0] + _placed_b[1, 0])), + float(0.5 * (_placed_b[0, 1] + _placed_b[1, 1])), + float(_placed_b[0, 2]), + ] + placed_objects.append( + { + "id": oid, + "path": str(object_path), + "world_aabb_bottom_center": world_aabb_bottom_center, + } + ) # Write the fit manifest. final_clutter_bounds = _table_fit_scene_union_bounds( diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py index 0dcd6718..d26a1484 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py @@ -20,6 +20,7 @@ import math import shutil import time +from collections.abc import Sequence from pathlib import Path from typing import Any @@ -49,7 +50,12 @@ "restitution": 0.01, } -_DEFAULT_MAX_CONVEX_HULL_NUM = 8 +_DEFAULT_MAX_CONVEX_HULL_NUM = 32 + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- def _resolve_path(value: str, output_root: Path) -> Path: @@ -83,13 +89,37 @@ def _matrix_to_euler_xyz_deg(matrix: list[list[float]]) -> list[float]: return [math.degrees(x), math.degrees(y), math.degrees(z)] -def _glb_aabb_bottom_center(glb_path: Path) -> list[float]: - """``[x, y, z]`` bottom-centre position in **simulation Z-up** space. +def _glb_to_sim_rotation() -> np.ndarray: + """Return the loader basis conversion from GLB Y-up to sim Z-up.""" + return np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float64, + ) - The GLB is stored in Y-up convention (X=width, Y=up, Z=depth). - EmbodiChain simulation converts to Z-up on load, so we return the - position in Z-up space: ``center_X``, ``-center_Z``, ``min_Y``. - """ + +def _glb_rotation_to_sim(rotation_matrix: list[list[float]]) -> list[list[float]]: + """Convert a GLB-space local rotation into simulation-space rotation.""" + rot = np.asarray(rotation_matrix, dtype=np.float64) + if rot.shape == (4, 4): + rot = rot[:3, :3] + basis = _glb_to_sim_rotation() + return (basis @ rot @ basis.T).tolist() + + +def _glb_scale_to_sim(scale: Sequence[float]) -> list[float]: + """Convert GLB-axis scale components to sim-axis body_scale components.""" + values = [float(v) for v in scale] + if len(values) != 3: + raise ValueError("scale must have three components") + return [values[0], values[2], values[1]] + + +def _glb_max_z(glb_path: Path) -> float: + """Maximum height (Y in GLB, Z in simulation) of a mesh.""" import trimesh scene = trimesh.load(glb_path, force="scene") @@ -104,16 +134,23 @@ def _glb_aabb_bottom_center(glb_path: Path) -> list[float]: [m for m in dumped if isinstance(m, trimesh.Trimesh)] ) ) - b = np.asarray(mesh.bounds, dtype=np.float64) - return [ - float(0.5 * (b[0, 0] + b[1, 0])), # centre X - float(-0.5 * (b[0, 2] + b[1, 2])), # -centre Z (GLB Z → internal -Y) - float(b[0, 1]), # min Y (GLB up → internal Z) - ] + return float(np.asarray(mesh.bounds, dtype=np.float64)[1, 1]) # max Y -def _glb_max_z(glb_path: Path) -> float: - """Maximum height (Y in GLB, Z in simulation) of a mesh.""" +def _rotated_aabb_offsets( + glb_path: Path, + rotation_matrix: list[list[float]] | None, + scale: float | Sequence[float] = 1.0, +) -> tuple[float, float, float]: + """Compute the AABB shift caused by rotation + scale alone. + + Loads the simready GLB, applies *rotation_matrix* and *scale_factor* + around the local origin (the AABB bottom-centre), and returns the XY + centre and minimum Z of the resulting AABB. These offsets are + subtracted from the fitted AABB bottom-centre to recover the true + world-space position of the simready local origin (the ``init_pos`` + that the simulation expects). + """ import trimesh scene = trimesh.load(glb_path, force="scene") @@ -128,7 +165,122 @@ def _glb_max_z(glb_path: Path) -> float: [m for m in dumped if isinstance(m, trimesh.Trimesh)] ) ) - return float(np.asarray(mesh.bounds, dtype=np.float64)[1, 1]) # max Y + verts = mesh.vertices.copy() + if isinstance(scale, Sequence) and not isinstance(scale, (str, bytes)): + scale_array = np.asarray(list(scale), dtype=np.float64) + if scale_array.shape != (3,): + raise ValueError("scale must be a scalar or a 3-vector") + verts *= scale_array + else: + verts *= float(scale) + if rotation_matrix is not None: + rot = np.asarray(rotation_matrix, dtype=np.float64) + if rot.shape == (4, 4): + rot = rot[:3, :3] + verts = (rot @ verts.T).T + b = np.zeros((2, 3), dtype=np.float64) + b[0] = verts.min(axis=0) + b[1] = verts.max(axis=0) + return ( + float(0.5 * (b[0, 0] + b[1, 0])), # AABB centre X → sim X + float(-0.5 * (b[0, 2] + b[1, 2])), # -centre Z → sim Y + float(b[0, 1]), # min Y → sim Z + ) + + +# --------------------------------------------------------------------------- +# consolidated object manifest +# --------------------------------------------------------------------------- + + +def _build_object_manifest( + output_root: Path, + step_result: dict[str, Any], + table_fit_manifest: dict[str, Any], + aligned_by_id: dict[str, dict[str, Any]], +) -> dict[str, Any]: + """Merge world_bc, rotation, scale into one per-object record. + + Returns a dict keyed by object id, each value containing everything + needed to compute ``init_pos`` / ``init_rot`` / ``body_scale``. + """ + objects_info = step_result.get("objects") or [] + + # index metric_scale by object id + metric_by_id: dict[str, float] = {} + for obj in objects_info: + oid = str(obj.get("id", "")) + if not oid: + continue + ms = obj.get("metric_scale") + sf = float(ms.get("scale_factor", 1.0)) if isinstance(ms, dict) else 1.0 + metric_by_id[oid] = sf + + # index world_aabb_bottom_center from table-fit manifest + world_bc_by_id: dict[str, list[float]] = {} + for e in table_fit_manifest.get("objects") or []: + eid = str(e.get("id", "")) if isinstance(e, dict) else "" + wbc = e.get("world_aabb_bottom_center") if isinstance(e, dict) else None + if eid and isinstance(wbc, list) and len(wbc) == 3: + world_bc_by_id[eid] = [float(v) for v in wbc] + + consolidated: dict[str, Any] = {} + skipped_no_glb: list[str] = [] + for obj in objects_info: + oid = str(obj.get("id", "")) + if not oid: + continue + + source = obj.get("simready_geometry_path") or obj.get("mesh_path") + simready_path = _resolve_path(source or "", output_root) + if not simready_path.is_file(): + skipped_no_glb.append(oid) + continue + + description = str(obj.get("description") or obj.get("name") or "").strip() + scale_factor = metric_by_id.get(oid, 1.0) + + aligned = aligned_by_id.get(oid) + rot_matrix: list[list[float]] | None = None + transform_scale: list[float] | None = None + if aligned: + raw = aligned.get("rotation_matrix") + if raw and isinstance(raw, list): + rot_matrix = raw + raw_scale = aligned.get("scale") + if isinstance(raw_scale, list) and len(raw_scale) == 3: + transform_scale = [float(v) for v in raw_scale] + + wbc = world_bc_by_id.get(oid) + + consolidated[oid] = { + "id": oid, + "description": description, + "simready_path": simready_path, + "scale_factor": scale_factor, + "transform_scale": transform_scale, + "rotation_matrix": rot_matrix, + "world_aabb_bottom_center": wbc, + } + + if skipped_no_glb: + print( + " [WARN] object(s) skipped (simready GLB not found): " + + ", ".join(skipped_no_glb) + ) + extra_in_manifest = set(world_bc_by_id) - set(consolidated) + if extra_in_manifest: + print( + " [WARN] object(s) in table-fit manifest but not in step_result: " + + ", ".join(sorted(extra_in_manifest)) + ) + + return consolidated + + +# --------------------------------------------------------------------------- +# main export +# --------------------------------------------------------------------------- def export_gym_config( @@ -148,45 +300,33 @@ def export_gym_config( export_dir = export_dir.expanduser().resolve() export_dir.mkdir(parents=True, exist_ok=True) - # ── step result & table-fit manifest ────────────────────────────── + # ── data sources ──────────────────────────────────────────────────── step_result = _read_json( output_root / UNIFIED_SCENE_GEN_STEP / STEP_RESULT_FILENAME ) table_fit = step_result.get("table_fit_to_clutter") or {} - manifest = _read_json( + table_fit_manifest = _read_json( _resolve_path(table_fit.get("manifest_path", ""), output_root) ) - # ── per-object metadata from simready→aligned manifest ──────────── aligned_by_id: dict[str, dict[str, Any]] = {} aligned_manifest_path = ( - output_root / UNIFIED_SCENE_GEN_STEP / "glb_gen" / "simready_to_aligned_manifest.json" + output_root + / UNIFIED_SCENE_GEN_STEP + / "glb_gen" + / "simready_to_aligned_manifest.json" ) if aligned_manifest_path.is_file(): - aligned_manifest = _read_json(aligned_manifest_path) - for item in aligned_manifest.get("items", []) or []: - if isinstance(item, dict): - aligned_by_id[str(item.get("id", ""))] = item - - # ── table surface Z (from fitted table GLB) ─────────────────────── - fitted_table_path = _resolve_path( - manifest.get("table_output_path", ""), output_root - ) - table_surface_z = ( - _glb_max_z(fitted_table_path) if fitted_table_path.is_file() else 0.0 - ) + for item in _read_json(aligned_manifest_path).get("items", []) or []: + if isinstance(item, dict) and item.get("id"): + aligned_by_id[str(item["id"])] = item - # ── description lookup ──────────────────────────────────────────── - object_meta_by_id: dict[str, dict[str, str]] = {} - for obj in step_result.get("objects", []) or []: - if isinstance(obj, dict): - oid = str(obj.get("id", "")) - if oid: - object_meta_by_id[oid] = { - "description": str(obj.get("description") or "").strip(), - "name": str(obj.get("name") or "").strip(), - } + # ── consolidated per-object manifest ───────────────────────────────── + object_manifest = _build_object_manifest( + output_root, step_result, table_fit_manifest, aligned_by_id + ) + # ── table ──────────────────────────────────────────────────────────── table_info = step_result.get("table") or {} table_desc = str( table_info.get("complete_table_description") @@ -196,7 +336,6 @@ def export_gym_config( mesh_assets_dir = export_dir / "mesh_assets" mesh_assets_dir.mkdir(parents=True, exist_ok=True) - # ── table ───────────────────────────────────────────────────────── table_simready = _resolve_path( table_info.get("simready_geometry_path") or table_info.get("mesh_path", ""), @@ -208,67 +347,52 @@ def export_gym_config( table_dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(table_simready, table_dst) + table_surface_z = _glb_max_z(table_simready) + uniform_scale = 1.0 - ts = manifest.get("table_xy_scale") + ts = table_fit_manifest.get("table_xy_scale") if isinstance(ts, dict): uniform_scale = float(ts.get("uniform_scale", 1.0)) - # ── objects ─────────────────────────────────────────────────────── - table_fit_objects = { - str(e["id"]): _resolve_path(e["path"], output_root) - for e in (manifest.get("objects") or []) - if isinstance(e, dict) - } - objects_info = step_result.get("objects") or [] + # ── objects ────────────────────────────────────────────────────────── rigid_objects: list[dict[str, Any]] = [] - def _obj_desc(obj: dict[str, Any]) -> str: - meta = object_meta_by_id.get(str(obj.get("id", ""))) - return (meta["description"] or meta["name"]) if meta else "" - - for obj in objects_info: - if not isinstance(obj, dict): - continue - object_id = str(obj.get("id", "")) - if not object_id: - continue - - # ── GLB: simready (normalised, no baked transforms) ────────── - source = obj.get("simready_geometry_path") or obj.get("mesh_path") - object_src = _resolve_path(source, output_root) - if not object_src.is_file(): - continue - - safe_name = object_id.replace("interact_", "").strip("_") or "object" - obj_dir = mesh_assets_dir / safe_name / object_id + total = len(object_manifest) + for idx, (oid, om) in enumerate(object_manifest.items()): + # Copy simready GLB + safe_name = oid.replace("interact_", "").strip("_") or "object" + obj_dir = mesh_assets_dir / safe_name / oid obj_dir.mkdir(parents=True, exist_ok=True) - object_dst = obj_dir / f"{object_id}.glb" - shutil.copy2(object_src, object_dst) + object_dst = obj_dir / f"{oid}.glb" + shutil.copy2(om["simready_path"], object_dst) - # ── body_scale ──────────────────────────────────────────────── - ms = obj.get("metric_scale") - scale_factor = float(ms.get("scale_factor", 1.0)) if isinstance(ms, dict) else 1.0 - body_scale = [scale_factor, scale_factor, scale_factor] + # body_scale. Image-scene alignment may contain a full simready→aligned + # scale; text-scene layout only has the per-object metric scale. + sf = om["scale_factor"] + scale_glb = om.get("transform_scale") or [sf, sf, sf] + body_scale = _glb_scale_to_sim(scale_glb) - # ── init_pos: read from fitted on-table GLB ─────────────────── - fitted_path = table_fit_objects.get(object_id) - if fitted_path and fitted_path.is_file(): - init_pos = _glb_aabb_bottom_center(fitted_path) - else: - init_pos = [0.0, 0.0, table_surface_z] - - # ── init_rot: decompose from simready→aligned rotation ──────── + # init_rot init_rot: list[float] = [0.0, 0.0, 0.0] - aligned = aligned_by_id.get(object_id) - if aligned: - rot = aligned.get("rotation_matrix") - if rot and isinstance(rot, list): - init_rot = _matrix_to_euler_xyz_deg(rot) + if om["rotation_matrix"] is not None: + init_rot = _matrix_to_euler_xyz_deg( + _glb_rotation_to_sim(om["rotation_matrix"]) + ) + + # init_pos = world_bc - rotated_aabb_offset + ro = _rotated_aabb_offsets( + om["simready_path"], om["rotation_matrix"], scale_glb + ) + wbc = om["world_aabb_bottom_center"] + if wbc is not None: + init_pos = [wbc[0] - ro[0], wbc[1] - ro[1], wbc[2] - ro[2]] + else: + init_pos = [-ro[0], -ro[1], table_surface_z - ro[2]] rigid_objects.append( { - "uid": object_id, - "description": _obj_desc(obj), + "uid": oid, + "description": om["description"], "shape": { "shape_type": "Mesh", "fpath": str(object_dst.relative_to(export_dir)), @@ -282,8 +406,14 @@ def _obj_desc(obj: dict[str, Any]) -> str: "max_convex_hull_num": _DEFAULT_MAX_CONVEX_HULL_NUM, } ) + wbc = om["world_aabb_bottom_center"] + wbc_flag = "wbc" if wbc is not None else "fallback" + print( + f" [{idx+1}/{total}] [{oid}] {om['description']}" + f" pos={init_pos} rot={init_rot} scale={body_scale} src={wbc_flag}" + ) - # ── write config ────────────────────────────────────────────────── + # ── write gym config ───────────────────────────────────────────────── config = { "id": f"Prompt2Scene-{int(time.time() * 1000)}-v0", "max_episodes": 10, @@ -316,4 +446,5 @@ def _obj_desc(obj: dict[str, Any]) -> str: json.dumps(config, indent=4, ensure_ascii=False) + "\n", encoding="utf-8", ) + return config_path