MLP State-based Policy for RoboTwin
This module implements a simple MLP (Multi-Layer Perceptron) policy that uses state observations (robot proprioception + object poses) to predict robot actions for the stack_bowls_two task.
Prerequisites
Make sure you have the RoboTwin conda environment activated:
source ~/miniconda3/etc/profile.d/conda.sh
conda activate RoboTwin
Overview
Unlike vision-based policies (ACT, DP, etc.) that take camera images as input, this policy uses a compact 26-dimensional state vector as observation. This makes it much faster to train and deploy while still being effective for tasks where the relevant state can be directly observed.
Observation Space (26-dim)
| Feature | Dimensions | Description |
|---|---|---|
bowlA_pos |
3 | Position of bowl A (the bottom bowl) in world frame |
bowlA_quat |
4 | Quaternion of bowl A (w, x, y, z) |
bowlA_to_bowlB_pos |
3 | Relative position: bowl B position - bowl A position |
eef_pos_L |
3 | Left end-effector position |
eef_quat_L |
4 | Left end-effector quaternion (w, x, y, z) |
eef_pos_R |
3 | Right end-effector position |
eef_quat_R |
4 | Right end-effector quaternion (w, x, y, z) |
gripper_L |
1 | Left gripper state (0=closed, 1=open) |
gripper_R |
1 | Right gripper state |
Action Space (14-dim)
| Feature | Dimensions | Description |
|---|---|---|
delta_pose_left |
6 | Left arm end-effector pose change [dx, dy, dz, droll, dpitch, dyaw] |
gripper_left |
1 | Left gripper command (>=0 open, <0 close) |
delta_pose_right |
6 | Right arm end-effector pose change [dx, dy, dz, droll, dpitch, dyaw] |
gripper_right |
1 | Right gripper command (>=0 open, <0 close) |
The action is a delta action - it specifies the change in end-effector pose relative to the current pose, rather than an absolute target pose.
Pipeline
1. Data Collection
First, collect state-based demonstration data:
# From the repository root
bash collect_state_data.sh stack_bowls_two state_mlp_clean 0
This will:
- Run the expert demonstration for 50 episodes (configurable in task_config/state_mlp_clean.yml)
- Save HDF5 files containing endpose, qpos, and object_state data
- No videos or RGB images are saved (for efficiency)
The data will be saved to: ./data/stack_bowls_two/state_mlp_clean/data/
2. Training
Train the MLP policy on the collected data:
# From the repository root
bash train_mlp_policy.sh stack_bowls_two state_mlp_clean v1 50 0
Arguments:
- stack_bowls_two: Task name
- state_mlp_clean: Task config (data directory name)
- v1: Checkpoint setting name (for versioning)
- 50: Number of episodes to use for training
- 0: GPU ID
Training hyperparameters can be modified in the script or passed as arguments:
- --num_epochs 500: Total training epochs
- --batch_size 256: Batch size
- --lr 1e-4: Learning rate
- --hidden_dims 256 256 256: Hidden layer dimensions
- --obs_horizon 1: Number of past observations to stack (history)
- --action_horizon 1: Number of future actions to predict (chunk size)
- --dropout 0.0: Dropout rate
Checkpoints are saved to: ./policy/MLP_state/ckpts/stack_bowls_two/v1/
3. Evaluation
Evaluate the trained policy in simulation:
# From the policy/MLP_state directory
cd policy/MLP_state
# Fastest mode (no video recording):
bash eval.sh stack_bowls_two state_mlp_eval_novideo v1 0 0
# With video recording:
bash eval.sh stack_bowls_two state_mlp_eval v1 0 0
Arguments:
- stack_bowls_two: Task name
- state_mlp_eval_novideo or state_mlp_eval: Evaluation config
- v1: Checkpoint setting (matches training)
- 0: Random seed
- 0: GPU ID
Evaluation Configs:
| Config | Description |
|--------|-------------|
| state_mlp_eval_novideo | Fastest - no video, no RGB rendering |
| state_mlp_eval | With video recording (slower) |
Alternatively, run directly:
# From repository root (activate conda first)
source ~/miniconda3/etc/profile.d/conda.sh
conda activate RoboTwin
python script/eval_policy.py --config policy/MLP_state/deploy_policy.yml \
--overrides \
--task_name stack_bowls_two \
--task_config state_mlp_eval_novideo \
--ckpt_setting v1 \
--ckpt_dir policy/MLP_state/ckpts/stack_bowls_two/v1 \
--seed 0 \
--policy_name MLP_state
Evaluation results are saved to: ./eval_result/stack_bowls_two/MLP_state/<config>/v1/<timestamp>/
This includes:
- _result.txt: Success rate
- episode<N>.mp4: Videos of evaluation episodes (only with state_mlp_eval config)
Configuration Files
task_config/state_mlp_clean.yml
Task configuration for state-based data collection:
render_freq: 0 # No rendering during collection
episode_num: 50 # Number of episodes to collect
use_seed: false # Generate new seeds
save_freq: 15 # Save observation every 15 simulation steps
embodiment: [aloha-agilex] # Robot type
# Disable domain randomization for clean training data
domain_randomization:
random_background: false
cluttered_table: false
clean_background_rate: 1
random_head_camera_dis: 0
random_table_height: 0
random_light: false
crazy_random_light_rate: 0
# Disable camera collection
camera:
head_camera_type: D435
wrist_camera_type: D435
collect_head_camera: false
collect_wrist_camera: false
# Only collect endpose and qpos
data_type:
rgb: false
third_view: false
depth: false
pointcloud: false
observer: false
endpose: true
qpos: true
mesh_segmentation: false
actor_segmentation: false
save_path: ./data
collect_data: true
eval_video_log: true
policy/MLP_state/deploy_policy.yml
Evaluation configuration for the MLP policy:
policy_name: MLP_state
obs_dim: 26
action_dim: 14
hidden_dims: [256, 256, 256]
obs_horizon: 1
action_horizon: 1
dropout: 0.0
device: cuda:0
Architecture
The MLP policy is a simple feedforward network:
Input: [obs_horizon * 26] normalized observations
↓
Linear(input, 256) → LayerNorm → ReLU
↓
Linear(256, 256) → LayerNorm → ReLU
↓
Linear(256, 256) → LayerNorm → ReLU
↓
Linear(256, action_horizon * 14) → Reshape
↓
Output: [action_horizon, 14] normalized actions
File Structure
policy/MLP_state/
├── __init__.py # Module exports
├── mlp_model.py # MLPPolicy network definition
├── dataset.py # StateEpisodicDataset for loading HDF5
├── train.py # Training script
├── deploy_policy.py # Evaluation interface (get_model, eval, reset_model)
├── deploy_policy.yml # Evaluation configuration
├── eval.sh # Evaluation shell script
└── README.md # This file
script/
├── collect_data_state.py # State-based data collection (modified for bowl poses)
├── eval_policy.py # Generic policy evaluation
task_config/
└── state_mlp_clean.yml # Config for state-based collection
# Root level convenience scripts
collect_state_data.sh # Data collection wrapper
train_mlp_policy.sh # Training wrapper
Optional: Observation History
The policy supports stacking multiple past observations for temporal context:
# Train with 3 timesteps of history
python policy/MLP_state/train.py \
--data_dir ./data/stack_bowls_two/state_mlp_clean/data \
--ckpt_dir ./policy/MLP_state/ckpts/stack_bowls_two/v2_history \
--num_episodes 50 \
--obs_horizon 3 \
--action_horizon 1
Then update deploy_policy.yml:
obs_horizon: 3
Optional: Action Chunking
The policy also supports predicting multiple future actions (action chunking):
# Train with action chunk size of 4
python policy/MLP_state/train.py \
--data_dir ./data/stack_bowls_two/state_mlp_clean/data \
--ckpt_dir ./policy/MLP_state/ckpts/stack_bowls_two/v3_chunk \
--num_episodes 50 \
--obs_horizon 1 \
--action_horizon 4
Then update deploy_policy.yml:
action_horizon: 4
Notes
-
Dual-arm only: This policy is designed for dual-arm manipulation. Single-arm configurations are not supported.
-
Delta actions: The policy outputs delta poses (changes from current pose) rather than absolute target poses. This makes the policy more robust to small errors that accumulate over time.
-
Gripper commands: Gripper actions are mapped as:
cmd >= 0: Open grippercmd < 0: Close gripper-
During training, the mapping is:
cmd = 2 * target_gripper - 1 -
Coordinate frame: All poses are in the robot's world frame. The robot base is at approximately (0, -0.65, 0).
-
Normalization: Observations and actions are normalized using per-dimension mean and standard deviation computed from the training data. The normalization statistics are saved alongside the checkpoint.