This repo trains a recurrent PPO agent on Craftax and trains probes on its recurrent hidden state to test whether the agent builds internal representations of quantities it cannot read directly from its observation. It continues the methodology of the MTRLCS / NMMO work (PPO-LSTM, probing cell states for position and tile type; Damasio-homeostasis framing) on a new, faster, fully-JAX environment.
Two kinds of probe target:
- Position (sanity check / reproduction). Craftax's symbolic observation
is egocentric — the map is always centred on the agent — so the absolute
world position
(x, y)is never in the input. The agent can only know where it is by integrating its own movement over time in its recurrent state. A probe that decodes position from the hidden state, while a probe on the raw observation cannot, reproduces the NMMO position result. - HP and the homeostatic intrinsics (the focus).
player_health,food,drink,energyare in the symbolic obs (4 values atobs[-10:-6]). We remove that block at training time and keep an HP-grounded reward, so the agent must infer its own homeostatic state to survive. Probing the hidden state then tests for an emergent internal homeostatic model. The secondary intrinsics (recover/hunger/thirst/fatigue) are never in the obs yet fully determine when HP drops — pure interoception targets.
Evidence is relative: the hidden-state probe is compared against probes on the masked observation and on the feedforward encoder embedding. The advantage of the hidden state over the masked-obs baseline is the result of interest.
| group | fields | in symbolic obs? |
|---|---|---|
| spatial | player_position (x,y), player_direction |
position no (egocentric) · direction yes |
| primary intrinsics | player_health/food/drink/energy, is_sleeping |
yes → masked at train time |
| secondary intrinsics | player_recover/hunger/thirst/fatigue |
no (never observable) |
| inventory | 12 counts | yes |
| world | light_level, achievements(22), timestep, mobs |
mixed |
(Full Craftax additionally exposes mana, xp, dexterity/strength/intelligence,
floor level, enchantments, boss progress — a deferred scaling step.)
train → collect → probe
- train a GRU PPO agent on Craftax with the intrinsics masked and a ΔHP reward; checkpoints saved at evenly-spaced points to study emergence.
- collect rolls out a checkpoint and saves, per step, the GRU hidden state,
the encoder embedding, the masked obs, and ground-truth labels from the
EnvState(aligned to the obs the agent saw). - probe trains linear + MLP probes for each target on each input and reports accuracy / R² vs chance and the hidden-over-obs advantage.
python3.12 -m venv .venv
source .venv/bin/activate
pip install -e .
# probe stage uses CPU torch:
pip install torch --index-url https://download.pytorch.org/whl/cpu
# GPU JAX (optional, for real-scale training):
pip install "jax[cuda12]==0.10.1"# 1) train (real run ~3e8 steps on GPU; example is a quick dev run)
cs-train --env_name Craftax-Classic-Symbolic-v1 --mask_intrinsics \
--total_timesteps 5e6 --num_envs 256 --num_checkpoints 5 \
--out_dir runs/exp1
# 2) collect a probe dataset from a checkpoint (or --all_ckpts)
cs-collect --run_dir runs/exp1 --ckpt last --num_envs 256 --num_steps 400
# 3) train + report probes
cs-probe --run_dir runs/exp1 --block lastKey flags: --no-mask_intrinsics (control agent that sees HP), --reward_hp_coef
/ --reward_ach_coef (homeostatic vs achievement reward mix), --cell gru|ema.
src/cs_internal/
envs/ wrappers (vendored Craftax_Baselines + MaskIntrinsics + HomeostaticReward), make_env, label extraction
training/ model.py (ScannedMemory + ActorCritic), ppo_gru.py (+ checkpointing)
collection/ collect.py (rollout → hidden/encoder/obs + EnvState labels → npz)
probing/ probe_models.py, train_probes.py (chance + advantage reporting)
checkpointing.py
tests/ env-layer invariants
- Full
Craftax-Symbolic-v1(mana/xp/floor/enchantments). - Combined homeostatic reward over all intrinsic changes (not HP alone).
- Previous action + reward in the obs (RL²-style) — note: reward is a function of ΔHP, so a raw reward channel re-leaks the masked signal; use a masked/transformed channel.
- Difficulty/scarcity sweep (probe advantage emerges only under survival pressure).
- Control agent trained with intrinsics visible, to contrast probe advantage.