From 4faa4d664675014fb9607c5b1e4e39dc126c1731 Mon Sep 17 00:00:00 2001 From: matafela Date: Mon, 29 Jun 2026 16:21:04 +0800 Subject: [PATCH 1/4] move joints can have multiple waypoint input --- embodichain/lab/sim/atomic_actions/actions.py | 9 +++- embodichain/lab/sim/atomic_actions/core.py | 9 +++- .../lab/sim/atomic_actions/trajectory.py | 54 ++++++++++++++++--- .../tutorials/atomic_action/move_joints.py | 13 ++++- tests/sim/atomic_actions/test_actions.py | 45 ++++++++++++++++ tests/sim/atomic_actions/test_trajectory.py | 39 ++++++++++++++ 6 files changed, 157 insertions(+), 12 deletions(-) diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index dd7991d9..a608d10d 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -254,7 +254,14 @@ def _fail(self, state: WorldState) -> ActionResult: class MoveJoints(AtomicAction): - """Plan a joint-space move for the configured control part.""" + """Plan a joint-space move for the configured control part. + + The :class:`JointPositionTarget` may carry either a single waypoint + ``(n_envs, control_dof)`` or a multi-waypoint trajectory + ``(n_envs, n_waypoint, control_dof)``. In the multi-waypoint case the + action plans a single trajectory that visits every waypoint in order, + starting from the inherited ``WorldState.last_qpos``. + """ TargetType: ClassVar[tuple[type, ...]] = ( JointPositionTarget, diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py index 5d378ccb..4e71ffcf 100644 --- a/embodichain/lab/sim/atomic_actions/core.py +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -79,7 +79,14 @@ class JointPositionTarget: """Joint-space target for a configured robot control part.""" qpos: torch.Tensor - """(control_dof,) or (n_envs, control_dof) target joint positions.""" + """Target joint positions. + + Accepts: + + - ``(control_dof,)`` or ``(n_envs, control_dof)`` — a single waypoint. + - ``(n_envs, n_waypoint, control_dof)`` — a multi-waypoint trajectory; + waypoints are visited in order. + """ @dataclass(frozen=True) diff --git a/embodichain/lab/sim/atomic_actions/trajectory.py b/embodichain/lab/sim/atomic_actions/trajectory.py index be584cbc..98b4aaa3 100644 --- a/embodichain/lab/sim/atomic_actions/trajectory.py +++ b/embodichain/lab/sim/atomic_actions/trajectory.py @@ -103,20 +103,47 @@ def resolve_joint_target( joint_dof: int, control_part: str, ) -> torch.Tensor: - """Resolve a joint-space target into batched control-part joint positions.""" + """Resolve a joint-space target into batched control-part joint positions. + + Accepts the following shapes for ``target_qpos``: + + - ``(joint_dof,)`` — broadcast to ``(n_envs, joint_dof)`` (single waypoint). + - ``(n_envs, joint_dof)`` — single waypoint, validated and passed through. + - ``(n_envs, n_waypoint, joint_dof)`` — a multi-waypoint trajectory; each + waypoint is visited in order. ``n_waypoint`` may be 1. + + Returns a 2D tensor for single-waypoint inputs and a 3D tensor for + multi-waypoint inputs, leaving downstream planners to treat the trailing + axis as the joint dimension. + """ if not isinstance(target_qpos, torch.Tensor): logger.log_error( f"target qpos for '{control_part}' must be a torch.Tensor with shape " - f"({joint_dof},) or ({n_envs}, {joint_dof})", + f"({joint_dof},), ({n_envs}, {joint_dof}), or " + f"({n_envs}, n_waypoint, {joint_dof})", TypeError, ) target_qpos = target_qpos.to(device=self.device, dtype=torch.float32) if target_qpos.shape == (joint_dof,): target_qpos = target_qpos.unsqueeze(0).repeat(n_envs, 1) - if target_qpos.shape != (n_envs, joint_dof): + if target_qpos.dim() == 2: + if target_qpos.shape != (n_envs, joint_dof): + logger.log_error( + f"target qpos for '{control_part}' must have shape ({joint_dof},) " + f"or ({n_envs}, {joint_dof}), but got {target_qpos.shape}", + ValueError, + ) + elif target_qpos.dim() == 3: + if target_qpos.shape[0] != n_envs or target_qpos.shape[2] != joint_dof: + logger.log_error( + f"multi-waypoint target qpos for '{control_part}' must have shape " + f"({n_envs}, n_waypoint, {joint_dof}), but got {target_qpos.shape}", + ValueError, + ) + else: logger.log_error( - f"target qpos for '{control_part}' must have shape ({joint_dof},) " - f"or ({n_envs}, {joint_dof}), but got {target_qpos.shape}", + f"target qpos for '{control_part}' must be 1D, 2D, or 3D with " + f"trailing dim {joint_dof}, but got {target_qpos.shape}", ValueError, ) return target_qpos @@ -291,10 +318,21 @@ def plan_joint_traj( target_qpos: torch.Tensor, n_waypoints: int, ) -> torch.Tensor: - """Interpolate a joint-space trajectory from ``start_qpos`` to ``target_qpos``.""" - trajectory = torch.stack([start_qpos, target_qpos], dim=1) + """Interpolate a joint-space trajectory through one or more target waypoints. + + ``start_qpos`` has shape ``(n_envs, joint_dof)``. ``target_qpos`` is + either a single waypoint ``(n_envs, joint_dof)`` or a sequence of + waypoints ``(n_envs, n_waypoint, joint_dof)``. The start configuration is + prepended to the target waypoints to build the keyframe sequence + ``(n_envs, 1 + n_waypoint, joint_dof)``, which is then resampled to + ``n_waypoints`` output samples by cumulative-distance piecewise-linear + interpolation — so each consecutive waypoint pair is traversed in turn. + """ + if target_qpos.dim() == 2: + target_qpos = target_qpos.unsqueeze(1) + keyframes = torch.cat([start_qpos.unsqueeze(1), target_qpos], dim=1) return interpolate_with_distance( - trajectory=trajectory, interp_num=n_waypoints, device=self.device + trajectory=keyframes, interp_num=n_waypoints, device=self.device ) # ------------------------------------------------------------------ diff --git a/scripts/tutorials/atomic_action/move_joints.py b/scripts/tutorials/atomic_action/move_joints.py index 32cfc44d..e1ca0bd2 100644 --- a/scripts/tutorials/atomic_action/move_joints.py +++ b/scripts/tutorials/atomic_action/move_joints.py @@ -179,6 +179,7 @@ def main() -> None: # Step 3: Configure the MoveJoints atomic action # # ------------------------------------------------------------------ # ready_qpos = make_arm_qpos([0.35, -1.20, 1.30, -1.65, -1.57, 0.20], sim.device) + mid_qpos = make_arm_qpos([0.15, -1.40, 1.45, -1.60, -1.57, 0.10], sim.device) home_qpos = make_arm_qpos([0.0, -1.57, 1.57, -1.57, -1.57, 0.0], sim.device) move_joints_cfg = MoveJointsCfg( control_part="arm", @@ -205,13 +206,21 @@ def main() -> None: # ------------------------------------------------------------------ # # Step 6: Plan the declared (name, typed_target) sequence # # ------------------------------------------------------------------ # + # The second MoveJoints step passes a multi-waypoint trajectory + # (n_envs, n_waypoint, control_dof): the arm visits `mid_qpos` then + # `home_qpos` in a single planned trajectory. + n_envs = robot.get_qpos().shape[0] + multi_waypoint_qpos = ( + torch.stack([mid_qpos, home_qpos], dim=0).unsqueeze(0).repeat(n_envs, 1, 1) + ) logger.log_info( - "Planning MoveJoints: NamedJointPositionTarget('ready') -> explicit home qpos" + "Planning MoveJoints: NamedJointPositionTarget('ready') -> " + "multi-waypoint trajectory (mid -> home)" ) is_success, traj, _ = atomic_engine.run( steps=[ ("move_joints", NamedJointPositionTarget(name="ready")), - ("move_joints", JointPositionTarget(qpos=home_qpos)), + ("move_joints", JointPositionTarget(qpos=multi_waypoint_qpos)), ] ) if not is_success: diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py index fb5a1bf0..ca3b0bf3 100644 --- a/tests/sim/atomic_actions/test_actions.py +++ b/tests/sim/atomic_actions/test_actions.py @@ -221,6 +221,51 @@ def test_execute_with_named_qpos_resolves_cfg_target(self): torch.full((NUM_ENVS, ARM_DOF), 0.2), ) + def test_execute_with_multi_waypoint_qpos_visits_each_waypoint(self): + action = MoveJoints(self.mg, MoveJointsCfg(sample_interval=10)) + # (n_envs, n_waypoint, control_dof) trajectory target + waypoint_qpos = ( + torch.stack( + [ + torch.full((ARM_DOF,), 0.3), + torch.full((ARM_DOF,), 0.7), + ], + dim=0, + ) + .unsqueeze(0) + .repeat(NUM_ENVS, 1, 1) + ) + last_qpos = torch.zeros(NUM_ENVS, TOTAL_DOF) + + captured = {} + + def interpolate(trajectory, interp_num, device): + captured["keyframes"] = trajectory + return trajectory[:, -1:, :].repeat(1, interp_num, 1) + + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + side_effect=interpolate, + ): + result = action.execute( + JointPositionTarget(qpos=waypoint_qpos), + WorldState(last_qpos=last_qpos), + ) + + assert result.success is True + assert result.trajectory.shape == (NUM_ENVS, 10, TOTAL_DOF) + # start prepended to the two waypoints -> 3 keyframes + keyframes = captured["keyframes"] + assert keyframes.shape == (NUM_ENVS, 3, ARM_DOF) + assert torch.allclose(keyframes[:, 0, :], torch.zeros(NUM_ENVS, ARM_DOF)) + assert torch.allclose(keyframes[:, 1, :], torch.full((NUM_ENVS, ARM_DOF), 0.3)) + assert torch.allclose(keyframes[:, 2, :], torch.full((NUM_ENVS, ARM_DOF), 0.7)) + # final state lands on the last waypoint + assert torch.allclose( + result.next_state.last_qpos[:, :ARM_DOF], + torch.full((NUM_ENVS, ARM_DOF), 0.7), + ) + def test_unknown_named_qpos_raises(self): action = MoveJoints(self.mg, MoveJointsCfg()) with pytest.raises(KeyError, match="missing"): diff --git a/tests/sim/atomic_actions/test_trajectory.py b/tests/sim/atomic_actions/test_trajectory.py index ef28ab8b..7390a40a 100644 --- a/tests/sim/atomic_actions/test_trajectory.py +++ b/tests/sim/atomic_actions/test_trajectory.py @@ -123,6 +123,26 @@ def test_wrong_shape_raises(self): torch.zeros(5), n_envs=2, joint_dof=6, control_part="arm" ) + def test_multi_waypoint_passes_through(self): + qpos = torch.arange(24, dtype=torch.float32).reshape(2, 2, 6) + out = self.builder.resolve_joint_target( + qpos, n_envs=2, joint_dof=6, control_part="arm" + ) + assert out.shape == (2, 2, 6) + assert torch.equal(out, qpos.to(torch.float32)) + + def test_multi_waypoint_wrong_envs_raises(self): + with pytest.raises(Exception): + self.builder.resolve_joint_target( + torch.zeros(3, 2, 6), n_envs=2, joint_dof=6, control_part="arm" + ) + + def test_multi_waypoint_wrong_dof_raises(self): + with pytest.raises(Exception): + self.builder.resolve_joint_target( + torch.zeros(2, 2, 5), n_envs=2, joint_dof=6, control_part="arm" + ) + class TestSplitThreePhase: def setup_method(self): @@ -212,6 +232,25 @@ def test_interpolates_start_to_target(self): assert torch.equal(kwargs["trajectory"][:, 0, :], start) assert torch.equal(kwargs["trajectory"][:, 1, :], target) + def test_interpolates_start_through_multi_waypoints(self): + start = torch.zeros(2, 6) + waypoints = torch.arange(24, dtype=torch.float32).reshape(2, 2, 6) + expected = torch.ones(2, 5, 6) + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + return_value=expected, + ) as interpolate: + out = self.builder.plan_joint_traj(start, waypoints, n_waypoints=5) + + assert out is expected + _, kwargs = interpolate.call_args + assert kwargs["interp_num"] == 5 + # start prepended, then every waypoint in order + assert kwargs["trajectory"].shape == (2, 3, 6) + assert torch.equal(kwargs["trajectory"][:, 0, :], start) + assert torch.equal(kwargs["trajectory"][:, 1, :], waypoints[:, 0, :]) + assert torch.equal(kwargs["trajectory"][:, 2, :], waypoints[:, 1, :]) + class TestIkSolve: def test_uses_first_env_seed_for_single_pose(self): From 9d6a2a569675cf8e14f450f09e70229861298f7e Mon Sep 17 00:00:00 2001 From: matafela Date: Mon, 29 Jun 2026 17:45:04 +0800 Subject: [PATCH 2/4] update --- embodichain/lab/sim/atomic_actions/actions.py | 32 +++++++++++--- embodichain/lab/sim/atomic_actions/core.py | 9 +++- .../lab/sim/atomic_actions/trajectory.py | 36 +++++++++++++--- .../atomic_action/move_end_effector.py | 43 ++++++++++++++++--- tests/sim/atomic_actions/test_actions.py | 41 ++++++++++++++++++ tests/sim/atomic_actions/test_trajectory.py | 13 ++++++ 6 files changed, 158 insertions(+), 16 deletions(-) diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index a608d10d..99904727 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -177,7 +177,15 @@ def _arm_qpos_from_state( class MoveEndEffector(AtomicAction): - """Plan a free-space end-effector move to a target pose.""" + """Plan a free-space end-effector move to a target pose. + + The :class:`EndEffectorPoseTarget` may carry either a single waypoint + ``(n_envs, 4, 4)`` (or a broadcastable ``(4, 4)``) or a multi-waypoint + trajectory ``(n_envs, n_waypoint, 4, 4)``. In the multi-waypoint case the + action plans a single trajectory that visits every waypoint in order, + starting from the inherited ``WorldState.last_qpos`` — IK is solved for each + waypoint with the previous waypoint's solution as the seed. + """ TargetType: ClassVar[type] = EndEffectorPoseTarget @@ -201,10 +209,7 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes arm_dof=self.arm_dof, control_part=self.cfg.control_part, ) - target_states_list = [ - [PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE)] - for i in range(self.n_envs) - ] + target_states_list = self._build_target_states(move_xpos) ok, arm_traj = self.builder.plan_arm_traj( target_states_list, start_qpos, @@ -223,6 +228,23 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes ), ) + def _build_target_states(self, move_xpos: torch.Tensor) -> list[list[PlanState]]: + """Build per-env PlanState lists from a single- or multi-waypoint target. + + ``move_xpos`` is the resolved target: 3D ``(n_envs, 4, 4)`` for a single + waypoint or 4D ``(n_envs, n_waypoint, 4, 4)`` for a trajectory. + """ + if move_xpos.dim() == 3: + move_xpos = move_xpos.unsqueeze(1) + n_waypoint = move_xpos.shape[1] + return [ + [ + PlanState(xpos=move_xpos[i, j], move_type=MoveType.EEF_MOVE) + for j in range(n_waypoint) + ] + for i in range(self.n_envs) + ] + def _embed( self, arm_traj: torch.Tensor, last_full_qpos: torch.Tensor ) -> torch.Tensor: diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py index 4e71ffcf..0f9f2909 100644 --- a/embodichain/lab/sim/atomic_actions/core.py +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -71,7 +71,14 @@ class EndEffectorPoseTarget: """End-effector pose target. Used by MoveEndEffector and Place.""" xpos: torch.Tensor - """(4, 4) or (n_envs, 4, 4) homogeneous transform.""" + """Target end-effector homogeneous transform. + + Accepts: + + - ``(4, 4)`` or ``(n_envs, 4, 4)`` — a single waypoint. + - ``(n_envs, n_waypoint, 4, 4)`` — a multi-waypoint trajectory; waypoints + are visited in order. (Only consumed as multi-waypoint by MoveEndEffector.) + """ @dataclass(frozen=True) diff --git a/embodichain/lab/sim/atomic_actions/trajectory.py b/embodichain/lab/sim/atomic_actions/trajectory.py index 98b4aaa3..ded8a2be 100644 --- a/embodichain/lab/sim/atomic_actions/trajectory.py +++ b/embodichain/lab/sim/atomic_actions/trajectory.py @@ -57,19 +57,45 @@ def all_envs_success(self, is_success: bool | torch.Tensor) -> bool: return bool(is_success) def resolve_pose_target(self, target: torch.Tensor, *, n_envs: int) -> torch.Tensor: - """Broadcast a (4, 4) pose to (n_envs, 4, 4) or validate batched shape.""" + """Resolve an end-effector pose target into batched homogeneous transforms. + + Accepts the following shapes for ``target``: + + - ``(4, 4)`` — broadcast to ``(n_envs, 4, 4)`` (single waypoint). + - ``(n_envs, 4, 4)`` — single waypoint, validated and passed through. + - ``(n_envs, n_waypoint, 4, 4)`` — a multi-waypoint trajectory; each + waypoint is visited in order. ``n_waypoint`` may be 1. + + Returns a 3D tensor for single-waypoint inputs and a 4D tensor for + multi-waypoint inputs. + """ if not isinstance(target, torch.Tensor): logger.log_error( - f"target must be torch.Tensor of shape (4, 4) or ({n_envs}, 4, 4)", + f"target must be torch.Tensor of shape (4, 4), ({n_envs}, 4, 4), " + f"or ({n_envs}, n_waypoint, 4, 4)", TypeError, ) target = target.to(device=self.device, dtype=torch.float32) if target.shape == (4, 4): target = target.unsqueeze(0).repeat(n_envs, 1, 1) - if target.shape != (n_envs, 4, 4): + if target.dim() == 3: + if target.shape != (n_envs, 4, 4): + logger.log_error( + f"target tensor must have shape (4, 4) or ({n_envs}, 4, 4), " + f"but got {target.shape}", + ValueError, + ) + elif target.dim() == 4: + if target.shape[0] != n_envs or target.shape[2:] != (4, 4): + logger.log_error( + f"multi-waypoint target tensor must have shape " + f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}", + ValueError, + ) + else: logger.log_error( - f"target tensor must have shape (4, 4) or ({n_envs}, 4, 4), " - f"but got {target.shape}", + f"target tensor must be (4, 4), ({n_envs}, 4, 4), or " + f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}", ValueError, ) return target diff --git a/scripts/tutorials/atomic_action/move_end_effector.py b/scripts/tutorials/atomic_action/move_end_effector.py index 2194be8c..a7396e92 100644 --- a/scripts/tutorials/atomic_action/move_end_effector.py +++ b/scripts/tutorials/atomic_action/move_end_effector.py @@ -14,7 +14,7 @@ # limitations under the License. # ---------------------------------------------------------------------------- -"""Demonstrate MoveEndEffector with a single pose target.""" +"""Demonstrate MoveEndEffector with a multi-waypoint pose trajectory.""" from __future__ import annotations @@ -66,7 +66,7 @@ def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Demonstrate MoveEndEffector with a top-down target pose." + description="Demonstrate MoveEndEffector with a multi-waypoint pose trajectory." ) add_env_launcher_args_to_parser(parser) parser.add_argument( @@ -159,13 +159,29 @@ def make_top_down_eef_pose(device: torch.device) -> torch.Tensor: return pose +def make_side_eef_pose(device: torch.device) -> torch.Tensor: + """A second waypoint offset from the top-down pose for the multi-waypoint demo.""" + pose = torch.eye(4, dtype=torch.float32, device=device) + pose[:3, :3] = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022], + [-0.9977, 0.0540, -0.0401], + [0.0401, 0.0000, -0.9992], + ], + dtype=torch.float32, + device=device, + ) + pose[:3, 3] = torch.tensor([0.45, 0.10, 0.30], dtype=torch.float32, device=device) + return pose + + def format_tensor(tensor: torch.Tensor) -> str: rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0 return str(rounded.tolist()) def main() -> None: - """Move the robot end effector to one target pose using atomic actions.""" + """Move the robot end effector through a multi-waypoint pose trajectory.""" args = parse_arguments() # ------------------------------------------------------------------ # @@ -199,21 +215,38 @@ def main() -> None: # Step 5: Define and visualize the end-effector target # # ------------------------------------------------------------------ # target_pose = make_top_down_eef_pose(sim.device) + side_pose = make_side_eef_pose(sim.device) if not args.headless: sim.open_window() if not args.no_vis_eef_axis: draw_axis_marker(sim, "move_end_effector_target_axis", target_pose) + draw_axis_marker(sim, "move_end_effector_side_axis", side_pose) if not args.auto_play: input("Inspect the robot, then press Enter to plan MoveEndEffector...") # ------------------------------------------------------------------ # # Step 6: Plan the declared (name, typed_target) sequence # # ------------------------------------------------------------------ # + # Pass a multi-waypoint trajectory (n_envs, n_waypoint, 4, 4): the + # end-effector visits `target_pose` then `side_pose` in a single plan. + n_envs = robot.get_qpos().shape[0] + multi_waypoint_xpos = ( + torch.stack([target_pose, side_pose], dim=0) + .unsqueeze(0) + .repeat(n_envs, 1, 1, 1) + ) logger.log_info( - f"Planning MoveEndEffector to xpos={format_tensor(target_pose[:3, 3])}" + "Planning MoveEndEffector through multi-waypoint trajectory: " + f"xpos0={format_tensor(target_pose[:3, 3])} -> " + f"xpos1={format_tensor(side_pose[:3, 3])}" ) is_success, traj, _ = atomic_engine.run( - steps=[("move_end_effector", EndEffectorPoseTarget(xpos=target_pose))] + steps=[ + ( + "move_end_effector", + EndEffectorPoseTarget(xpos=multi_waypoint_xpos), + ) + ] ) if not is_success: logger.log_warning("Failed to plan MoveEndEffector demo trajectory.") diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py index ca3b0bf3..cf184796 100644 --- a/tests/sim/atomic_actions/test_actions.py +++ b/tests/sim/atomic_actions/test_actions.py @@ -148,6 +148,47 @@ def test_execute_returns_full_dof_trajectory(self): # MoveEndEffector preserves held_object. assert result.next_state.held_object is None + def test_execute_with_multi_waypoint_visits_each_waypoint(self): + action = MoveEndEffector(self.mg, MoveEndEffectorCfg(sample_interval=10)) + pose0 = torch.eye(4) + pose1 = torch.eye(4) + pose1[0, 3] = 1.0 + # (n_envs, n_waypoint, 4, 4) trajectory target + multi_xpos = ( + torch.stack([pose0, pose1], dim=0).unsqueeze(0).repeat(NUM_ENVS, 1, 1, 1) + ) + seen_poses = [] + + def compute_ik(pose=None, name=None, joint_seed=None, **kwargs): + seen_poses.append(pose.clone()) + return torch.ones(NUM_ENVS, dtype=torch.bool), joint_seed.clone() + + self.mg.robot.compute_ik = Mock(side_effect=compute_ik) + + captured = {} + + def interpolate(trajectory, interp_num, device): + captured["keyframes"] = trajectory + return trajectory[:, -1:, :].repeat(1, interp_num, 1) + + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + side_effect=interpolate, + ): + result = action.execute( + EndEffectorPoseTarget(xpos=multi_xpos), + WorldState(last_qpos=torch.zeros(NUM_ENVS, TOTAL_DOF)), + ) + + assert result.success is True + assert result.trajectory.shape == (NUM_ENVS, 10, TOTAL_DOF) + # Two waypoints -> two IK calls, in order. + assert len(seen_poses) == 2 + assert torch.allclose(seen_poses[0], pose0.unsqueeze(0).repeat(NUM_ENVS, 1, 1)) + assert torch.allclose(seen_poses[1], pose1.unsqueeze(0).repeat(NUM_ENVS, 1, 1)) + # start prepended to the two IK solutions -> 3 keyframes. + assert captured["keyframes"].shape == (NUM_ENVS, 3, ARM_DOF) + # --------------------------------------------------------------------------- # MoveJoints diff --git a/tests/sim/atomic_actions/test_trajectory.py b/tests/sim/atomic_actions/test_trajectory.py index 7390a40a..8177f050 100644 --- a/tests/sim/atomic_actions/test_trajectory.py +++ b/tests/sim/atomic_actions/test_trajectory.py @@ -96,6 +96,19 @@ def test_wrong_shape_raises(self): with pytest.raises(Exception): self.builder.resolve_pose_target(torch.eye(3), n_envs=2) + def test_multi_waypoint_passes_through(self): + pose = torch.eye(4).unsqueeze(0).unsqueeze(0).repeat(2, 3, 1, 1) + pose[0, 1, :3, 3] = torch.tensor([1.0, 0.0, 0.0]) + out = self.builder.resolve_pose_target(pose, n_envs=2) + assert out.shape == (2, 3, 4, 4) + assert torch.equal(out, pose.to(torch.float32)) + + def test_multi_waypoint_wrong_envs_raises(self): + with pytest.raises(Exception): + self.builder.resolve_pose_target( + torch.eye(4).unsqueeze(0).unsqueeze(0).repeat(3, 2, 1, 1), n_envs=2 + ) + class TestResolveJointTarget: def setup_method(self): From 3b5571606ded81e92f4a0f77eda2850c5f5bddca Mon Sep 17 00:00:00 2001 From: matafela Date: Tue, 30 Jun 2026 16:55:10 +0800 Subject: [PATCH 3/4] place action support multi waypoint --- embodichain/lab/sim/atomic_actions/actions.py | 39 ++++++++--- .../lab/sim/atomic_actions/trajectory.py | 12 ++++ scripts/tutorials/atomic_action/place.py | 42 ++++++++--- tests/sim/atomic_actions/test_actions.py | 69 +++++++++++++++++++ tests/sim/atomic_actions/test_trajectory.py | 12 ++++ 5 files changed, 152 insertions(+), 22 deletions(-) diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index 99904727..ada50294 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -658,7 +658,16 @@ def _fail(self, state: WorldState) -> ActionResult: class Place(AtomicAction): - """Lower the held object to a place pose, open the gripper, retract.""" + """Lower the held object to a place pose, open the gripper, retract. + + The :class:`EndEffectorPoseTarget` may carry either a single waypoint + ``(n_envs, 4, 4)`` (or a broadcastable ``(4, 4)``) or a multi-waypoint + trajectory ``(n_envs, n_waypoint, 4, 4)``. In the multi-waypoint case the + down phase visits every waypoint in order — approaching from above the + first waypoint, descending through each waypoint, then opening the gripper + at the final waypoint and retracting to above the last waypoint. Starting + joint positions are inherited from ``WorldState.last_qpos``. + """ TargetType: ClassVar[type] = EndEffectorPoseTarget @@ -686,6 +695,12 @@ def __init__( def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionResult: place_xpos = self.builder.resolve_pose_target(target.xpos, n_envs=self.n_envs) + # Normalize a single-waypoint (n_envs, 4, 4) target to (n_envs, 1, 4, 4) + # so the multi-waypoint descent path below is uniform. + if place_xpos.dim() == 3: + place_xpos = place_xpos.unsqueeze(1) + n_waypoint = place_xpos.shape[1] + start_arm_qpos = self.builder.resolve_start_qpos( _arm_qpos_from_state(state, self.arm_joint_ids, self.robot_dof), n_envs=self.n_envs, @@ -699,16 +714,18 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes third_phase_name="back", ) - lift_xpos = self.builder.apply_local_offset( - place_xpos, - torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, - ) + lift_offset = torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height + # Approach from above the first waypoint; retract to above the last. + # For a single waypoint these coincide, matching the legacy behavior. + approach_xpos = self.builder.apply_local_offset(place_xpos[:, 0], lift_offset) + retract_xpos = self.builder.apply_local_offset(place_xpos[:, -1], lift_offset) - # Phase 1: down (lift → place) + # Phase 1: down (approach → every place waypoint in order) target_states_list = [ - [ - PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), - PlanState(xpos=place_xpos[i], move_type=MoveType.EEF_MOVE), + [PlanState(xpos=approach_xpos[i], move_type=MoveType.EEF_MOVE)] + + [ + PlanState(xpos=place_xpos[i, j], move_type=MoveType.EEF_MOVE) + for j in range(n_waypoint) ] for i in range(self.n_envs) ] @@ -723,9 +740,9 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes return self._fail(state) reach_arm_qpos = down_arm[:, -1, :] - # Phase 3: back (retract to lift) + # Phase 3: back (retract to above the last waypoint) target_states_list = [ - [PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE)] + [PlanState(xpos=retract_xpos[i], move_type=MoveType.EEF_MOVE)] for i in range(self.n_envs) ] ok, back_arm = self.builder.plan_arm_traj( diff --git a/embodichain/lab/sim/atomic_actions/trajectory.py b/embodichain/lab/sim/atomic_actions/trajectory.py index ded8a2be..58c7a8bf 100644 --- a/embodichain/lab/sim/atomic_actions/trajectory.py +++ b/embodichain/lab/sim/atomic_actions/trajectory.py @@ -92,6 +92,12 @@ def resolve_pose_target(self, target: torch.Tensor, *, n_envs: int) -> torch.Ten f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}", ValueError, ) + if target.shape[1] == 0: + logger.log_error( + "multi-waypoint target tensor has zero waypoints (shape[1] == 0); " + "at least one waypoint is required.", + ValueError, + ) else: logger.log_error( f"target tensor must be (4, 4), ({n_envs}, 4, 4), or " @@ -166,6 +172,12 @@ def resolve_joint_target( f"({n_envs}, n_waypoint, {joint_dof}), but got {target_qpos.shape}", ValueError, ) + if target_qpos.shape[1] == 0: + logger.log_error( + f"multi-waypoint target qpos for '{control_part}' has zero waypoints " + f"(shape[1] == 0); at least one waypoint is required.", + ValueError, + ) else: logger.log_error( f"target qpos for '{control_part}' must be 1D, 2D, or 3D with " diff --git a/scripts/tutorials/atomic_action/place.py b/scripts/tutorials/atomic_action/place.py index f6ec66bc..91ba2895 100644 --- a/scripts/tutorials/atomic_action/place.py +++ b/scripts/tutorials/atomic_action/place.py @@ -304,19 +304,34 @@ def initialize_pre_pick_robot_pose( robot.clear_dynamics() -def make_place_eef_pose(device: torch.device) -> torch.Tensor: - pose = torch.eye(4, dtype=torch.float32, device=device) - pose[:3, :3] = torch.tensor( +def make_place_eef_poses(device: torch.device) -> torch.Tensor: + """Build a multi-waypoint place trajectory ``(n_waypoint, 4, 4)``. + + Two waypoints are returned: a higher hover pose and the final release pose. + ``Place`` approaches from above the first waypoint, descends through each + waypoint in order, opens the gripper at the last, and retracts — so this + exercises the multi-waypoint descent path. + """ + rotation = torch.tensor( [ - [-0.0539, -0.9985, -0.0022], - [-0.9977, 0.0540, -0.0401], - [0.0401, 0.0000, -0.9992], + [0.0539, 0.9985, -0.0022], + [0.9977, -0.0540, -0.0401], + [-0.0401, -0.0000, -0.9992], ], dtype=torch.float32, device=device, ) - pose[:3, 3] = torch.tensor([-0.20, 0.28, 0.1], dtype=torch.float32, device=device) - return pose + hover_pose = torch.eye(4, dtype=torch.float32, device=device) + hover_pose[:3, :3] = rotation + hover_pose[:3, 3] = torch.tensor( + [-0.40, 0.48, 0.20], dtype=torch.float32, device=device + ) + place_pose = torch.eye(4, dtype=torch.float32, device=device) + place_pose[:3, :3] = rotation + place_pose[:3, 3] = torch.tensor( + [-0.40, 0.48, 0.10], dtype=torch.float32, device=device + ) + return torch.stack([hover_pose, place_pose], dim=0) def compute_pick_close_end_step() -> int: @@ -385,17 +400,22 @@ def main() -> None: # Step 5: Describe the object and define the place target # # ------------------------------------------------------------------ # semantics = create_object_semantics(obj, args) - place_eef_pose = make_place_eef_pose(sim.device) + place_eef_poses = make_place_eef_poses(sim.device) if not args.no_vis_eef_axis: - draw_axis_marker(sim, "place_target_axis", place_eef_pose) + draw_axis_marker(sim, "place_target_axis", place_eef_poses[-1]) if not args.auto_play: input("Inspect the object, then press Enter to plan PickUp -> Place...") # ------------------------------------------------------------------ # # Step 6: Plan the declared (name, typed_target) sequence # # ------------------------------------------------------------------ # - place_target = EndEffectorPoseTarget(xpos=place_eef_pose) + # Pass a multi-waypoint trajectory (n_envs, n_waypoint, 4, 4): Place + # approaches from above the first waypoint, descends through each + # waypoint in order, opens the gripper at the last, and retracts. + n_envs = robot.get_qpos().shape[0] + multi_waypoint_xpos = place_eef_poses.unsqueeze(0).repeat(n_envs, 1, 1, 1) + place_target = EndEffectorPoseTarget(xpos=multi_waypoint_xpos) logger.log_info("Planning PickUp precondition -> Place release trajectory") is_success, traj, _ = atomic_engine.run( steps=[ diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py index cf184796..c593e036 100644 --- a/tests/sim/atomic_actions/test_actions.py +++ b/tests/sim/atomic_actions/test_actions.py @@ -466,3 +466,72 @@ def test_execute_clears_held_object(self): assert result.success is True assert result.trajectory.shape[2] == TOTAL_DOF assert result.next_state.held_object is None + + def test_execute_with_multi_waypoint_visits_each_waypoint(self): + cfg = PlaceCfg( + hand_open_qpos=_hand_open(), + hand_close_qpos=_hand_close(), + sample_interval=20, + hand_interp_steps=4, + lift_height=0.1, + ) + action = Place(self.mg, cfg) + sem = ObjectSemantics( + affordance=AntipodalAffordance(), geometry={}, label="mug" + ) + held = HeldObjectState( + semantics=sem, + object_to_eef=torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1), + grasp_xpos=torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1), + ) + state = WorldState(last_qpos=torch.zeros(NUM_ENVS, TOTAL_DOF), held_object=held) + + pose0 = torch.eye(4) + pose1 = torch.eye(4) + pose1[0, 3] = 1.0 + # (n_envs, n_waypoint, 4, 4) trajectory target + multi_xpos = ( + torch.stack([pose0, pose1], dim=0).unsqueeze(0).repeat(NUM_ENVS, 1, 1, 1) + ) + seen_poses = [] + + def compute_ik(pose=None, name=None, joint_seed=None, **kwargs): + seen_poses.append(pose.clone()) + return torch.ones(NUM_ENVS, dtype=torch.bool), joint_seed.clone() + + self.mg.robot.compute_ik = Mock(side_effect=compute_ik) + + captured = {} + + def interpolate(trajectory, interp_num, device): + # Only the down phase carries more than 2 keyframes; capture it. + if trajectory.shape[1] > 2: + captured["down_keyframes"] = trajectory + return trajectory[:, -1:, :].repeat(1, interp_num, 1) + + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + side_effect=interpolate, + ): + result = action.execute(EndEffectorPoseTarget(xpos=multi_xpos), state) + + assert result.success is True + assert result.trajectory.shape[2] == TOTAL_DOF + assert result.next_state.held_object is None + # IK order: down phase (approach, pose0, pose1) then back phase (retract). + assert len(seen_poses) == 4 + lift_height = cfg.lift_height + approach = pose0.clone() + approach[2, 3] += lift_height + retract = pose1.clone() + retract[2, 3] += lift_height + assert torch.allclose( + seen_poses[0], approach.unsqueeze(0).repeat(NUM_ENVS, 1, 1) + ) + assert torch.allclose(seen_poses[1], pose0.unsqueeze(0).repeat(NUM_ENVS, 1, 1)) + assert torch.allclose(seen_poses[2], pose1.unsqueeze(0).repeat(NUM_ENVS, 1, 1)) + assert torch.allclose( + seen_poses[3], retract.unsqueeze(0).repeat(NUM_ENVS, 1, 1) + ) + # start prepended to the 3 down-phase IK solutions -> 4 keyframes. + assert captured["down_keyframes"].shape == (NUM_ENVS, 4, ARM_DOF) diff --git a/tests/sim/atomic_actions/test_trajectory.py b/tests/sim/atomic_actions/test_trajectory.py index 8177f050..006a713c 100644 --- a/tests/sim/atomic_actions/test_trajectory.py +++ b/tests/sim/atomic_actions/test_trajectory.py @@ -109,6 +109,11 @@ def test_multi_waypoint_wrong_envs_raises(self): torch.eye(4).unsqueeze(0).unsqueeze(0).repeat(3, 2, 1, 1), n_envs=2 ) + def test_multi_waypoint_empty_raises(self): + empty = torch.zeros((2, 0, 4, 4), dtype=torch.float32) + with pytest.raises(ValueError, match="zero waypoints"): + self.builder.resolve_pose_target(empty, n_envs=2) + class TestResolveJointTarget: def setup_method(self): @@ -156,6 +161,13 @@ def test_multi_waypoint_wrong_dof_raises(self): torch.zeros(2, 2, 5), n_envs=2, joint_dof=6, control_part="arm" ) + def test_multi_waypoint_empty_raises(self): + empty = torch.zeros((2, 0, 6), dtype=torch.float32) + with pytest.raises(ValueError, match="zero waypoints"): + self.builder.resolve_joint_target( + empty, n_envs=2, joint_dof=6, control_part="arm" + ) + class TestSplitThreePhase: def setup_method(self): From a3ad7d035bd6f637ccf27c756f8a03d0b2cb58c5 Mon Sep 17 00:00:00 2001 From: matafela Date: Tue, 30 Jun 2026 17:03:38 +0800 Subject: [PATCH 4/4] update document --- .../source/overview/sim/atomic_actions/builtin_actions.md | 8 +++----- docs/source/overview/sim/atomic_actions/index.md | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/source/overview/sim/atomic_actions/builtin_actions.md b/docs/source/overview/sim/atomic_actions/builtin_actions.md index 4507d28b..01320709 100644 --- a/docs/source/overview/sim/atomic_actions/builtin_actions.md +++ b/docs/source/overview/sim/atomic_actions/builtin_actions.md @@ -26,8 +26,7 @@ Moves the end-effector to a target pose in free space. | `control_part` | `"arm"` | Robot control part to move | | `sample_interval` | `50` | Number of waypoints in the trajectory | -**Target:** `EndEffectorPoseTarget(xpos=...)` where `xpos` is a `torch.Tensor` of shape `(4, 4)` or -`(n_envs, 4, 4)` — a homogeneous EEF pose. +**Target:** `EndEffectorPoseTarget(xpos=...)` where `xpos` is a `torch.Tensor` of shape `(4, 4)`, `(n_envs, 4, 4)` or `(n_envs, n_waypoint, 4, 4)` — a homogeneous EEF pose. ![MoveEndEffector demo](../../../_static/atomic_actions/move_end_effector.gif) @@ -45,8 +44,7 @@ home poses, recovery motions, or any motion where a qpos target is clearer than | `named_joint_positions` | `None` | Optional `dict[str, torch.Tensor]` for named qpos targets | **Targets:** -- `JointPositionTarget(qpos=...)` where `qpos` is a `torch.Tensor` of shape `(control_dof,)` or - `(n_envs, control_dof)`. +- `JointPositionTarget(qpos=...)` where `qpos` is a `torch.Tensor` of shape `(control_dof,)`, `(n_envs, control_dof)` or `(n_envs, n_waypoint, control_dof)`. - `NamedJointPositionTarget(name=...)` where `name` is resolved from `MoveJointsCfg.named_joint_positions`. @@ -120,6 +118,6 @@ down to the target pose. On success, the returned `WorldState` clears `held_obje | `sample_interval` | `80` | Total waypoints across all three phases | **Target:** `EndEffectorPoseTarget(xpos=...)` — the EEF pose at release, a `torch.Tensor` of shape -`(4, 4)` or `(n_envs, 4, 4)`. +`(4, 4)`, `(n_envs, 4, 4)` or `(n_envs, n_waypoint, 4, 4)`. ![Place demo](../../../_static/atomic_actions/place.gif) diff --git a/docs/source/overview/sim/atomic_actions/index.md b/docs/source/overview/sim/atomic_actions/index.md index cee58179..119021bd 100644 --- a/docs/source/overview/sim/atomic_actions/index.md +++ b/docs/source/overview/sim/atomic_actions/index.md @@ -94,8 +94,8 @@ action's `TargetType` before calling `execute`: | Target | Holds | Accepted by | |---|---|---| -| `EndEffectorPoseTarget(xpos)` | EEF pose tensor `(4,4)` or `(n_envs,4,4)` | `MoveEndEffector`, `Place` | -| `JointPositionTarget(qpos)` | Control-part qpos tensor `(control_dof,)` or `(n_envs, control_dof)` | `MoveJoints` | +| `EndEffectorPoseTarget(xpos)` | EEF pose tensor `(4,4)`, `(n_envs,4,4)` or `(n_envs, n_waypoint, 4, 4)` | `MoveEndEffector`, `Place` | +| `JointPositionTarget(qpos)` | Control-part qpos tensor `(control_dof,)`, `(n_envs, control_dof)` or `(n_envs, n_waypoint, control_dof)` | `MoveJoints` | | `NamedJointPositionTarget(name)` | Name resolved from `MoveJointsCfg.named_joint_positions` | `MoveJoints` | | `GraspTarget(semantics)` | `ObjectSemantics` (affordance + entity) | `PickUp` | | `HeldObjectPoseTarget(object_target_pose)` | Desired held-object pose tensor | `MoveHeldObject` |