diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index dd7991d9..99904727 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -177,7 +177,15 @@ def _arm_qpos_from_state( class MoveEndEffector(AtomicAction): - """Plan a free-space end-effector move to a target pose.""" + """Plan a free-space end-effector move to a target pose. + + The :class:`EndEffectorPoseTarget` may carry either a single waypoint + ``(n_envs, 4, 4)`` (or a broadcastable ``(4, 4)``) or a multi-waypoint + trajectory ``(n_envs, n_waypoint, 4, 4)``. In the multi-waypoint case the + action plans a single trajectory that visits every waypoint in order, + starting from the inherited ``WorldState.last_qpos`` — IK is solved for each + waypoint with the previous waypoint's solution as the seed. + """ TargetType: ClassVar[type] = EndEffectorPoseTarget @@ -201,10 +209,7 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes arm_dof=self.arm_dof, control_part=self.cfg.control_part, ) - target_states_list = [ - [PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE)] - for i in range(self.n_envs) - ] + target_states_list = self._build_target_states(move_xpos) ok, arm_traj = self.builder.plan_arm_traj( target_states_list, start_qpos, @@ -223,6 +228,23 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes ), ) + def _build_target_states(self, move_xpos: torch.Tensor) -> list[list[PlanState]]: + """Build per-env PlanState lists from a single- or multi-waypoint target. + + ``move_xpos`` is the resolved target: 3D ``(n_envs, 4, 4)`` for a single + waypoint or 4D ``(n_envs, n_waypoint, 4, 4)`` for a trajectory. + """ + if move_xpos.dim() == 3: + move_xpos = move_xpos.unsqueeze(1) + n_waypoint = move_xpos.shape[1] + return [ + [ + PlanState(xpos=move_xpos[i, j], move_type=MoveType.EEF_MOVE) + for j in range(n_waypoint) + ] + for i in range(self.n_envs) + ] + def _embed( self, arm_traj: torch.Tensor, last_full_qpos: torch.Tensor ) -> torch.Tensor: @@ -254,7 +276,14 @@ def _fail(self, state: WorldState) -> ActionResult: class MoveJoints(AtomicAction): - """Plan a joint-space move for the configured control part.""" + """Plan a joint-space move for the configured control part. + + The :class:`JointPositionTarget` may carry either a single waypoint + ``(n_envs, control_dof)`` or a multi-waypoint trajectory + ``(n_envs, n_waypoint, control_dof)``. In the multi-waypoint case the + action plans a single trajectory that visits every waypoint in order, + starting from the inherited ``WorldState.last_qpos``. + """ TargetType: ClassVar[tuple[type, ...]] = ( JointPositionTarget, diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py index 5d378ccb..0f9f2909 100644 --- a/embodichain/lab/sim/atomic_actions/core.py +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -71,7 +71,14 @@ class EndEffectorPoseTarget: """End-effector pose target. Used by MoveEndEffector and Place.""" xpos: torch.Tensor - """(4, 4) or (n_envs, 4, 4) homogeneous transform.""" + """Target end-effector homogeneous transform. + + Accepts: + + - ``(4, 4)`` or ``(n_envs, 4, 4)`` — a single waypoint. + - ``(n_envs, n_waypoint, 4, 4)`` — a multi-waypoint trajectory; waypoints + are visited in order. (Only consumed as multi-waypoint by MoveEndEffector.) + """ @dataclass(frozen=True) @@ -79,7 +86,14 @@ class JointPositionTarget: """Joint-space target for a configured robot control part.""" qpos: torch.Tensor - """(control_dof,) or (n_envs, control_dof) target joint positions.""" + """Target joint positions. + + Accepts: + + - ``(control_dof,)`` or ``(n_envs, control_dof)`` — a single waypoint. + - ``(n_envs, n_waypoint, control_dof)`` — a multi-waypoint trajectory; + waypoints are visited in order. + """ @dataclass(frozen=True) diff --git a/embodichain/lab/sim/atomic_actions/trajectory.py b/embodichain/lab/sim/atomic_actions/trajectory.py index be584cbc..ded8a2be 100644 --- a/embodichain/lab/sim/atomic_actions/trajectory.py +++ b/embodichain/lab/sim/atomic_actions/trajectory.py @@ -57,19 +57,45 @@ def all_envs_success(self, is_success: bool | torch.Tensor) -> bool: return bool(is_success) def resolve_pose_target(self, target: torch.Tensor, *, n_envs: int) -> torch.Tensor: - """Broadcast a (4, 4) pose to (n_envs, 4, 4) or validate batched shape.""" + """Resolve an end-effector pose target into batched homogeneous transforms. + + Accepts the following shapes for ``target``: + + - ``(4, 4)`` — broadcast to ``(n_envs, 4, 4)`` (single waypoint). + - ``(n_envs, 4, 4)`` — single waypoint, validated and passed through. + - ``(n_envs, n_waypoint, 4, 4)`` — a multi-waypoint trajectory; each + waypoint is visited in order. ``n_waypoint`` may be 1. + + Returns a 3D tensor for single-waypoint inputs and a 4D tensor for + multi-waypoint inputs. + """ if not isinstance(target, torch.Tensor): logger.log_error( - f"target must be torch.Tensor of shape (4, 4) or ({n_envs}, 4, 4)", + f"target must be torch.Tensor of shape (4, 4), ({n_envs}, 4, 4), " + f"or ({n_envs}, n_waypoint, 4, 4)", TypeError, ) target = target.to(device=self.device, dtype=torch.float32) if target.shape == (4, 4): target = target.unsqueeze(0).repeat(n_envs, 1, 1) - if target.shape != (n_envs, 4, 4): + if target.dim() == 3: + if target.shape != (n_envs, 4, 4): + logger.log_error( + f"target tensor must have shape (4, 4) or ({n_envs}, 4, 4), " + f"but got {target.shape}", + ValueError, + ) + elif target.dim() == 4: + if target.shape[0] != n_envs or target.shape[2:] != (4, 4): + logger.log_error( + f"multi-waypoint target tensor must have shape " + f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}", + ValueError, + ) + else: logger.log_error( - f"target tensor must have shape (4, 4) or ({n_envs}, 4, 4), " - f"but got {target.shape}", + f"target tensor must be (4, 4), ({n_envs}, 4, 4), or " + f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}", ValueError, ) return target @@ -103,20 +129,47 @@ def resolve_joint_target( joint_dof: int, control_part: str, ) -> torch.Tensor: - """Resolve a joint-space target into batched control-part joint positions.""" + """Resolve a joint-space target into batched control-part joint positions. + + Accepts the following shapes for ``target_qpos``: + + - ``(joint_dof,)`` — broadcast to ``(n_envs, joint_dof)`` (single waypoint). + - ``(n_envs, joint_dof)`` — single waypoint, validated and passed through. + - ``(n_envs, n_waypoint, joint_dof)`` — a multi-waypoint trajectory; each + waypoint is visited in order. ``n_waypoint`` may be 1. + + Returns a 2D tensor for single-waypoint inputs and a 3D tensor for + multi-waypoint inputs, leaving downstream planners to treat the trailing + axis as the joint dimension. + """ if not isinstance(target_qpos, torch.Tensor): logger.log_error( f"target qpos for '{control_part}' must be a torch.Tensor with shape " - f"({joint_dof},) or ({n_envs}, {joint_dof})", + f"({joint_dof},), ({n_envs}, {joint_dof}), or " + f"({n_envs}, n_waypoint, {joint_dof})", TypeError, ) target_qpos = target_qpos.to(device=self.device, dtype=torch.float32) if target_qpos.shape == (joint_dof,): target_qpos = target_qpos.unsqueeze(0).repeat(n_envs, 1) - if target_qpos.shape != (n_envs, joint_dof): + if target_qpos.dim() == 2: + if target_qpos.shape != (n_envs, joint_dof): + logger.log_error( + f"target qpos for '{control_part}' must have shape ({joint_dof},) " + f"or ({n_envs}, {joint_dof}), but got {target_qpos.shape}", + ValueError, + ) + elif target_qpos.dim() == 3: + if target_qpos.shape[0] != n_envs or target_qpos.shape[2] != joint_dof: + logger.log_error( + f"multi-waypoint target qpos for '{control_part}' must have shape " + f"({n_envs}, n_waypoint, {joint_dof}), but got {target_qpos.shape}", + ValueError, + ) + else: logger.log_error( - f"target qpos for '{control_part}' must have shape ({joint_dof},) " - f"or ({n_envs}, {joint_dof}), but got {target_qpos.shape}", + f"target qpos for '{control_part}' must be 1D, 2D, or 3D with " + f"trailing dim {joint_dof}, but got {target_qpos.shape}", ValueError, ) return target_qpos @@ -291,10 +344,21 @@ def plan_joint_traj( target_qpos: torch.Tensor, n_waypoints: int, ) -> torch.Tensor: - """Interpolate a joint-space trajectory from ``start_qpos`` to ``target_qpos``.""" - trajectory = torch.stack([start_qpos, target_qpos], dim=1) + """Interpolate a joint-space trajectory through one or more target waypoints. + + ``start_qpos`` has shape ``(n_envs, joint_dof)``. ``target_qpos`` is + either a single waypoint ``(n_envs, joint_dof)`` or a sequence of + waypoints ``(n_envs, n_waypoint, joint_dof)``. The start configuration is + prepended to the target waypoints to build the keyframe sequence + ``(n_envs, 1 + n_waypoint, joint_dof)``, which is then resampled to + ``n_waypoints`` output samples by cumulative-distance piecewise-linear + interpolation — so each consecutive waypoint pair is traversed in turn. + """ + if target_qpos.dim() == 2: + target_qpos = target_qpos.unsqueeze(1) + keyframes = torch.cat([start_qpos.unsqueeze(1), target_qpos], dim=1) return interpolate_with_distance( - trajectory=trajectory, interp_num=n_waypoints, device=self.device + trajectory=keyframes, interp_num=n_waypoints, device=self.device ) # ------------------------------------------------------------------ diff --git a/scripts/tutorials/atomic_action/move_end_effector.py b/scripts/tutorials/atomic_action/move_end_effector.py index 2194be8c..a7396e92 100644 --- a/scripts/tutorials/atomic_action/move_end_effector.py +++ b/scripts/tutorials/atomic_action/move_end_effector.py @@ -14,7 +14,7 @@ # limitations under the License. # ---------------------------------------------------------------------------- -"""Demonstrate MoveEndEffector with a single pose target.""" +"""Demonstrate MoveEndEffector with a multi-waypoint pose trajectory.""" from __future__ import annotations @@ -66,7 +66,7 @@ def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Demonstrate MoveEndEffector with a top-down target pose." + description="Demonstrate MoveEndEffector with a multi-waypoint pose trajectory." ) add_env_launcher_args_to_parser(parser) parser.add_argument( @@ -159,13 +159,29 @@ def make_top_down_eef_pose(device: torch.device) -> torch.Tensor: return pose +def make_side_eef_pose(device: torch.device) -> torch.Tensor: + """A second waypoint offset from the top-down pose for the multi-waypoint demo.""" + pose = torch.eye(4, dtype=torch.float32, device=device) + pose[:3, :3] = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022], + [-0.9977, 0.0540, -0.0401], + [0.0401, 0.0000, -0.9992], + ], + dtype=torch.float32, + device=device, + ) + pose[:3, 3] = torch.tensor([0.45, 0.10, 0.30], dtype=torch.float32, device=device) + return pose + + def format_tensor(tensor: torch.Tensor) -> str: rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0 return str(rounded.tolist()) def main() -> None: - """Move the robot end effector to one target pose using atomic actions.""" + """Move the robot end effector through a multi-waypoint pose trajectory.""" args = parse_arguments() # ------------------------------------------------------------------ # @@ -199,21 +215,38 @@ def main() -> None: # Step 5: Define and visualize the end-effector target # # ------------------------------------------------------------------ # target_pose = make_top_down_eef_pose(sim.device) + side_pose = make_side_eef_pose(sim.device) if not args.headless: sim.open_window() if not args.no_vis_eef_axis: draw_axis_marker(sim, "move_end_effector_target_axis", target_pose) + draw_axis_marker(sim, "move_end_effector_side_axis", side_pose) if not args.auto_play: input("Inspect the robot, then press Enter to plan MoveEndEffector...") # ------------------------------------------------------------------ # # Step 6: Plan the declared (name, typed_target) sequence # # ------------------------------------------------------------------ # + # Pass a multi-waypoint trajectory (n_envs, n_waypoint, 4, 4): the + # end-effector visits `target_pose` then `side_pose` in a single plan. + n_envs = robot.get_qpos().shape[0] + multi_waypoint_xpos = ( + torch.stack([target_pose, side_pose], dim=0) + .unsqueeze(0) + .repeat(n_envs, 1, 1, 1) + ) logger.log_info( - f"Planning MoveEndEffector to xpos={format_tensor(target_pose[:3, 3])}" + "Planning MoveEndEffector through multi-waypoint trajectory: " + f"xpos0={format_tensor(target_pose[:3, 3])} -> " + f"xpos1={format_tensor(side_pose[:3, 3])}" ) is_success, traj, _ = atomic_engine.run( - steps=[("move_end_effector", EndEffectorPoseTarget(xpos=target_pose))] + steps=[ + ( + "move_end_effector", + EndEffectorPoseTarget(xpos=multi_waypoint_xpos), + ) + ] ) if not is_success: logger.log_warning("Failed to plan MoveEndEffector demo trajectory.") diff --git a/scripts/tutorials/atomic_action/move_joints.py b/scripts/tutorials/atomic_action/move_joints.py index 32cfc44d..e1ca0bd2 100644 --- a/scripts/tutorials/atomic_action/move_joints.py +++ b/scripts/tutorials/atomic_action/move_joints.py @@ -179,6 +179,7 @@ def main() -> None: # Step 3: Configure the MoveJoints atomic action # # ------------------------------------------------------------------ # ready_qpos = make_arm_qpos([0.35, -1.20, 1.30, -1.65, -1.57, 0.20], sim.device) + mid_qpos = make_arm_qpos([0.15, -1.40, 1.45, -1.60, -1.57, 0.10], sim.device) home_qpos = make_arm_qpos([0.0, -1.57, 1.57, -1.57, -1.57, 0.0], sim.device) move_joints_cfg = MoveJointsCfg( control_part="arm", @@ -205,13 +206,21 @@ def main() -> None: # ------------------------------------------------------------------ # # Step 6: Plan the declared (name, typed_target) sequence # # ------------------------------------------------------------------ # + # The second MoveJoints step passes a multi-waypoint trajectory + # (n_envs, n_waypoint, control_dof): the arm visits `mid_qpos` then + # `home_qpos` in a single planned trajectory. + n_envs = robot.get_qpos().shape[0] + multi_waypoint_qpos = ( + torch.stack([mid_qpos, home_qpos], dim=0).unsqueeze(0).repeat(n_envs, 1, 1) + ) logger.log_info( - "Planning MoveJoints: NamedJointPositionTarget('ready') -> explicit home qpos" + "Planning MoveJoints: NamedJointPositionTarget('ready') -> " + "multi-waypoint trajectory (mid -> home)" ) is_success, traj, _ = atomic_engine.run( steps=[ ("move_joints", NamedJointPositionTarget(name="ready")), - ("move_joints", JointPositionTarget(qpos=home_qpos)), + ("move_joints", JointPositionTarget(qpos=multi_waypoint_qpos)), ] ) if not is_success: diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py index fb5a1bf0..cf184796 100644 --- a/tests/sim/atomic_actions/test_actions.py +++ b/tests/sim/atomic_actions/test_actions.py @@ -148,6 +148,47 @@ def test_execute_returns_full_dof_trajectory(self): # MoveEndEffector preserves held_object. assert result.next_state.held_object is None + def test_execute_with_multi_waypoint_visits_each_waypoint(self): + action = MoveEndEffector(self.mg, MoveEndEffectorCfg(sample_interval=10)) + pose0 = torch.eye(4) + pose1 = torch.eye(4) + pose1[0, 3] = 1.0 + # (n_envs, n_waypoint, 4, 4) trajectory target + multi_xpos = ( + torch.stack([pose0, pose1], dim=0).unsqueeze(0).repeat(NUM_ENVS, 1, 1, 1) + ) + seen_poses = [] + + def compute_ik(pose=None, name=None, joint_seed=None, **kwargs): + seen_poses.append(pose.clone()) + return torch.ones(NUM_ENVS, dtype=torch.bool), joint_seed.clone() + + self.mg.robot.compute_ik = Mock(side_effect=compute_ik) + + captured = {} + + def interpolate(trajectory, interp_num, device): + captured["keyframes"] = trajectory + return trajectory[:, -1:, :].repeat(1, interp_num, 1) + + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + side_effect=interpolate, + ): + result = action.execute( + EndEffectorPoseTarget(xpos=multi_xpos), + WorldState(last_qpos=torch.zeros(NUM_ENVS, TOTAL_DOF)), + ) + + assert result.success is True + assert result.trajectory.shape == (NUM_ENVS, 10, TOTAL_DOF) + # Two waypoints -> two IK calls, in order. + assert len(seen_poses) == 2 + assert torch.allclose(seen_poses[0], pose0.unsqueeze(0).repeat(NUM_ENVS, 1, 1)) + assert torch.allclose(seen_poses[1], pose1.unsqueeze(0).repeat(NUM_ENVS, 1, 1)) + # start prepended to the two IK solutions -> 3 keyframes. + assert captured["keyframes"].shape == (NUM_ENVS, 3, ARM_DOF) + # --------------------------------------------------------------------------- # MoveJoints @@ -221,6 +262,51 @@ def test_execute_with_named_qpos_resolves_cfg_target(self): torch.full((NUM_ENVS, ARM_DOF), 0.2), ) + def test_execute_with_multi_waypoint_qpos_visits_each_waypoint(self): + action = MoveJoints(self.mg, MoveJointsCfg(sample_interval=10)) + # (n_envs, n_waypoint, control_dof) trajectory target + waypoint_qpos = ( + torch.stack( + [ + torch.full((ARM_DOF,), 0.3), + torch.full((ARM_DOF,), 0.7), + ], + dim=0, + ) + .unsqueeze(0) + .repeat(NUM_ENVS, 1, 1) + ) + last_qpos = torch.zeros(NUM_ENVS, TOTAL_DOF) + + captured = {} + + def interpolate(trajectory, interp_num, device): + captured["keyframes"] = trajectory + return trajectory[:, -1:, :].repeat(1, interp_num, 1) + + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + side_effect=interpolate, + ): + result = action.execute( + JointPositionTarget(qpos=waypoint_qpos), + WorldState(last_qpos=last_qpos), + ) + + assert result.success is True + assert result.trajectory.shape == (NUM_ENVS, 10, TOTAL_DOF) + # start prepended to the two waypoints -> 3 keyframes + keyframes = captured["keyframes"] + assert keyframes.shape == (NUM_ENVS, 3, ARM_DOF) + assert torch.allclose(keyframes[:, 0, :], torch.zeros(NUM_ENVS, ARM_DOF)) + assert torch.allclose(keyframes[:, 1, :], torch.full((NUM_ENVS, ARM_DOF), 0.3)) + assert torch.allclose(keyframes[:, 2, :], torch.full((NUM_ENVS, ARM_DOF), 0.7)) + # final state lands on the last waypoint + assert torch.allclose( + result.next_state.last_qpos[:, :ARM_DOF], + torch.full((NUM_ENVS, ARM_DOF), 0.7), + ) + def test_unknown_named_qpos_raises(self): action = MoveJoints(self.mg, MoveJointsCfg()) with pytest.raises(KeyError, match="missing"): diff --git a/tests/sim/atomic_actions/test_trajectory.py b/tests/sim/atomic_actions/test_trajectory.py index ef28ab8b..8177f050 100644 --- a/tests/sim/atomic_actions/test_trajectory.py +++ b/tests/sim/atomic_actions/test_trajectory.py @@ -96,6 +96,19 @@ def test_wrong_shape_raises(self): with pytest.raises(Exception): self.builder.resolve_pose_target(torch.eye(3), n_envs=2) + def test_multi_waypoint_passes_through(self): + pose = torch.eye(4).unsqueeze(0).unsqueeze(0).repeat(2, 3, 1, 1) + pose[0, 1, :3, 3] = torch.tensor([1.0, 0.0, 0.0]) + out = self.builder.resolve_pose_target(pose, n_envs=2) + assert out.shape == (2, 3, 4, 4) + assert torch.equal(out, pose.to(torch.float32)) + + def test_multi_waypoint_wrong_envs_raises(self): + with pytest.raises(Exception): + self.builder.resolve_pose_target( + torch.eye(4).unsqueeze(0).unsqueeze(0).repeat(3, 2, 1, 1), n_envs=2 + ) + class TestResolveJointTarget: def setup_method(self): @@ -123,6 +136,26 @@ def test_wrong_shape_raises(self): torch.zeros(5), n_envs=2, joint_dof=6, control_part="arm" ) + def test_multi_waypoint_passes_through(self): + qpos = torch.arange(24, dtype=torch.float32).reshape(2, 2, 6) + out = self.builder.resolve_joint_target( + qpos, n_envs=2, joint_dof=6, control_part="arm" + ) + assert out.shape == (2, 2, 6) + assert torch.equal(out, qpos.to(torch.float32)) + + def test_multi_waypoint_wrong_envs_raises(self): + with pytest.raises(Exception): + self.builder.resolve_joint_target( + torch.zeros(3, 2, 6), n_envs=2, joint_dof=6, control_part="arm" + ) + + def test_multi_waypoint_wrong_dof_raises(self): + with pytest.raises(Exception): + self.builder.resolve_joint_target( + torch.zeros(2, 2, 5), n_envs=2, joint_dof=6, control_part="arm" + ) + class TestSplitThreePhase: def setup_method(self): @@ -212,6 +245,25 @@ def test_interpolates_start_to_target(self): assert torch.equal(kwargs["trajectory"][:, 0, :], start) assert torch.equal(kwargs["trajectory"][:, 1, :], target) + def test_interpolates_start_through_multi_waypoints(self): + start = torch.zeros(2, 6) + waypoints = torch.arange(24, dtype=torch.float32).reshape(2, 2, 6) + expected = torch.ones(2, 5, 6) + with patch( + "embodichain.lab.sim.atomic_actions.trajectory.interpolate_with_distance", + return_value=expected, + ) as interpolate: + out = self.builder.plan_joint_traj(start, waypoints, n_waypoints=5) + + assert out is expected + _, kwargs = interpolate.call_args + assert kwargs["interp_num"] == 5 + # start prepended, then every waypoint in order + assert kwargs["trajectory"].shape == (2, 3, 6) + assert torch.equal(kwargs["trajectory"][:, 0, :], start) + assert torch.equal(kwargs["trajectory"][:, 1, :], waypoints[:, 0, :]) + assert torch.equal(kwargs["trajectory"][:, 2, :], waypoints[:, 1, :]) + class TestIkSolve: def test_uses_first_env_seed_for_single_pose(self):