Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions embodichain/lab/sim/atomic_actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,15 @@ def _arm_qpos_from_state(


class MoveEndEffector(AtomicAction):
"""Plan a free-space end-effector move to a target pose."""
"""Plan a free-space end-effector move to a target pose.

The :class:`EndEffectorPoseTarget` may carry either a single waypoint
``(n_envs, 4, 4)`` (or a broadcastable ``(4, 4)``) or a multi-waypoint
trajectory ``(n_envs, n_waypoint, 4, 4)``. In the multi-waypoint case the
action plans a single trajectory that visits every waypoint in order,
starting from the inherited ``WorldState.last_qpos`` — IK is solved for each
waypoint with the previous waypoint's solution as the seed.
"""

TargetType: ClassVar[type] = EndEffectorPoseTarget

Expand All @@ -201,10 +209,7 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes
arm_dof=self.arm_dof,
control_part=self.cfg.control_part,
)
target_states_list = [
[PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE)]
for i in range(self.n_envs)
]
target_states_list = self._build_target_states(move_xpos)
ok, arm_traj = self.builder.plan_arm_traj(
target_states_list,
start_qpos,
Expand All @@ -223,6 +228,23 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes
),
)

def _build_target_states(self, move_xpos: torch.Tensor) -> list[list[PlanState]]:
"""Build per-env PlanState lists from a single- or multi-waypoint target.

``move_xpos`` is the resolved target: 3D ``(n_envs, 4, 4)`` for a single
waypoint or 4D ``(n_envs, n_waypoint, 4, 4)`` for a trajectory.
"""
if move_xpos.dim() == 3:
move_xpos = move_xpos.unsqueeze(1)
n_waypoint = move_xpos.shape[1]
return [
[
PlanState(xpos=move_xpos[i, j], move_type=MoveType.EEF_MOVE)
for j in range(n_waypoint)
]
for i in range(self.n_envs)
]

def _embed(
self, arm_traj: torch.Tensor, last_full_qpos: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -254,7 +276,14 @@ def _fail(self, state: WorldState) -> ActionResult:


class MoveJoints(AtomicAction):
"""Plan a joint-space move for the configured control part."""
"""Plan a joint-space move for the configured control part.

The :class:`JointPositionTarget` may carry either a single waypoint
``(n_envs, control_dof)`` or a multi-waypoint trajectory
``(n_envs, n_waypoint, control_dof)``. In the multi-waypoint case the
action plans a single trajectory that visits every waypoint in order,
starting from the inherited ``WorldState.last_qpos``.
"""

TargetType: ClassVar[tuple[type, ...]] = (
JointPositionTarget,
Expand Down
18 changes: 16 additions & 2 deletions embodichain/lab/sim/atomic_actions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,29 @@ class EndEffectorPoseTarget:
"""End-effector pose target. Used by MoveEndEffector and Place."""

xpos: torch.Tensor
"""(4, 4) or (n_envs, 4, 4) homogeneous transform."""
"""Target end-effector homogeneous transform.

Accepts:

- ``(4, 4)`` or ``(n_envs, 4, 4)`` — a single waypoint.
- ``(n_envs, n_waypoint, 4, 4)`` — a multi-waypoint trajectory; waypoints
are visited in order. (Only consumed as multi-waypoint by MoveEndEffector.)
"""


@dataclass(frozen=True)
class JointPositionTarget:
"""Joint-space target for a configured robot control part."""

qpos: torch.Tensor
"""(control_dof,) or (n_envs, control_dof) target joint positions."""
"""Target joint positions.

Accepts:

- ``(control_dof,)`` or ``(n_envs, control_dof)`` — a single waypoint.
- ``(n_envs, n_waypoint, control_dof)`` — a multi-waypoint trajectory;
waypoints are visited in order.
"""


@dataclass(frozen=True)
Expand Down
90 changes: 77 additions & 13 deletions embodichain/lab/sim/atomic_actions/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,45 @@ def all_envs_success(self, is_success: bool | torch.Tensor) -> bool:
return bool(is_success)

def resolve_pose_target(self, target: torch.Tensor, *, n_envs: int) -> torch.Tensor:
"""Broadcast a (4, 4) pose to (n_envs, 4, 4) or validate batched shape."""
"""Resolve an end-effector pose target into batched homogeneous transforms.

Accepts the following shapes for ``target``:

- ``(4, 4)`` — broadcast to ``(n_envs, 4, 4)`` (single waypoint).
- ``(n_envs, 4, 4)`` — single waypoint, validated and passed through.
- ``(n_envs, n_waypoint, 4, 4)`` — a multi-waypoint trajectory; each
waypoint is visited in order. ``n_waypoint`` may be 1.

Returns a 3D tensor for single-waypoint inputs and a 4D tensor for
multi-waypoint inputs.
"""
if not isinstance(target, torch.Tensor):
logger.log_error(
f"target must be torch.Tensor of shape (4, 4) or ({n_envs}, 4, 4)",
f"target must be torch.Tensor of shape (4, 4), ({n_envs}, 4, 4), "
f"or ({n_envs}, n_waypoint, 4, 4)",
TypeError,
)
target = target.to(device=self.device, dtype=torch.float32)
if target.shape == (4, 4):
target = target.unsqueeze(0).repeat(n_envs, 1, 1)
if target.shape != (n_envs, 4, 4):
if target.dim() == 3:
if target.shape != (n_envs, 4, 4):
logger.log_error(
f"target tensor must have shape (4, 4) or ({n_envs}, 4, 4), "
f"but got {target.shape}",
ValueError,
)
elif target.dim() == 4:
if target.shape[0] != n_envs or target.shape[2:] != (4, 4):
logger.log_error(
f"multi-waypoint target tensor must have shape "
f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}",
ValueError,
)
else:
logger.log_error(
f"target tensor must have shape (4, 4) or ({n_envs}, 4, 4), "
f"but got {target.shape}",
f"target tensor must be (4, 4), ({n_envs}, 4, 4), or "
f"({n_envs}, n_waypoint, 4, 4), but got {target.shape}",
ValueError,
)
return target
Expand Down Expand Up @@ -103,20 +129,47 @@ def resolve_joint_target(
joint_dof: int,
control_part: str,
) -> torch.Tensor:
"""Resolve a joint-space target into batched control-part joint positions."""
"""Resolve a joint-space target into batched control-part joint positions.

Accepts the following shapes for ``target_qpos``:

- ``(joint_dof,)`` — broadcast to ``(n_envs, joint_dof)`` (single waypoint).
- ``(n_envs, joint_dof)`` — single waypoint, validated and passed through.
- ``(n_envs, n_waypoint, joint_dof)`` — a multi-waypoint trajectory; each
waypoint is visited in order. ``n_waypoint`` may be 1.

Returns a 2D tensor for single-waypoint inputs and a 3D tensor for
multi-waypoint inputs, leaving downstream planners to treat the trailing
axis as the joint dimension.
"""
if not isinstance(target_qpos, torch.Tensor):
logger.log_error(
f"target qpos for '{control_part}' must be a torch.Tensor with shape "
f"({joint_dof},) or ({n_envs}, {joint_dof})",
f"({joint_dof},), ({n_envs}, {joint_dof}), or "
f"({n_envs}, n_waypoint, {joint_dof})",
TypeError,
)
target_qpos = target_qpos.to(device=self.device, dtype=torch.float32)
if target_qpos.shape == (joint_dof,):
target_qpos = target_qpos.unsqueeze(0).repeat(n_envs, 1)
if target_qpos.shape != (n_envs, joint_dof):
if target_qpos.dim() == 2:
if target_qpos.shape != (n_envs, joint_dof):
logger.log_error(
f"target qpos for '{control_part}' must have shape ({joint_dof},) "
f"or ({n_envs}, {joint_dof}), but got {target_qpos.shape}",
ValueError,
)
elif target_qpos.dim() == 3:
if target_qpos.shape[0] != n_envs or target_qpos.shape[2] != joint_dof:
logger.log_error(
f"multi-waypoint target qpos for '{control_part}' must have shape "
f"({n_envs}, n_waypoint, {joint_dof}), but got {target_qpos.shape}",
ValueError,
)
else:
logger.log_error(
f"target qpos for '{control_part}' must have shape ({joint_dof},) "
f"or ({n_envs}, {joint_dof}), but got {target_qpos.shape}",
f"target qpos for '{control_part}' must be 1D, 2D, or 3D with "
f"trailing dim {joint_dof}, but got {target_qpos.shape}",
ValueError,
)
return target_qpos
Expand Down Expand Up @@ -291,10 +344,21 @@ def plan_joint_traj(
target_qpos: torch.Tensor,
n_waypoints: int,
) -> torch.Tensor:
"""Interpolate a joint-space trajectory from ``start_qpos`` to ``target_qpos``."""
trajectory = torch.stack([start_qpos, target_qpos], dim=1)
"""Interpolate a joint-space trajectory through one or more target waypoints.

``start_qpos`` has shape ``(n_envs, joint_dof)``. ``target_qpos`` is
either a single waypoint ``(n_envs, joint_dof)`` or a sequence of
waypoints ``(n_envs, n_waypoint, joint_dof)``. The start configuration is
prepended to the target waypoints to build the keyframe sequence
``(n_envs, 1 + n_waypoint, joint_dof)``, which is then resampled to
``n_waypoints`` output samples by cumulative-distance piecewise-linear
interpolation — so each consecutive waypoint pair is traversed in turn.
"""
if target_qpos.dim() == 2:
target_qpos = target_qpos.unsqueeze(1)
keyframes = torch.cat([start_qpos.unsqueeze(1), target_qpos], dim=1)
return interpolate_with_distance(
trajectory=trajectory, interp_num=n_waypoints, device=self.device
trajectory=keyframes, interp_num=n_waypoints, device=self.device
)

# ------------------------------------------------------------------
Expand Down
43 changes: 38 additions & 5 deletions scripts/tutorials/atomic_action/move_end_effector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ----------------------------------------------------------------------------

"""Demonstrate MoveEndEffector with a single pose target."""
"""Demonstrate MoveEndEffector with a multi-waypoint pose trajectory."""

from __future__ import annotations

Expand Down Expand Up @@ -66,7 +66,7 @@

def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Demonstrate MoveEndEffector with a top-down target pose."
description="Demonstrate MoveEndEffector with a multi-waypoint pose trajectory."
)
add_env_launcher_args_to_parser(parser)
parser.add_argument(
Expand Down Expand Up @@ -159,13 +159,29 @@ def make_top_down_eef_pose(device: torch.device) -> torch.Tensor:
return pose


def make_side_eef_pose(device: torch.device) -> torch.Tensor:
"""A second waypoint offset from the top-down pose for the multi-waypoint demo."""
pose = torch.eye(4, dtype=torch.float32, device=device)
pose[:3, :3] = torch.tensor(
[
[-0.0539, -0.9985, -0.0022],
[-0.9977, 0.0540, -0.0401],
[0.0401, 0.0000, -0.9992],
],
dtype=torch.float32,
device=device,
)
pose[:3, 3] = torch.tensor([0.45, 0.10, 0.30], dtype=torch.float32, device=device)
return pose


def format_tensor(tensor: torch.Tensor) -> str:
rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0
return str(rounded.tolist())


def main() -> None:
"""Move the robot end effector to one target pose using atomic actions."""
"""Move the robot end effector through a multi-waypoint pose trajectory."""
args = parse_arguments()

# ------------------------------------------------------------------ #
Expand Down Expand Up @@ -199,21 +215,38 @@ def main() -> None:
# Step 5: Define and visualize the end-effector target #
# ------------------------------------------------------------------ #
target_pose = make_top_down_eef_pose(sim.device)
side_pose = make_side_eef_pose(sim.device)
if not args.headless:
sim.open_window()
if not args.no_vis_eef_axis:
draw_axis_marker(sim, "move_end_effector_target_axis", target_pose)
draw_axis_marker(sim, "move_end_effector_side_axis", side_pose)
if not args.auto_play:
input("Inspect the robot, then press Enter to plan MoveEndEffector...")

# ------------------------------------------------------------------ #
# Step 6: Plan the declared (name, typed_target) sequence #
# ------------------------------------------------------------------ #
# Pass a multi-waypoint trajectory (n_envs, n_waypoint, 4, 4): the
# end-effector visits `target_pose` then `side_pose` in a single plan.
n_envs = robot.get_qpos().shape[0]
multi_waypoint_xpos = (
torch.stack([target_pose, side_pose], dim=0)
.unsqueeze(0)
.repeat(n_envs, 1, 1, 1)
)
logger.log_info(
f"Planning MoveEndEffector to xpos={format_tensor(target_pose[:3, 3])}"
"Planning MoveEndEffector through multi-waypoint trajectory: "
f"xpos0={format_tensor(target_pose[:3, 3])} -> "
f"xpos1={format_tensor(side_pose[:3, 3])}"
)
is_success, traj, _ = atomic_engine.run(
steps=[("move_end_effector", EndEffectorPoseTarget(xpos=target_pose))]
steps=[
(
"move_end_effector",
EndEffectorPoseTarget(xpos=multi_waypoint_xpos),
)
]
)
if not is_success:
logger.log_warning("Failed to plan MoveEndEffector demo trajectory.")
Expand Down
13 changes: 11 additions & 2 deletions scripts/tutorials/atomic_action/move_joints.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def main() -> None:
# Step 3: Configure the MoveJoints atomic action #
# ------------------------------------------------------------------ #
ready_qpos = make_arm_qpos([0.35, -1.20, 1.30, -1.65, -1.57, 0.20], sim.device)
mid_qpos = make_arm_qpos([0.15, -1.40, 1.45, -1.60, -1.57, 0.10], sim.device)
home_qpos = make_arm_qpos([0.0, -1.57, 1.57, -1.57, -1.57, 0.0], sim.device)
move_joints_cfg = MoveJointsCfg(
control_part="arm",
Expand All @@ -205,13 +206,21 @@ def main() -> None:
# ------------------------------------------------------------------ #
# Step 6: Plan the declared (name, typed_target) sequence #
# ------------------------------------------------------------------ #
# The second MoveJoints step passes a multi-waypoint trajectory
# (n_envs, n_waypoint, control_dof): the arm visits `mid_qpos` then
# `home_qpos` in a single planned trajectory.
n_envs = robot.get_qpos().shape[0]
multi_waypoint_qpos = (
torch.stack([mid_qpos, home_qpos], dim=0).unsqueeze(0).repeat(n_envs, 1, 1)
)
logger.log_info(
"Planning MoveJoints: NamedJointPositionTarget('ready') -> explicit home qpos"
"Planning MoveJoints: NamedJointPositionTarget('ready') -> "
"multi-waypoint trajectory (mid -> home)"
)
is_success, traj, _ = atomic_engine.run(
steps=[
("move_joints", NamedJointPositionTarget(name="ready")),
("move_joints", JointPositionTarget(qpos=home_qpos)),
("move_joints", JointPositionTarget(qpos=multi_waypoint_qpos)),
]
)
if not is_success:
Expand Down
Loading
Loading