Skip to content

MathisImm/CS_INTERNAL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CS_INTERNAL — Probing internal attributes of recurrent RL agents on Craftax

This repo trains a recurrent PPO agent on Craftax and trains probes on its recurrent hidden state to test whether the agent builds internal representations of quantities it cannot read directly from its observation. It continues the methodology of the MTRLCS / NMMO work (PPO-LSTM, probing cell states for position and tile type; Damasio-homeostasis framing) on a new, faster, fully-JAX environment.

The idea in one paragraph

Two kinds of probe target:

  • Position (sanity check / reproduction). Craftax's symbolic observation is egocentric — the map is always centred on the agent — so the absolute world position (x, y) is never in the input. The agent can only know where it is by integrating its own movement over time in its recurrent state. A probe that decodes position from the hidden state, while a probe on the raw observation cannot, reproduces the NMMO position result.
  • HP and the homeostatic intrinsics (the focus). player_health, food, drink, energy are in the symbolic obs (4 values at obs[-10:-6]). We remove that block at training time and keep an HP-grounded reward, so the agent must infer its own homeostatic state to survive. Probing the hidden state then tests for an emergent internal homeostatic model. The secondary intrinsics (recover/hunger/thirst/fatigue) are never in the obs yet fully determine when HP drops — pure interoception targets.

Evidence is relative: the hidden-state probe is compared against probes on the masked observation and on the feedforward encoder embedding. The advantage of the hidden state over the masked-obs baseline is the result of interest.

Internal attributes Craftax-Classic provides

group fields in symbolic obs?
spatial player_position (x,y), player_direction position no (egocentric) · direction yes
primary intrinsics player_health/food/drink/energy, is_sleeping yes → masked at train time
secondary intrinsics player_recover/hunger/thirst/fatigue no (never observable)
inventory 12 counts yes
world light_level, achievements(22), timestep, mobs mixed

(Full Craftax additionally exposes mana, xp, dexterity/strength/intelligence, floor level, enchantments, boss progress — a deferred scaling step.)

Pipeline

train  →  collect  →  probe
  1. train a GRU PPO agent on Craftax with the intrinsics masked and a ΔHP reward; checkpoints saved at evenly-spaced points to study emergence.
  2. collect rolls out a checkpoint and saves, per step, the GRU hidden state, the encoder embedding, the masked obs, and ground-truth labels from the EnvState (aligned to the obs the agent saw).
  3. probe trains linear + MLP probes for each target on each input and reports accuracy / R² vs chance and the hidden-over-obs advantage.

Setup

python3.12 -m venv .venv
source .venv/bin/activate
pip install -e .
# probe stage uses CPU torch:
pip install torch --index-url https://download.pytorch.org/whl/cpu
# GPU JAX (optional, for real-scale training):
pip install "jax[cuda12]==0.10.1"

Running

# 1) train (real run ~3e8 steps on GPU; example is a quick dev run)
cs-train --env_name Craftax-Classic-Symbolic-v1 --mask_intrinsics \
         --total_timesteps 5e6 --num_envs 256 --num_checkpoints 5 \
         --out_dir runs/exp1

# 2) collect a probe dataset from a checkpoint (or --all_ckpts)
cs-collect --run_dir runs/exp1 --ckpt last --num_envs 256 --num_steps 400

# 3) train + report probes
cs-probe --run_dir runs/exp1 --block last

Key flags: --no-mask_intrinsics (control agent that sees HP), --reward_hp_coef / --reward_ach_coef (homeostatic vs achievement reward mix), --cell gru|ema.

Layout

src/cs_internal/
  envs/        wrappers (vendored Craftax_Baselines + MaskIntrinsics + HomeostaticReward), make_env, label extraction
  training/    model.py (ScannedMemory + ActorCritic), ppo_gru.py (+ checkpointing)
  collection/  collect.py (rollout → hidden/encoder/obs + EnvState labels → npz)
  probing/     probe_models.py, train_probes.py (chance + advantage reporting)
  checkpointing.py
tests/         env-layer invariants

Deferred / future scaling

  • Full Craftax-Symbolic-v1 (mana/xp/floor/enchantments).
  • Combined homeostatic reward over all intrinsic changes (not HP alone).
  • Previous action + reward in the obs (RL²-style) — note: reward is a function of ΔHP, so a raw reward channel re-leaks the masked signal; use a masked/transformed channel.
  • Difficulty/scarcity sweep (probe advantage emerges only under survival pressure).
  • Control agent trained with intrinsics visible, to contrast probe advantage.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages