Skip to content

MLP State-based Policy for RoboTwin

This module implements a simple MLP (Multi-Layer Perceptron) policy that uses state observations (robot proprioception + object poses) to predict robot actions for the stack_bowls_two task.

Prerequisites

Make sure you have the RoboTwin conda environment activated:

source ~/miniconda3/etc/profile.d/conda.sh
conda activate RoboTwin

Overview

Unlike vision-based policies (ACT, DP, etc.) that take camera images as input, this policy uses a compact 26-dimensional state vector as observation. This makes it much faster to train and deploy while still being effective for tasks where the relevant state can be directly observed.

Observation Space (26-dim)

Feature Dimensions Description
bowlA_pos 3 Position of bowl A (the bottom bowl) in world frame
bowlA_quat 4 Quaternion of bowl A (w, x, y, z)
bowlA_to_bowlB_pos 3 Relative position: bowl B position - bowl A position
eef_pos_L 3 Left end-effector position
eef_quat_L 4 Left end-effector quaternion (w, x, y, z)
eef_pos_R 3 Right end-effector position
eef_quat_R 4 Right end-effector quaternion (w, x, y, z)
gripper_L 1 Left gripper state (0=closed, 1=open)
gripper_R 1 Right gripper state

Action Space (14-dim)

Feature Dimensions Description
delta_pose_left 6 Left arm end-effector pose change [dx, dy, dz, droll, dpitch, dyaw]
gripper_left 1 Left gripper command (>=0 open, <0 close)
delta_pose_right 6 Right arm end-effector pose change [dx, dy, dz, droll, dpitch, dyaw]
gripper_right 1 Right gripper command (>=0 open, <0 close)

The action is a delta action - it specifies the change in end-effector pose relative to the current pose, rather than an absolute target pose.

Pipeline

1. Data Collection

First, collect state-based demonstration data:

# From the repository root
bash collect_state_data.sh stack_bowls_two state_mlp_clean 0

This will: - Run the expert demonstration for 50 episodes (configurable in task_config/state_mlp_clean.yml) - Save HDF5 files containing endpose, qpos, and object_state data - No videos or RGB images are saved (for efficiency)

The data will be saved to: ./data/stack_bowls_two/state_mlp_clean/data/

2. Training

Train the MLP policy on the collected data:

# From the repository root
bash train_mlp_policy.sh stack_bowls_two state_mlp_clean v1 50 0

Arguments: - stack_bowls_two: Task name - state_mlp_clean: Task config (data directory name) - v1: Checkpoint setting name (for versioning) - 50: Number of episodes to use for training - 0: GPU ID

Training hyperparameters can be modified in the script or passed as arguments: - --num_epochs 500: Total training epochs - --batch_size 256: Batch size - --lr 1e-4: Learning rate - --hidden_dims 256 256 256: Hidden layer dimensions - --obs_horizon 1: Number of past observations to stack (history) - --action_horizon 1: Number of future actions to predict (chunk size) - --dropout 0.0: Dropout rate

Checkpoints are saved to: ./policy/MLP_state/ckpts/stack_bowls_two/v1/

3. Evaluation

Evaluate the trained policy in simulation:

# From the policy/MLP_state directory
cd policy/MLP_state

# Fastest mode (no video recording):
bash eval.sh stack_bowls_two state_mlp_eval_novideo v1 0 0

# With video recording:
bash eval.sh stack_bowls_two state_mlp_eval v1 0 0

Arguments: - stack_bowls_two: Task name - state_mlp_eval_novideo or state_mlp_eval: Evaluation config - v1: Checkpoint setting (matches training) - 0: Random seed - 0: GPU ID

Evaluation Configs: | Config | Description | |--------|-------------| | state_mlp_eval_novideo | Fastest - no video, no RGB rendering | | state_mlp_eval | With video recording (slower) |

Alternatively, run directly:

# From repository root (activate conda first)
source ~/miniconda3/etc/profile.d/conda.sh
conda activate RoboTwin

python script/eval_policy.py --config policy/MLP_state/deploy_policy.yml \
    --overrides \
    --task_name stack_bowls_two \
    --task_config state_mlp_eval_novideo \
    --ckpt_setting v1 \
    --ckpt_dir policy/MLP_state/ckpts/stack_bowls_two/v1 \
    --seed 0 \
    --policy_name MLP_state

Evaluation results are saved to: ./eval_result/stack_bowls_two/MLP_state/<config>/v1/<timestamp>/

This includes: - _result.txt: Success rate - episode<N>.mp4: Videos of evaluation episodes (only with state_mlp_eval config)

Configuration Files

task_config/state_mlp_clean.yml

Task configuration for state-based data collection:

render_freq: 0            # No rendering during collection
episode_num: 50           # Number of episodes to collect
use_seed: false           # Generate new seeds
save_freq: 15             # Save observation every 15 simulation steps
embodiment: [aloha-agilex]  # Robot type

# Disable domain randomization for clean training data
domain_randomization:
  random_background: false
  cluttered_table: false
  clean_background_rate: 1
  random_head_camera_dis: 0
  random_table_height: 0
  random_light: false
  crazy_random_light_rate: 0

# Disable camera collection
camera:
  head_camera_type: D435
  wrist_camera_type: D435
  collect_head_camera: false
  collect_wrist_camera: false

# Only collect endpose and qpos
data_type:
  rgb: false
  third_view: false
  depth: false
  pointcloud: false
  observer: false
  endpose: true
  qpos: true
  mesh_segmentation: false
  actor_segmentation: false

save_path: ./data
collect_data: true
eval_video_log: true

policy/MLP_state/deploy_policy.yml

Evaluation configuration for the MLP policy:

policy_name: MLP_state
obs_dim: 26
action_dim: 14
hidden_dims: [256, 256, 256]
obs_horizon: 1
action_horizon: 1
dropout: 0.0
device: cuda:0

Architecture

The MLP policy is a simple feedforward network:

Input: [obs_horizon * 26] normalized observations
    ↓
Linear(input, 256) → LayerNorm → ReLU
    ↓
Linear(256, 256) → LayerNorm → ReLU
    ↓
Linear(256, 256) → LayerNorm → ReLU
    ↓
Linear(256, action_horizon * 14) → Reshape
    ↓
Output: [action_horizon, 14] normalized actions

File Structure

policy/MLP_state/
├── __init__.py           # Module exports
├── mlp_model.py          # MLPPolicy network definition
├── dataset.py            # StateEpisodicDataset for loading HDF5
├── train.py              # Training script
├── deploy_policy.py      # Evaluation interface (get_model, eval, reset_model)
├── deploy_policy.yml     # Evaluation configuration
├── eval.sh               # Evaluation shell script
└── README.md             # This file

script/
├── collect_data_state.py # State-based data collection (modified for bowl poses)
├── eval_policy.py        # Generic policy evaluation

task_config/
└── state_mlp_clean.yml   # Config for state-based collection

# Root level convenience scripts
collect_state_data.sh     # Data collection wrapper
train_mlp_policy.sh       # Training wrapper

Optional: Observation History

The policy supports stacking multiple past observations for temporal context:

# Train with 3 timesteps of history
python policy/MLP_state/train.py \
    --data_dir ./data/stack_bowls_two/state_mlp_clean/data \
    --ckpt_dir ./policy/MLP_state/ckpts/stack_bowls_two/v2_history \
    --num_episodes 50 \
    --obs_horizon 3 \
    --action_horizon 1

Then update deploy_policy.yml:

obs_horizon: 3

Optional: Action Chunking

The policy also supports predicting multiple future actions (action chunking):

# Train with action chunk size of 4
python policy/MLP_state/train.py \
    --data_dir ./data/stack_bowls_two/state_mlp_clean/data \
    --ckpt_dir ./policy/MLP_state/ckpts/stack_bowls_two/v3_chunk \
    --num_episodes 50 \
    --obs_horizon 1 \
    --action_horizon 4

Then update deploy_policy.yml:

action_horizon: 4

Notes

  1. Dual-arm only: This policy is designed for dual-arm manipulation. Single-arm configurations are not supported.

  2. Delta actions: The policy outputs delta poses (changes from current pose) rather than absolute target poses. This makes the policy more robust to small errors that accumulate over time.

  3. Gripper commands: Gripper actions are mapped as:

  4. cmd >= 0: Open gripper
  5. cmd < 0: Close gripper
  6. During training, the mapping is: cmd = 2 * target_gripper - 1

  7. Coordinate frame: All poses are in the robot's world frame. The robot base is at approximately (0, -0.65, 0).

  8. Normalization: Observations and actions are normalized using per-dimension mean and standard deviation computed from the training data. The normalization statistics are saved alongside the checkpoint.