train_dreamer4_minecraft
Training script for Dreamer4 on Minecraft using VPT data.
This script implements the full 3-phase Dreamer4 training pipeline:
Phase 1: Video Tokenizer Training Train the VideoTokenizer to compress 384x640 Minecraft frames into compact latent representations. Uses MAE-style masking, LPIPS perceptual loss, and axial space-time attention.
Phase 2: World Model (Dynamics) Training Train the DynamicsWorldModel on tokenized video + actions from VPT data. Uses flow matching with shortcut consistency training. The model learns to predict future latent states given past observations and actions.
Phase 3: Agent Training (Dream-based RL) Train the policy/value heads inside the learned world model using imagined rollouts (DreamTrainer). No real environment needed — the world model generates synthetic experience for PPO/PMPO training.
Each phase saves checkpoints that the next phase loads.
Usage:
Phase 1: Train tokenizer
python train_dreamer4_minecraft.py --phase 1 --data_dir ./data/vpt-recordings --output_dir ./checkpoints --num_steps 50000
Phase 2: Train world model
python train_dreamer4_minecraft.py --phase 2 --data_dir ./data/vpt-recordings --output_dir ./checkpoints --tokenizer_ckpt ./checkpoints/tokenizer.pt --num_steps 100000
Phase 3: Train agent in dreams
python train_dreamer4_minecraft.py --phase 3 --output_dir ./checkpoints --dynamics_ckpt ./checkpoints/dynamics.pt --num_steps 50000
1""" 2Training script for Dreamer4 on Minecraft using VPT data. 3 4This script implements the full 3-phase Dreamer4 training pipeline: 5 6 Phase 1: Video Tokenizer Training 7 Train the VideoTokenizer to compress 384x640 Minecraft frames into 8 compact latent representations. Uses MAE-style masking, LPIPS perceptual 9 loss, and axial space-time attention. 10 11 Phase 2: World Model (Dynamics) Training 12 Train the DynamicsWorldModel on tokenized video + actions from VPT data. 13 Uses flow matching with shortcut consistency training. The model learns 14 to predict future latent states given past observations and actions. 15 16 Phase 3: Agent Training (Dream-based RL) 17 Train the policy/value heads inside the learned world model using 18 imagined rollouts (DreamTrainer). No real environment needed — the 19 world model generates synthetic experience for PPO/PMPO training. 20 21Each phase saves checkpoints that the next phase loads. 22 23Usage: 24 # Phase 1: Train tokenizer 25 python train_dreamer4_minecraft.py --phase 1 \ 26 --data_dir ./data/vpt-recordings \ 27 --output_dir ./checkpoints \ 28 --num_steps 50000 29 30 # Phase 2: Train world model 31 python train_dreamer4_minecraft.py --phase 2 \ 32 --data_dir ./data/vpt-recordings \ 33 --output_dir ./checkpoints \ 34 --tokenizer_ckpt ./checkpoints/tokenizer.pt \ 35 --num_steps 100000 36 37 # Phase 3: Train agent in dreams 38 python train_dreamer4_minecraft.py --phase 3 \ 39 --output_dir ./checkpoints \ 40 --dynamics_ckpt ./checkpoints/dynamics.pt \ 41 --num_steps 50000 42""" 43 44import os 45import argparse 46import sys 47from pathlib import Path 48 49import torch 50 51from dreamer4 import VideoTokenizer, DynamicsWorldModel 52from dreamer4.trainers import ( 53 VideoTokenizerTrainer, 54 BehaviorCloneTrainer, 55 DreamTrainer, 56 _prune_previous_checkpoint, 57) 58 59 60def _cleanup_last_intermediate_checkpoint(output_dir, legacy_prefix): 61 """After the final phase-level .pt is saved, the most recent 62 state-<N>/ dir and <prefix>-<N>.pt files are no longer needed — 63 the per-step pruner already removed every *earlier* one, so this 64 just handles the tail. new_step=-1 guarantees the sentinel never 65 matches a real step, so the helper always deletes the target. 66 """ 67 pointer = Path(output_dir) / 'latest_state.txt' 68 if not pointer.exists(): 69 return 70 last_name = pointer.read_text().strip() 71 _prune_previous_checkpoint( 72 Path(output_dir), last_name, new_step=-1, legacy_prefix=legacy_prefix, 73 ) 74 try: 75 pointer.unlink() 76 except OSError: 77 pass 78 79from minecraft_vpt_dataset import ( 80 MinecraftVPTDataset, 81 DREAMER4_NUM_DISCRETE_ACTIONS, 82) 83 84# cuDNN autotuner picks the fastest conv/attn algorithm for the (fixed) input 85# shapes used by this script. One-line win, safe because input shapes are 86# static (384x640 video, constant seq_len per phase). 87torch.backends.cudnn.benchmark = True 88 89 90# ─── CUDA / resume helpers ─────────────────────────────────────────── 91 92def _enforce_cuda_or_exit(args, phase_name: str) -> bool: 93 """Return True if the phase should run on CPU, False for GPU. 94 95 If CUDA is unavailable and --allow_cpu was not passed, print a loud 96 warning and exit(1). This prevents silently burning HPC walltime 97 on a CPU fallback caused by e.g. a CUDA driver version mismatch. 98 """ 99 if torch.cuda.is_available(): 100 print(f"[{phase_name}] CUDA OK — device={torch.cuda.get_device_name(0)} " 101 f"| torch={torch.__version__} | cuda={torch.version.cuda}") 102 return False 103 104 msg = ( 105 "\n" + "=" * 70 + "\n" 106 f"[{phase_name}] ERROR: torch.cuda.is_available() is False.\n" 107 "This is almost always a CUDA driver / PyTorch mismatch.\n" 108 f" torch version : {torch.__version__}\n" 109 f" torch.cuda : {torch.version.cuda}\n" 110 "Training on CPU for this model size is ~100x slower than a V100S.\n" 111 "To fix: install a PyTorch build matching your cluster driver, e.g.\n" 112 "pip install torch --index-url https://download.pytorch.org/whl/cu121\n" 113 "or ask your admin which CUDA version the driver supports.\n" 114 "\n" 115 "Pass --allow_cpu to override this check (e.g. for a laptop smoke test).\n" 116 + "=" * 70 + "\n" 117 ) 118 if not getattr(args, 'allow_cpu', False): 119 print(msg, file=sys.stderr) 120 sys.exit(1) 121 print(msg, file=sys.stderr) 122 print(f"[{phase_name}] --allow_cpu set, proceeding on CPU (will be slow).", 123 file=sys.stderr) 124 return True 125 126 127def _resolve_resume_from(raw: str | None, output_dir: str) -> str | None: 128 """Resolve --resume_from into a concrete state-<N> directory path. 129 130 Acceptable inputs: 131 None → no resume 132 'latest' → look in output_dir/latest_state.txt 133 <path to state-N dir> → use as-is 134 <path to dir containing state-N dirs> 135 → look for latest_state.txt inside it 136 """ 137 if raw is None: 138 return None 139 140 if raw == 'latest': 141 base = Path(output_dir) 142 else: 143 base = Path(raw) 144 145 if base.is_dir(): 146 # Accept either a state-<N> dir directly, or a parent with a pointer. 147 pointer = base / 'latest_state.txt' 148 if pointer.exists(): 149 target = base / pointer.read_text().strip() 150 if target.exists(): 151 print(f"[resume] using {target} from {pointer}") 152 return str(target) 153 print(f"[resume] pointer {pointer} references {target} but it does not exist") 154 return None 155 # Might already be a state-<N> dir 156 if (base / 'random_states_0.pkl').exists() or any(base.glob('*.safetensors')): 157 print(f"[resume] using state dir {base}") 158 return str(base) 159 print(f"[resume] {base} is not a state dir and has no latest_state.txt; starting fresh") 160 return None 161 162 print(f"[resume] resume_from path {base} does not exist; starting fresh") 163 return None 164 165 166# ─── Default Hyperparameters ──────────────────────────────────────── 167 168# Paper-matched defaults for Minecraft at 384x640 resolution (360x640 zero-padded). 169# See data.txt: 16x16 patches → 960 tokens, bottleneck (N_b=512)x(D_b=16). 170 171TOKENIZER_DEFAULTS = dict( 172 dim=256, # Hidden dimension of transformer 173 dim_latent=16, # Latent bottleneck dimension (D_b=16) 174 patch_size=16, # 384/16=24, 640/16=40 → 960 patch tokens 175 image_height=384, # VPT 360x640 zero-padded to 384x640 176 image_width=640, 177 num_latent_tokens=512, # Bottleneck token count (N_b=512) 178 encoder_depth=4, # Transformer depth 179 decoder_depth=4, 180 time_block_every=4, # Temporal attention every 4th block 181 attn_heads=4, 182 attn_dim_head=64, 183 lpips_loss_weight=0.2, # Perceptual loss weight 184 per_image_patch_mask_prob=(0.0, 0.9), # MAE masking range 185 use_loss_normalization=True, 186) 187 188DYNAMICS_DEFAULTS = dict( 189 dim=256, # Hidden dimension 190 dim_latent=16, # Must match tokenizer (D_b=16) 191 max_steps=64, # K_max for flow matching (power of 2) 192 num_register_tokens=8, # Register tokens for temporal consistency 193 num_spatial_tokens=256, # Paper: N_z=256 spatial tokens 194 num_latent_tokens=512, # Must match tokenizer (N_b=512) 195 depth=8, # Transformer depth 196 time_block_every=4, 197 attn_heads=4, 198 attn_dim_head=64, 199 use_time_rnn=True, # GRU on temporal blocks 200 # Action space: 20 binary buttons + 1 camera (121 choices) 201 num_discrete_actions=DREAMER4_NUM_DISCRETE_ACTIONS, 202 num_continuous_actions=0, # All actions are discrete 203 multi_token_pred_len=8, # Multi-token prediction horizon 204 pred_orig_latent=True, # x-space prediction (better than v-space) 205 # RL hyperparameters (used in Phase 3) 206 gae_discount_factor=0.997, 207 gae_lambda=0.95, 208 ppo_eps_clip=0.2, 209 policy_entropy_weight=0.01, 210) 211 212# LeWM dynamics: replaces flow matching with JEPA-style next-embedding prediction 213LEWM_DYNAMICS_DEFAULTS = dict( 214 **DYNAMICS_DEFAULTS, 215 use_lewm_dynamics=True, # Enable LeWM mode 216 lewm_loss_weight=1.0, # Next-embedding prediction loss weight 217 lewm_sigreg_loss_weight=0.05, # SIGReg regularization weight 218 lewm_layer=-1, # Use last transformer layer for prediction 219 lewm_action_conditioned=True, # Condition predictions on actions 220) 221 222TRAINING_DEFAULTS = dict( 223 # Phase 1 224 tokenizer_batch_size=2, # Keep small — AttentionResidual layers are memory-intensive 225 tokenizer_lr=3e-4, 226 tokenizer_num_steps=50000, 227 tokenizer_max_grad_norm=1.0, 228 tokenizer_seq_len=16, 229 # Phase 2 230 dynamics_batch_size=4, 231 dynamics_lr=3e-4, 232 dynamics_num_steps=100000, 233 dynamics_max_grad_norm=1.0, 234 dynamics_seq_len=16, 235 # Phase 3 236 dream_batch_size=16, 237 dream_lr=3e-4, 238 dream_num_steps=50000, 239 dream_max_grad_norm=1.0, 240 dream_generate_timesteps=16, 241) 242 243 244# ─── Phase 1: Train Video Tokenizer ──────────────────────────────── 245 246def train_tokenizer(args): 247 """Train the VideoTokenizer on Minecraft video data. 248 249 The tokenizer learns to compress 384x640 RGB frames into compact 250 latent representations using: 251 - Patch embedding (16x16 patches → 24x40 spatial grid) 252 - Axial space-time transformer encoder 253 - Latent bottleneck with Tanh activation 254 - MAE-style patch masking for regularization 255 - LPIPS perceptual loss for visual quality 256 - Temporal/spatial decorrelation losses 257 258 The encoder output shape per frame: (num_latent_tokens, dim_latent) = (512, 16) 259 """ 260 print("=" * 60) 261 print("PHASE 1: Training Video Tokenizer") 262 print("=" * 60) 263 264 use_cpu = _enforce_cuda_or_exit(args, "Phase 1") 265 resume_from = _resolve_resume_from(args.resume_from, args.output_dir) 266 267 # Load dataset — tokenizer only needs video, no actions 268 dataset = MinecraftVPTDataset( 269 data_dir=args.data_dir, 270 seq_len=args.tokenizer_seq_len, 271 stride=args.tokenizer_seq_len, 272 image_height=384, 273 image_width=640, 274 max_trajectories=args.max_trajectories, 275 ) 276 277 # The VideoTokenizerTrainer expects a dataset that yields video tensors. 278 # Our dataset yields dicts, so we wrap it to extract just the video. 279 class VideoOnlyDataset(torch.utils.data.Dataset): 280 """Adapter that returns only the ``'video'`` key from MinecraftVPTDataset.""" 281 282 def __init__(self, base_dataset): 283 self.base = base_dataset 284 285 def __len__(self): 286 return len(self.base) 287 288 def __getitem__(self, idx): 289 return self.base[idx]['video'] # (3, T, H, W) 290 291 video_dataset = VideoOnlyDataset(dataset) 292 293 if len(video_dataset) == 0: 294 raise RuntimeError( 295 f"No training clips were created from data in '{args.data_dir}'. " 296 f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. " 297 f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)" 298 ) 299 300 # Create tokenizer 301 tokenizer = VideoTokenizer(**TOKENIZER_DEFAULTS) 302 print(f"VideoTokenizer parameters: {sum(p.numel() for p in tokenizer.parameters()):,}") 303 304 # Train 305 trainer = VideoTokenizerTrainer( 306 model=tokenizer, 307 dataset=video_dataset, 308 batch_size=args.tokenizer_batch_size, 309 learning_rate=args.tokenizer_lr, 310 max_grad_norm=args.tokenizer_max_grad_norm, 311 num_train_steps=args.num_steps or args.tokenizer_num_steps, 312 cpu=use_cpu, 313 mixed_precision=args.mixed_precision, 314 use_tensorboard_logger=args.use_tensorboard, 315 log_dir=args.output_dir, 316 log_video=args.log_video, 317 video_fps=20, 318 log_video_every=args.log_video_every, 319 checkpoint_folder=args.output_dir, 320 dataloader_num_workers=args.num_workers, 321 dataloader_pin_memory=not args.no_pin_memory, 322 dataloader_prefetch_factor=args.prefetch_factor, 323 resume_from=resume_from, 324 ) 325 326 trainer() 327 328 # Save checkpoint 329 os.makedirs(args.output_dir, exist_ok=True) 330 ckpt_path = os.path.join(args.output_dir, "tokenizer.pt") 331 torch.save({ 332 'model': tokenizer.state_dict(), 333 'config': TOKENIZER_DEFAULTS, 334 }, ckpt_path) 335 print(f"Tokenizer saved to {ckpt_path}") 336 337 # Only prune the last state-<N>/ if training finished cleanly. If the 338 # trainer was cancelled via SIGTERM (scancel / walltime), keep the 339 # state dir so the user can resume with --resume_from latest. 340 if not getattr(trainer, 'cancelled', False): 341 _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='tokenizer') 342 else: 343 print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from") 344 345 346# ─── Phase 2: Train Dynamics World Model ─────────────────────────── 347 348def train_dynamics(args): 349 """Train the DynamicsWorldModel on tokenized video + actions. 350 351 The dynamics model learns to predict future latent states using 352 flow matching with shortcut consistency training: 353 354 1. Takes clean latents from the tokenizer 355 2. Adds noise at random signal levels (flow matching) 356 3. Predicts the clean latents from noised versions 357 4. Also predicts rewards and actions (multi-token prediction) 358 359 Shortcut training (from Frans et al.): 360 - Sometimes trains with step_size > 1, allowing the model to 361 make larger jumps in the denoising process 362 - Consistency loss ensures half-step predictions compose correctly 363 - This makes generation faster at inference time 364 365 The world model processes sequences of: 366 - Spatial tokens (from latents) 367 - Action tokens (embedded discrete actions) 368 - Reward tokens (SymExp two-hot encoded) 369 - Register tokens (for temporal consistency) 370 - Agent tokens (for policy/value heads) 371 """ 372 print("=" * 60) 373 print("PHASE 2: Training Dynamics World Model") 374 print("=" * 60) 375 376 use_cpu = _enforce_cuda_or_exit(args, "Phase 2") 377 resume_from = _resolve_resume_from(args.resume_from, args.output_dir) 378 379 # Load tokenizer from Phase 1 checkpoint 380 assert args.tokenizer_ckpt is not None, "Must provide --tokenizer_ckpt for Phase 2" 381 tok_ckpt = torch.load(args.tokenizer_ckpt, map_location='cpu', weights_only=False) 382 tok_config = tok_ckpt.get('config', TOKENIZER_DEFAULTS) 383 384 tokenizer = VideoTokenizer(**tok_config) 385 tokenizer.load_state_dict(tok_ckpt['model']) 386 tokenizer.eval() 387 for p in tokenizer.parameters(): 388 p.requires_grad_(False) 389 print("Loaded frozen tokenizer") 390 391 # Load dataset with actions 392 dataset = MinecraftVPTDataset( 393 data_dir=args.data_dir, 394 seq_len=args.dynamics_seq_len, 395 stride=args.dynamics_seq_len, 396 image_height=384, 397 image_width=640, 398 max_trajectories=args.max_trajectories, 399 ) 400 401 if len(dataset) == 0: 402 raise RuntimeError( 403 f"No training clips were created from data in '{args.data_dir}'. " 404 f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. " 405 f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)" 406 ) 407 408 # Create dynamics model with the tokenizer 409 base_defaults = LEWM_DYNAMICS_DEFAULTS if args.use_lewm else DYNAMICS_DEFAULTS 410 dynamics_config = base_defaults.copy() 411 dynamics_config['num_latent_tokens'] = tok_config.get('num_latent_tokens', 16) 412 dynamics_config['dim_latent'] = tok_config.get('dim_latent', 32) 413 414 variant = "LeWM" if args.use_lewm else "Dreamer4" 415 dynamics = DynamicsWorldModel( 416 video_tokenizer=tokenizer, 417 **dynamics_config, 418 ) 419 n_params = sum(p.numel() for p in dynamics.parameters()) 420 print(f"DynamicsWorldModel ({variant}) parameters: {n_params:,}") 421 422 # Train using BehaviorCloneTrainer 423 # The trainer accepts dict batches and calls dynamics(**batch_data) 424 trainer = BehaviorCloneTrainer( 425 model=dynamics, 426 dataset=dataset, 427 batch_size=args.dynamics_batch_size, 428 learning_rate=args.dynamics_lr, 429 max_grad_norm=args.dynamics_max_grad_norm, 430 num_train_steps=args.num_steps or args.dynamics_num_steps, 431 cpu=use_cpu, 432 mixed_precision=args.mixed_precision, 433 use_tensorboard_logger=args.use_tensorboard, 434 log_dir=args.output_dir, 435 checkpoint_folder=args.output_dir, 436 dataloader_num_workers=args.num_workers, 437 dataloader_pin_memory=not args.no_pin_memory, 438 dataloader_prefetch_factor=args.prefetch_factor, 439 resume_from=resume_from, 440 ) 441 442 trainer() 443 444 # Save checkpoint 445 os.makedirs(args.output_dir, exist_ok=True) 446 ckpt_name = "lewm_dynamics.pt" if args.use_lewm else "dynamics.pt" 447 ckpt_path = os.path.join(args.output_dir, ckpt_name) 448 torch.save({ 449 'model': dynamics.state_dict(), 450 'config': dynamics_config, 451 'tokenizer_config': tok_config, 452 'use_lewm': args.use_lewm, 453 }, ckpt_path) 454 print(f"Dynamics model saved to {ckpt_path}") 455 456 if not getattr(trainer, 'cancelled', False): 457 _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='dynamics') 458 else: 459 print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from") 460 461 462# ─── Phase 3: Dream-based Agent Training ─────────────────────────── 463 464def train_agent(args): 465 """Train the policy/value heads using imagined rollouts. 466 467 DreamTrainer generates experience entirely inside the world model: 468 469 1. Start with random noise latents 470 2. Iteratively denoise using the dynamics model (generate()) 471 3. At each step, the agent token embedding provides: 472 - Policy distribution (via policy head MLP) 473 - Value estimate (via value head MLP) 474 - Predicted reward (via reward prediction heads) 475 4. After generating a trajectory, compute GAE returns 476 5. Update policy head with PPO/PMPO loss 477 6. Update value head with clipped value loss 478 479 Only the policy and value head parameters are updated. 480 The world model transformer weights remain frozen. 481 482 This is the key advantage of Dreamer-style methods: 483 the agent can train on unlimited imagined experience 484 without needing access to the real Minecraft environment. 485 """ 486 print("=" * 60) 487 print("PHASE 3: Training Agent in Dreams") 488 print("=" * 60) 489 490 use_cpu = _enforce_cuda_or_exit(args, "Phase 3") 491 resume_from = _resolve_resume_from(args.resume_from, args.output_dir) 492 493 # Load dynamics model from Phase 2 494 assert args.dynamics_ckpt is not None, "Must provide --dynamics_ckpt for Phase 3" 495 dyn_ckpt = torch.load(args.dynamics_ckpt, map_location='cpu', weights_only=False) 496 is_lewm = dyn_ckpt.get('use_lewm', False) or args.use_lewm 497 default_config = LEWM_DYNAMICS_DEFAULTS if is_lewm else DYNAMICS_DEFAULTS 498 dyn_config = dyn_ckpt.get('config', default_config) 499 tok_config = dyn_ckpt.get('tokenizer_config', TOKENIZER_DEFAULTS) 500 if is_lewm: 501 print("Detected LeWM dynamics checkpoint") 502 503 # Rebuild tokenizer (frozen) 504 tokenizer = VideoTokenizer(**tok_config) 505 tokenizer.eval() 506 507 # Rebuild dynamics model 508 dynamics = DynamicsWorldModel( 509 video_tokenizer=tokenizer, 510 **dyn_config, 511 ) 512 dynamics.load_state_dict(dyn_ckpt['model']) 513 print("Loaded dynamics model") 514 515 # Freeze everything except policy and value heads 516 for p in dynamics.parameters(): 517 p.requires_grad_(False) 518 for p in dynamics.policy_head_parameters(): 519 p.requires_grad_(True) 520 for p in dynamics.value_head_parameters(): 521 p.requires_grad_(True) 522 523 n_trainable = sum(p.numel() for p in dynamics.parameters() if p.requires_grad) 524 print(f"Trainable parameters: {n_trainable:,}") 525 526 # Train using DreamTrainer 527 trainer = DreamTrainer( 528 model=dynamics, 529 batch_size=args.dream_batch_size, 530 generate_timesteps=args.dream_generate_timesteps, 531 learning_rate=args.dream_lr, 532 max_grad_norm=args.dream_max_grad_norm, 533 num_train_steps=args.num_steps or args.dream_num_steps, 534 cpu=use_cpu, 535 mixed_precision=args.mixed_precision, 536 use_tensorboard_logger=args.use_tensorboard, 537 log_dir=args.output_dir, 538 checkpoint_every=args.dream_checkpoint_every, 539 checkpoint_folder=args.output_dir, 540 resume_from=resume_from, 541 ) 542 543 trainer() 544 545 # Save final checkpoint with everything 546 os.makedirs(args.output_dir, exist_ok=True) 547 ckpt_name = "lewm_minecraft.pt" if is_lewm else "dreamer4_minecraft.pt" 548 ckpt_path = os.path.join(args.output_dir, ckpt_name) 549 torch.save({ 550 'model': dynamics.state_dict(), 551 'config': dyn_config, 552 'tokenizer_config': tok_config, 553 'use_lewm': is_lewm, 554 }, ckpt_path) 555 print(f"Trained agent saved to {ckpt_path}") 556 557 if not getattr(trainer, 'cancelled', False): 558 _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix=None) 559 else: 560 print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from") 561 562 563# ─── CLI ──────────────────────────────────────────────────────────── 564 565def main(): 566 """CLI entry point: parse arguments and dispatch to the requested phase. 567 568 Wires up the full argparse surface for all three phases (tokenizer / 569 dynamics / dream), validates that a data directory is present for 570 phases that need it, and then calls :func:`train_tokenizer`, 571 :func:`train_dynamics`, or :func:`train_agent` based on ``--phase``. 572 """ 573 parser = argparse.ArgumentParser( 574 description="Train Dreamer4 on Minecraft using VPT data" 575 ) 576 577 # Required arguments 578 parser.add_argument("--phase", type=int, required=True, choices=[1, 2, 3], 579 help="Training phase: 1=tokenizer, 2=dynamics, 3=agent") 580 parser.add_argument("--output_dir", type=str, default="./checkpoints", 581 help="Directory for saving checkpoints and logs") 582 583 # Data arguments 584 parser.add_argument("--data_dir", type=str, default=None, 585 help="Directory containing VPT .mp4/.jsonl pairs") 586 parser.add_argument("--max_trajectories", type=int, default=None, 587 help="Limit number of trajectories (for debugging)") 588 589 # Checkpoint arguments 590 parser.add_argument("--tokenizer_ckpt", type=str, default=None, 591 help="Path to tokenizer checkpoint (Phase 2)") 592 parser.add_argument("--dynamics_ckpt", type=str, default=None, 593 help="Path to dynamics checkpoint (Phase 3)") 594 595 # Training arguments 596 parser.add_argument("--num_steps", type=int, default=None, 597 help="Override number of training steps") 598 599 # Tokenizer hyperparameters 600 parser.add_argument("--tokenizer_batch_size", type=int, 601 default=TRAINING_DEFAULTS['tokenizer_batch_size']) 602 parser.add_argument("--tokenizer_lr", type=float, 603 default=TRAINING_DEFAULTS['tokenizer_lr']) 604 parser.add_argument("--tokenizer_max_grad_norm", type=float, 605 default=TRAINING_DEFAULTS['tokenizer_max_grad_norm']) 606 parser.add_argument("--tokenizer_num_steps", type=int, 607 default=TRAINING_DEFAULTS['tokenizer_num_steps']) 608 parser.add_argument("--tokenizer_seq_len", type=int, 609 default=TRAINING_DEFAULTS['tokenizer_seq_len']) 610 611 # Dynamics hyperparameters 612 parser.add_argument("--dynamics_batch_size", type=int, 613 default=TRAINING_DEFAULTS['dynamics_batch_size']) 614 parser.add_argument("--dynamics_lr", type=float, 615 default=TRAINING_DEFAULTS['dynamics_lr']) 616 parser.add_argument("--dynamics_max_grad_norm", type=float, 617 default=TRAINING_DEFAULTS['dynamics_max_grad_norm']) 618 parser.add_argument("--dynamics_num_steps", type=int, 619 default=TRAINING_DEFAULTS['dynamics_num_steps']) 620 parser.add_argument("--dynamics_seq_len", type=int, 621 default=TRAINING_DEFAULTS['dynamics_seq_len']) 622 623 # Dream training hyperparameters 624 parser.add_argument("--dream_batch_size", type=int, 625 default=TRAINING_DEFAULTS['dream_batch_size']) 626 parser.add_argument("--dream_lr", type=float, 627 default=TRAINING_DEFAULTS['dream_lr']) 628 parser.add_argument("--dream_max_grad_norm", type=float, 629 default=TRAINING_DEFAULTS['dream_max_grad_norm']) 630 parser.add_argument("--dream_num_steps", type=int, 631 default=TRAINING_DEFAULTS['dream_num_steps']) 632 parser.add_argument("--dream_generate_timesteps", type=int, 633 default=TRAINING_DEFAULTS['dream_generate_timesteps']) 634 635 # Model variant 636 parser.add_argument("--use_lewm", action="store_true", 637 help="Use LeWM dynamics (JEPA-style next-embedding prediction) " 638 "instead of flow matching") 639 640 # Logging 641 parser.add_argument("--use_tensorboard", action="store_true", 642 help="Enable TensorBoard logging") 643 parser.add_argument("--log_video", action="store_true", 644 help="Log video reconstructions (Phase 1 only)") 645 parser.add_argument("--log_video_every", type=int, default=1000, 646 help="Log video every N steps") 647 648 # Resume / checkpointing 649 parser.add_argument("--resume_from", type=str, default=None, 650 help="Resume training from an accelerator.save_state() dump. " 651 "Accepts either a specific state-<N>/ directory, the " 652 "parent dir containing latest_state.txt, or the literal " 653 "'latest' to read the pointer inside --output_dir.") 654 parser.add_argument("--dream_checkpoint_every", type=int, default=500, 655 help="Phase 3: save a DreamTrainer checkpoint every N steps " 656 "(0 to disable). Phases 1/2 use the trainer defaults.") 657 658 # Performance / hardware 659 parser.add_argument("--num_workers", type=int, default=4, 660 help="DataLoader num_workers. Parallel frame decoding so " 661 "data prep does not bottleneck GPU training.") 662 parser.add_argument("--no_pin_memory", action="store_true", 663 help="Disable DataLoader pin_memory (default: enabled).") 664 parser.add_argument("--prefetch_factor", type=int, default=4, 665 help="DataLoader prefetch_factor (per worker). Bigger " 666 "value keeps the GPU fed while workers decode " 667 "the next batches of video.") 668 parser.add_argument("--mixed_precision", type=str, default="no", 669 choices=["no", "fp16", "bf16"], 670 help="Mixed precision mode passed to Accelerate. " 671 "Use 'fp16' on V100S (Volta), 'bf16' on A100/H100.") 672 parser.add_argument("--allow_cpu", action="store_true", 673 help="Allow running on CPU when torch.cuda.is_available() " 674 "is False. Without this flag, a missing GPU is a hard " 675 "error — preventing a silent 24h CPU fallback.") 676 677 args = parser.parse_args() 678 679 # Validate arguments 680 if args.phase in [1, 2]: 681 assert args.data_dir is not None, f"Phase {args.phase} requires --data_dir" 682 683 # Run the appropriate phase 684 if args.phase == 1: 685 train_tokenizer(args) 686 elif args.phase == 2: 687 train_dynamics(args) 688 elif args.phase == 3: 689 train_agent(args) 690 691 692if __name__ == "__main__": 693 main()
247def train_tokenizer(args): 248 """Train the VideoTokenizer on Minecraft video data. 249 250 The tokenizer learns to compress 384x640 RGB frames into compact 251 latent representations using: 252 - Patch embedding (16x16 patches → 24x40 spatial grid) 253 - Axial space-time transformer encoder 254 - Latent bottleneck with Tanh activation 255 - MAE-style patch masking for regularization 256 - LPIPS perceptual loss for visual quality 257 - Temporal/spatial decorrelation losses 258 259 The encoder output shape per frame: (num_latent_tokens, dim_latent) = (512, 16) 260 """ 261 print("=" * 60) 262 print("PHASE 1: Training Video Tokenizer") 263 print("=" * 60) 264 265 use_cpu = _enforce_cuda_or_exit(args, "Phase 1") 266 resume_from = _resolve_resume_from(args.resume_from, args.output_dir) 267 268 # Load dataset — tokenizer only needs video, no actions 269 dataset = MinecraftVPTDataset( 270 data_dir=args.data_dir, 271 seq_len=args.tokenizer_seq_len, 272 stride=args.tokenizer_seq_len, 273 image_height=384, 274 image_width=640, 275 max_trajectories=args.max_trajectories, 276 ) 277 278 # The VideoTokenizerTrainer expects a dataset that yields video tensors. 279 # Our dataset yields dicts, so we wrap it to extract just the video. 280 class VideoOnlyDataset(torch.utils.data.Dataset): 281 """Adapter that returns only the ``'video'`` key from MinecraftVPTDataset.""" 282 283 def __init__(self, base_dataset): 284 self.base = base_dataset 285 286 def __len__(self): 287 return len(self.base) 288 289 def __getitem__(self, idx): 290 return self.base[idx]['video'] # (3, T, H, W) 291 292 video_dataset = VideoOnlyDataset(dataset) 293 294 if len(video_dataset) == 0: 295 raise RuntimeError( 296 f"No training clips were created from data in '{args.data_dir}'. " 297 f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. " 298 f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)" 299 ) 300 301 # Create tokenizer 302 tokenizer = VideoTokenizer(**TOKENIZER_DEFAULTS) 303 print(f"VideoTokenizer parameters: {sum(p.numel() for p in tokenizer.parameters()):,}") 304 305 # Train 306 trainer = VideoTokenizerTrainer( 307 model=tokenizer, 308 dataset=video_dataset, 309 batch_size=args.tokenizer_batch_size, 310 learning_rate=args.tokenizer_lr, 311 max_grad_norm=args.tokenizer_max_grad_norm, 312 num_train_steps=args.num_steps or args.tokenizer_num_steps, 313 cpu=use_cpu, 314 mixed_precision=args.mixed_precision, 315 use_tensorboard_logger=args.use_tensorboard, 316 log_dir=args.output_dir, 317 log_video=args.log_video, 318 video_fps=20, 319 log_video_every=args.log_video_every, 320 checkpoint_folder=args.output_dir, 321 dataloader_num_workers=args.num_workers, 322 dataloader_pin_memory=not args.no_pin_memory, 323 dataloader_prefetch_factor=args.prefetch_factor, 324 resume_from=resume_from, 325 ) 326 327 trainer() 328 329 # Save checkpoint 330 os.makedirs(args.output_dir, exist_ok=True) 331 ckpt_path = os.path.join(args.output_dir, "tokenizer.pt") 332 torch.save({ 333 'model': tokenizer.state_dict(), 334 'config': TOKENIZER_DEFAULTS, 335 }, ckpt_path) 336 print(f"Tokenizer saved to {ckpt_path}") 337 338 # Only prune the last state-<N>/ if training finished cleanly. If the 339 # trainer was cancelled via SIGTERM (scancel / walltime), keep the 340 # state dir so the user can resume with --resume_from latest. 341 if not getattr(trainer, 'cancelled', False): 342 _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='tokenizer') 343 else: 344 print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")
Train the VideoTokenizer on Minecraft video data.
The tokenizer learns to compress 384x640 RGB frames into compact latent representations using:
- Patch embedding (16x16 patches → 24x40 spatial grid)
- Axial space-time transformer encoder
- Latent bottleneck with Tanh activation
- MAE-style patch masking for regularization
- LPIPS perceptual loss for visual quality
- Temporal/spatial decorrelation losses
The encoder output shape per frame: (num_latent_tokens, dim_latent) = (512, 16)
349def train_dynamics(args): 350 """Train the DynamicsWorldModel on tokenized video + actions. 351 352 The dynamics model learns to predict future latent states using 353 flow matching with shortcut consistency training: 354 355 1. Takes clean latents from the tokenizer 356 2. Adds noise at random signal levels (flow matching) 357 3. Predicts the clean latents from noised versions 358 4. Also predicts rewards and actions (multi-token prediction) 359 360 Shortcut training (from Frans et al.): 361 - Sometimes trains with step_size > 1, allowing the model to 362 make larger jumps in the denoising process 363 - Consistency loss ensures half-step predictions compose correctly 364 - This makes generation faster at inference time 365 366 The world model processes sequences of: 367 - Spatial tokens (from latents) 368 - Action tokens (embedded discrete actions) 369 - Reward tokens (SymExp two-hot encoded) 370 - Register tokens (for temporal consistency) 371 - Agent tokens (for policy/value heads) 372 """ 373 print("=" * 60) 374 print("PHASE 2: Training Dynamics World Model") 375 print("=" * 60) 376 377 use_cpu = _enforce_cuda_or_exit(args, "Phase 2") 378 resume_from = _resolve_resume_from(args.resume_from, args.output_dir) 379 380 # Load tokenizer from Phase 1 checkpoint 381 assert args.tokenizer_ckpt is not None, "Must provide --tokenizer_ckpt for Phase 2" 382 tok_ckpt = torch.load(args.tokenizer_ckpt, map_location='cpu', weights_only=False) 383 tok_config = tok_ckpt.get('config', TOKENIZER_DEFAULTS) 384 385 tokenizer = VideoTokenizer(**tok_config) 386 tokenizer.load_state_dict(tok_ckpt['model']) 387 tokenizer.eval() 388 for p in tokenizer.parameters(): 389 p.requires_grad_(False) 390 print("Loaded frozen tokenizer") 391 392 # Load dataset with actions 393 dataset = MinecraftVPTDataset( 394 data_dir=args.data_dir, 395 seq_len=args.dynamics_seq_len, 396 stride=args.dynamics_seq_len, 397 image_height=384, 398 image_width=640, 399 max_trajectories=args.max_trajectories, 400 ) 401 402 if len(dataset) == 0: 403 raise RuntimeError( 404 f"No training clips were created from data in '{args.data_dir}'. " 405 f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. " 406 f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)" 407 ) 408 409 # Create dynamics model with the tokenizer 410 base_defaults = LEWM_DYNAMICS_DEFAULTS if args.use_lewm else DYNAMICS_DEFAULTS 411 dynamics_config = base_defaults.copy() 412 dynamics_config['num_latent_tokens'] = tok_config.get('num_latent_tokens', 16) 413 dynamics_config['dim_latent'] = tok_config.get('dim_latent', 32) 414 415 variant = "LeWM" if args.use_lewm else "Dreamer4" 416 dynamics = DynamicsWorldModel( 417 video_tokenizer=tokenizer, 418 **dynamics_config, 419 ) 420 n_params = sum(p.numel() for p in dynamics.parameters()) 421 print(f"DynamicsWorldModel ({variant}) parameters: {n_params:,}") 422 423 # Train using BehaviorCloneTrainer 424 # The trainer accepts dict batches and calls dynamics(**batch_data) 425 trainer = BehaviorCloneTrainer( 426 model=dynamics, 427 dataset=dataset, 428 batch_size=args.dynamics_batch_size, 429 learning_rate=args.dynamics_lr, 430 max_grad_norm=args.dynamics_max_grad_norm, 431 num_train_steps=args.num_steps or args.dynamics_num_steps, 432 cpu=use_cpu, 433 mixed_precision=args.mixed_precision, 434 use_tensorboard_logger=args.use_tensorboard, 435 log_dir=args.output_dir, 436 checkpoint_folder=args.output_dir, 437 dataloader_num_workers=args.num_workers, 438 dataloader_pin_memory=not args.no_pin_memory, 439 dataloader_prefetch_factor=args.prefetch_factor, 440 resume_from=resume_from, 441 ) 442 443 trainer() 444 445 # Save checkpoint 446 os.makedirs(args.output_dir, exist_ok=True) 447 ckpt_name = "lewm_dynamics.pt" if args.use_lewm else "dynamics.pt" 448 ckpt_path = os.path.join(args.output_dir, ckpt_name) 449 torch.save({ 450 'model': dynamics.state_dict(), 451 'config': dynamics_config, 452 'tokenizer_config': tok_config, 453 'use_lewm': args.use_lewm, 454 }, ckpt_path) 455 print(f"Dynamics model saved to {ckpt_path}") 456 457 if not getattr(trainer, 'cancelled', False): 458 _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='dynamics') 459 else: 460 print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")
Train the DynamicsWorldModel on tokenized video + actions.
The dynamics model learns to predict future latent states using flow matching with shortcut consistency training:
- Takes clean latents from the tokenizer
- Adds noise at random signal levels (flow matching)
- Predicts the clean latents from noised versions
- Also predicts rewards and actions (multi-token prediction)
Shortcut training (from Frans et al.):
- Sometimes trains with step_size > 1, allowing the model to make larger jumps in the denoising process
- Consistency loss ensures half-step predictions compose correctly
- This makes generation faster at inference time
The world model processes sequences of:
- Spatial tokens (from latents)
- Action tokens (embedded discrete actions)
- Reward tokens (SymExp two-hot encoded)
- Register tokens (for temporal consistency)
- Agent tokens (for policy/value heads)
465def train_agent(args): 466 """Train the policy/value heads using imagined rollouts. 467 468 DreamTrainer generates experience entirely inside the world model: 469 470 1. Start with random noise latents 471 2. Iteratively denoise using the dynamics model (generate()) 472 3. At each step, the agent token embedding provides: 473 - Policy distribution (via policy head MLP) 474 - Value estimate (via value head MLP) 475 - Predicted reward (via reward prediction heads) 476 4. After generating a trajectory, compute GAE returns 477 5. Update policy head with PPO/PMPO loss 478 6. Update value head with clipped value loss 479 480 Only the policy and value head parameters are updated. 481 The world model transformer weights remain frozen. 482 483 This is the key advantage of Dreamer-style methods: 484 the agent can train on unlimited imagined experience 485 without needing access to the real Minecraft environment. 486 """ 487 print("=" * 60) 488 print("PHASE 3: Training Agent in Dreams") 489 print("=" * 60) 490 491 use_cpu = _enforce_cuda_or_exit(args, "Phase 3") 492 resume_from = _resolve_resume_from(args.resume_from, args.output_dir) 493 494 # Load dynamics model from Phase 2 495 assert args.dynamics_ckpt is not None, "Must provide --dynamics_ckpt for Phase 3" 496 dyn_ckpt = torch.load(args.dynamics_ckpt, map_location='cpu', weights_only=False) 497 is_lewm = dyn_ckpt.get('use_lewm', False) or args.use_lewm 498 default_config = LEWM_DYNAMICS_DEFAULTS if is_lewm else DYNAMICS_DEFAULTS 499 dyn_config = dyn_ckpt.get('config', default_config) 500 tok_config = dyn_ckpt.get('tokenizer_config', TOKENIZER_DEFAULTS) 501 if is_lewm: 502 print("Detected LeWM dynamics checkpoint") 503 504 # Rebuild tokenizer (frozen) 505 tokenizer = VideoTokenizer(**tok_config) 506 tokenizer.eval() 507 508 # Rebuild dynamics model 509 dynamics = DynamicsWorldModel( 510 video_tokenizer=tokenizer, 511 **dyn_config, 512 ) 513 dynamics.load_state_dict(dyn_ckpt['model']) 514 print("Loaded dynamics model") 515 516 # Freeze everything except policy and value heads 517 for p in dynamics.parameters(): 518 p.requires_grad_(False) 519 for p in dynamics.policy_head_parameters(): 520 p.requires_grad_(True) 521 for p in dynamics.value_head_parameters(): 522 p.requires_grad_(True) 523 524 n_trainable = sum(p.numel() for p in dynamics.parameters() if p.requires_grad) 525 print(f"Trainable parameters: {n_trainable:,}") 526 527 # Train using DreamTrainer 528 trainer = DreamTrainer( 529 model=dynamics, 530 batch_size=args.dream_batch_size, 531 generate_timesteps=args.dream_generate_timesteps, 532 learning_rate=args.dream_lr, 533 max_grad_norm=args.dream_max_grad_norm, 534 num_train_steps=args.num_steps or args.dream_num_steps, 535 cpu=use_cpu, 536 mixed_precision=args.mixed_precision, 537 use_tensorboard_logger=args.use_tensorboard, 538 log_dir=args.output_dir, 539 checkpoint_every=args.dream_checkpoint_every, 540 checkpoint_folder=args.output_dir, 541 resume_from=resume_from, 542 ) 543 544 trainer() 545 546 # Save final checkpoint with everything 547 os.makedirs(args.output_dir, exist_ok=True) 548 ckpt_name = "lewm_minecraft.pt" if is_lewm else "dreamer4_minecraft.pt" 549 ckpt_path = os.path.join(args.output_dir, ckpt_name) 550 torch.save({ 551 'model': dynamics.state_dict(), 552 'config': dyn_config, 553 'tokenizer_config': tok_config, 554 'use_lewm': is_lewm, 555 }, ckpt_path) 556 print(f"Trained agent saved to {ckpt_path}") 557 558 if not getattr(trainer, 'cancelled', False): 559 _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix=None) 560 else: 561 print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")
Train the policy/value heads using imagined rollouts.
DreamTrainer generates experience entirely inside the world model:
- Start with random noise latents
- Iteratively denoise using the dynamics model (generate())
- At each step, the agent token embedding provides:
- Policy distribution (via policy head MLP)
- Value estimate (via value head MLP)
- Predicted reward (via reward prediction heads)
- After generating a trajectory, compute GAE returns
- Update policy head with PPO/PMPO loss
- Update value head with clipped value loss
Only the policy and value head parameters are updated. The world model transformer weights remain frozen.
This is the key advantage of Dreamer-style methods: the agent can train on unlimited imagined experience without needing access to the real Minecraft environment.
566def main(): 567 """CLI entry point: parse arguments and dispatch to the requested phase. 568 569 Wires up the full argparse surface for all three phases (tokenizer / 570 dynamics / dream), validates that a data directory is present for 571 phases that need it, and then calls :func:`train_tokenizer`, 572 :func:`train_dynamics`, or :func:`train_agent` based on ``--phase``. 573 """ 574 parser = argparse.ArgumentParser( 575 description="Train Dreamer4 on Minecraft using VPT data" 576 ) 577 578 # Required arguments 579 parser.add_argument("--phase", type=int, required=True, choices=[1, 2, 3], 580 help="Training phase: 1=tokenizer, 2=dynamics, 3=agent") 581 parser.add_argument("--output_dir", type=str, default="./checkpoints", 582 help="Directory for saving checkpoints and logs") 583 584 # Data arguments 585 parser.add_argument("--data_dir", type=str, default=None, 586 help="Directory containing VPT .mp4/.jsonl pairs") 587 parser.add_argument("--max_trajectories", type=int, default=None, 588 help="Limit number of trajectories (for debugging)") 589 590 # Checkpoint arguments 591 parser.add_argument("--tokenizer_ckpt", type=str, default=None, 592 help="Path to tokenizer checkpoint (Phase 2)") 593 parser.add_argument("--dynamics_ckpt", type=str, default=None, 594 help="Path to dynamics checkpoint (Phase 3)") 595 596 # Training arguments 597 parser.add_argument("--num_steps", type=int, default=None, 598 help="Override number of training steps") 599 600 # Tokenizer hyperparameters 601 parser.add_argument("--tokenizer_batch_size", type=int, 602 default=TRAINING_DEFAULTS['tokenizer_batch_size']) 603 parser.add_argument("--tokenizer_lr", type=float, 604 default=TRAINING_DEFAULTS['tokenizer_lr']) 605 parser.add_argument("--tokenizer_max_grad_norm", type=float, 606 default=TRAINING_DEFAULTS['tokenizer_max_grad_norm']) 607 parser.add_argument("--tokenizer_num_steps", type=int, 608 default=TRAINING_DEFAULTS['tokenizer_num_steps']) 609 parser.add_argument("--tokenizer_seq_len", type=int, 610 default=TRAINING_DEFAULTS['tokenizer_seq_len']) 611 612 # Dynamics hyperparameters 613 parser.add_argument("--dynamics_batch_size", type=int, 614 default=TRAINING_DEFAULTS['dynamics_batch_size']) 615 parser.add_argument("--dynamics_lr", type=float, 616 default=TRAINING_DEFAULTS['dynamics_lr']) 617 parser.add_argument("--dynamics_max_grad_norm", type=float, 618 default=TRAINING_DEFAULTS['dynamics_max_grad_norm']) 619 parser.add_argument("--dynamics_num_steps", type=int, 620 default=TRAINING_DEFAULTS['dynamics_num_steps']) 621 parser.add_argument("--dynamics_seq_len", type=int, 622 default=TRAINING_DEFAULTS['dynamics_seq_len']) 623 624 # Dream training hyperparameters 625 parser.add_argument("--dream_batch_size", type=int, 626 default=TRAINING_DEFAULTS['dream_batch_size']) 627 parser.add_argument("--dream_lr", type=float, 628 default=TRAINING_DEFAULTS['dream_lr']) 629 parser.add_argument("--dream_max_grad_norm", type=float, 630 default=TRAINING_DEFAULTS['dream_max_grad_norm']) 631 parser.add_argument("--dream_num_steps", type=int, 632 default=TRAINING_DEFAULTS['dream_num_steps']) 633 parser.add_argument("--dream_generate_timesteps", type=int, 634 default=TRAINING_DEFAULTS['dream_generate_timesteps']) 635 636 # Model variant 637 parser.add_argument("--use_lewm", action="store_true", 638 help="Use LeWM dynamics (JEPA-style next-embedding prediction) " 639 "instead of flow matching") 640 641 # Logging 642 parser.add_argument("--use_tensorboard", action="store_true", 643 help="Enable TensorBoard logging") 644 parser.add_argument("--log_video", action="store_true", 645 help="Log video reconstructions (Phase 1 only)") 646 parser.add_argument("--log_video_every", type=int, default=1000, 647 help="Log video every N steps") 648 649 # Resume / checkpointing 650 parser.add_argument("--resume_from", type=str, default=None, 651 help="Resume training from an accelerator.save_state() dump. " 652 "Accepts either a specific state-<N>/ directory, the " 653 "parent dir containing latest_state.txt, or the literal " 654 "'latest' to read the pointer inside --output_dir.") 655 parser.add_argument("--dream_checkpoint_every", type=int, default=500, 656 help="Phase 3: save a DreamTrainer checkpoint every N steps " 657 "(0 to disable). Phases 1/2 use the trainer defaults.") 658 659 # Performance / hardware 660 parser.add_argument("--num_workers", type=int, default=4, 661 help="DataLoader num_workers. Parallel frame decoding so " 662 "data prep does not bottleneck GPU training.") 663 parser.add_argument("--no_pin_memory", action="store_true", 664 help="Disable DataLoader pin_memory (default: enabled).") 665 parser.add_argument("--prefetch_factor", type=int, default=4, 666 help="DataLoader prefetch_factor (per worker). Bigger " 667 "value keeps the GPU fed while workers decode " 668 "the next batches of video.") 669 parser.add_argument("--mixed_precision", type=str, default="no", 670 choices=["no", "fp16", "bf16"], 671 help="Mixed precision mode passed to Accelerate. " 672 "Use 'fp16' on V100S (Volta), 'bf16' on A100/H100.") 673 parser.add_argument("--allow_cpu", action="store_true", 674 help="Allow running on CPU when torch.cuda.is_available() " 675 "is False. Without this flag, a missing GPU is a hard " 676 "error — preventing a silent 24h CPU fallback.") 677 678 args = parser.parse_args() 679 680 # Validate arguments 681 if args.phase in [1, 2]: 682 assert args.data_dir is not None, f"Phase {args.phase} requires --data_dir" 683 684 # Run the appropriate phase 685 if args.phase == 1: 686 train_tokenizer(args) 687 elif args.phase == 2: 688 train_dynamics(args) 689 elif args.phase == 3: 690 train_agent(args)
CLI entry point: parse arguments and dispatch to the requested phase.
Wires up the full argparse surface for all three phases (tokenizer /
dynamics / dream), validates that a data directory is present for
phases that need it, and then calls train_tokenizer(),
train_dynamics(), or train_agent() based on --phase.