train_dreamer4_minecraft

Training script for Dreamer4 on Minecraft using VPT data.

This script implements the full 3-phase Dreamer4 training pipeline:

Phase 1: Video Tokenizer Training Train the VideoTokenizer to compress 384x640 Minecraft frames into compact latent representations. Uses MAE-style masking, LPIPS perceptual loss, and axial space-time attention.

Phase 2: World Model (Dynamics) Training Train the DynamicsWorldModel on tokenized video + actions from VPT data. Uses flow matching with shortcut consistency training. The model learns to predict future latent states given past observations and actions.

Phase 3: Agent Training (Dream-based RL) Train the policy/value heads inside the learned world model using imagined rollouts (DreamTrainer). No real environment needed — the world model generates synthetic experience for PPO/PMPO training.

Each phase saves checkpoints that the next phase loads.

Usage:

Phase 1: Train tokenizer

python train_dreamer4_minecraft.py --phase 1 --data_dir ./data/vpt-recordings --output_dir ./checkpoints --num_steps 50000

Phase 2: Train world model

python train_dreamer4_minecraft.py --phase 2 --data_dir ./data/vpt-recordings --output_dir ./checkpoints --tokenizer_ckpt ./checkpoints/tokenizer.pt --num_steps 100000

Phase 3: Train agent in dreams

python train_dreamer4_minecraft.py --phase 3 --output_dir ./checkpoints --dynamics_ckpt ./checkpoints/dynamics.pt --num_steps 50000

  1"""
  2Training script for Dreamer4 on Minecraft using VPT data.
  3
  4This script implements the full 3-phase Dreamer4 training pipeline:
  5
  6  Phase 1: Video Tokenizer Training
  7    Train the VideoTokenizer to compress 384x640 Minecraft frames into
  8    compact latent representations. Uses MAE-style masking, LPIPS perceptual
  9    loss, and axial space-time attention.
 10
 11  Phase 2: World Model (Dynamics) Training
 12    Train the DynamicsWorldModel on tokenized video + actions from VPT data.
 13    Uses flow matching with shortcut consistency training. The model learns
 14    to predict future latent states given past observations and actions.
 15
 16  Phase 3: Agent Training (Dream-based RL)
 17    Train the policy/value heads inside the learned world model using
 18    imagined rollouts (DreamTrainer). No real environment needed — the
 19    world model generates synthetic experience for PPO/PMPO training.
 20
 21Each phase saves checkpoints that the next phase loads.
 22
 23Usage:
 24    # Phase 1: Train tokenizer
 25    python train_dreamer4_minecraft.py --phase 1 \
 26        --data_dir ./data/vpt-recordings \
 27        --output_dir ./checkpoints \
 28        --num_steps 50000
 29
 30    # Phase 2: Train world model
 31    python train_dreamer4_minecraft.py --phase 2 \
 32        --data_dir ./data/vpt-recordings \
 33        --output_dir ./checkpoints \
 34        --tokenizer_ckpt ./checkpoints/tokenizer.pt \
 35        --num_steps 100000
 36
 37    # Phase 3: Train agent in dreams
 38    python train_dreamer4_minecraft.py --phase 3 \
 39        --output_dir ./checkpoints \
 40        --dynamics_ckpt ./checkpoints/dynamics.pt \
 41        --num_steps 50000
 42"""
 43
 44import os
 45import argparse
 46import sys
 47from pathlib import Path
 48
 49import torch
 50
 51from dreamer4 import VideoTokenizer, DynamicsWorldModel
 52from dreamer4.trainers import (
 53    VideoTokenizerTrainer,
 54    BehaviorCloneTrainer,
 55    DreamTrainer,
 56    _prune_previous_checkpoint,
 57)
 58
 59
 60def _cleanup_last_intermediate_checkpoint(output_dir, legacy_prefix):
 61    """After the final phase-level .pt is saved, the most recent
 62    state-<N>/ dir and <prefix>-<N>.pt files are no longer needed —
 63    the per-step pruner already removed every *earlier* one, so this
 64    just handles the tail. new_step=-1 guarantees the sentinel never
 65    matches a real step, so the helper always deletes the target.
 66    """
 67    pointer = Path(output_dir) / 'latest_state.txt'
 68    if not pointer.exists():
 69        return
 70    last_name = pointer.read_text().strip()
 71    _prune_previous_checkpoint(
 72        Path(output_dir), last_name, new_step=-1, legacy_prefix=legacy_prefix,
 73    )
 74    try:
 75        pointer.unlink()
 76    except OSError:
 77        pass
 78
 79from minecraft_vpt_dataset import (
 80    MinecraftVPTDataset,
 81    DREAMER4_NUM_DISCRETE_ACTIONS,
 82)
 83
 84# cuDNN autotuner picks the fastest conv/attn algorithm for the (fixed) input
 85# shapes used by this script. One-line win, safe because input shapes are
 86# static (384x640 video, constant seq_len per phase).
 87torch.backends.cudnn.benchmark = True
 88
 89
 90# ─── CUDA / resume helpers ───────────────────────────────────────────
 91
 92def _enforce_cuda_or_exit(args, phase_name: str) -> bool:
 93    """Return True if the phase should run on CPU, False for GPU.
 94
 95    If CUDA is unavailable and --allow_cpu was not passed, print a loud
 96    warning and exit(1). This prevents silently burning HPC walltime
 97    on a CPU fallback caused by e.g. a CUDA driver version mismatch.
 98    """
 99    if torch.cuda.is_available():
100        print(f"[{phase_name}] CUDA OK — device={torch.cuda.get_device_name(0)} "
101              f"| torch={torch.__version__} | cuda={torch.version.cuda}")
102        return False
103
104    msg = (
105        "\n" + "=" * 70 + "\n"
106        f"[{phase_name}] ERROR: torch.cuda.is_available() is False.\n"
107        "This is almost always a CUDA driver / PyTorch mismatch.\n"
108        f"  torch version : {torch.__version__}\n"
109        f"  torch.cuda    : {torch.version.cuda}\n"
110        "Training on CPU for this model size is ~100x slower than a V100S.\n"
111        "To fix: install a PyTorch build matching your cluster driver, e.g.\n"
112        "pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
113        "or ask your admin which CUDA version the driver supports.\n"
114        "\n"
115        "Pass --allow_cpu to override this check (e.g. for a laptop smoke test).\n"
116        + "=" * 70 + "\n"
117    )
118    if not getattr(args, 'allow_cpu', False):
119        print(msg, file=sys.stderr)
120        sys.exit(1)
121    print(msg, file=sys.stderr)
122    print(f"[{phase_name}] --allow_cpu set, proceeding on CPU (will be slow).",
123          file=sys.stderr)
124    return True
125
126
127def _resolve_resume_from(raw: str | None, output_dir: str) -> str | None:
128    """Resolve --resume_from into a concrete state-<N> directory path.
129
130    Acceptable inputs:
131      None                   → no resume
132      'latest'               → look in output_dir/latest_state.txt
133      <path to state-N dir>  → use as-is
134      <path to dir containing state-N dirs>
135                             → look for latest_state.txt inside it
136    """
137    if raw is None:
138        return None
139
140    if raw == 'latest':
141        base = Path(output_dir)
142    else:
143        base = Path(raw)
144
145    if base.is_dir():
146        # Accept either a state-<N> dir directly, or a parent with a pointer.
147        pointer = base / 'latest_state.txt'
148        if pointer.exists():
149            target = base / pointer.read_text().strip()
150            if target.exists():
151                print(f"[resume] using {target} from {pointer}")
152                return str(target)
153            print(f"[resume] pointer {pointer} references {target} but it does not exist")
154            return None
155        # Might already be a state-<N> dir
156        if (base / 'random_states_0.pkl').exists() or any(base.glob('*.safetensors')):
157            print(f"[resume] using state dir {base}")
158            return str(base)
159        print(f"[resume] {base} is not a state dir and has no latest_state.txt; starting fresh")
160        return None
161
162    print(f"[resume] resume_from path {base} does not exist; starting fresh")
163    return None
164
165
166# ─── Default Hyperparameters ────────────────────────────────────────
167
168# Paper-matched defaults for Minecraft at 384x640 resolution (360x640 zero-padded).
169# See data.txt: 16x16 patches → 960 tokens, bottleneck (N_b=512)x(D_b=16).
170
171TOKENIZER_DEFAULTS = dict(
172    dim=256,                    # Hidden dimension of transformer
173    dim_latent=16,              # Latent bottleneck dimension (D_b=16)
174    patch_size=16,              # 384/16=24, 640/16=40 → 960 patch tokens
175    image_height=384,           # VPT 360x640 zero-padded to 384x640
176    image_width=640,
177    num_latent_tokens=512,      # Bottleneck token count (N_b=512)
178    encoder_depth=4,            # Transformer depth
179    decoder_depth=4,
180    time_block_every=4,         # Temporal attention every 4th block
181    attn_heads=4,
182    attn_dim_head=64,
183    lpips_loss_weight=0.2,      # Perceptual loss weight
184    per_image_patch_mask_prob=(0.0, 0.9),  # MAE masking range
185    use_loss_normalization=True,
186)
187
188DYNAMICS_DEFAULTS = dict(
189    dim=256,                    # Hidden dimension
190    dim_latent=16,              # Must match tokenizer (D_b=16)
191    max_steps=64,               # K_max for flow matching (power of 2)
192    num_register_tokens=8,      # Register tokens for temporal consistency
193    num_spatial_tokens=256,     # Paper: N_z=256 spatial tokens
194    num_latent_tokens=512,      # Must match tokenizer (N_b=512)
195    depth=8,                    # Transformer depth
196    time_block_every=4,
197    attn_heads=4,
198    attn_dim_head=64,
199    use_time_rnn=True,          # GRU on temporal blocks
200    # Action space: 20 binary buttons + 1 camera (121 choices)
201    num_discrete_actions=DREAMER4_NUM_DISCRETE_ACTIONS,
202    num_continuous_actions=0,   # All actions are discrete
203    multi_token_pred_len=8,     # Multi-token prediction horizon
204    pred_orig_latent=True,      # x-space prediction (better than v-space)
205    # RL hyperparameters (used in Phase 3)
206    gae_discount_factor=0.997,
207    gae_lambda=0.95,
208    ppo_eps_clip=0.2,
209    policy_entropy_weight=0.01,
210)
211
212# LeWM dynamics: replaces flow matching with JEPA-style next-embedding prediction
213LEWM_DYNAMICS_DEFAULTS = dict(
214    **DYNAMICS_DEFAULTS,
215    use_lewm_dynamics=True,         # Enable LeWM mode
216    lewm_loss_weight=1.0,           # Next-embedding prediction loss weight
217    lewm_sigreg_loss_weight=0.05,   # SIGReg regularization weight
218    lewm_layer=-1,                  # Use last transformer layer for prediction
219    lewm_action_conditioned=True,   # Condition predictions on actions
220)
221
222TRAINING_DEFAULTS = dict(
223    # Phase 1
224    tokenizer_batch_size=2,             # Keep small — AttentionResidual layers are memory-intensive
225    tokenizer_lr=3e-4,
226    tokenizer_num_steps=50000,
227    tokenizer_max_grad_norm=1.0,
228    tokenizer_seq_len=16,
229    # Phase 2
230    dynamics_batch_size=4,
231    dynamics_lr=3e-4,
232    dynamics_num_steps=100000,
233    dynamics_max_grad_norm=1.0,
234    dynamics_seq_len=16,
235    # Phase 3
236    dream_batch_size=16,
237    dream_lr=3e-4,
238    dream_num_steps=50000,
239    dream_max_grad_norm=1.0,
240    dream_generate_timesteps=16,
241)
242
243
244# ─── Phase 1: Train Video Tokenizer ────────────────────────────────
245
246def train_tokenizer(args):
247    """Train the VideoTokenizer on Minecraft video data.
248
249    The tokenizer learns to compress 384x640 RGB frames into compact
250    latent representations using:
251      - Patch embedding (16x16 patches → 24x40 spatial grid)
252      - Axial space-time transformer encoder
253      - Latent bottleneck with Tanh activation
254      - MAE-style patch masking for regularization
255      - LPIPS perceptual loss for visual quality
256      - Temporal/spatial decorrelation losses
257
258    The encoder output shape per frame: (num_latent_tokens, dim_latent) = (512, 16)
259    """
260    print("=" * 60)
261    print("PHASE 1: Training Video Tokenizer")
262    print("=" * 60)
263
264    use_cpu = _enforce_cuda_or_exit(args, "Phase 1")
265    resume_from = _resolve_resume_from(args.resume_from, args.output_dir)
266
267    # Load dataset — tokenizer only needs video, no actions
268    dataset = MinecraftVPTDataset(
269        data_dir=args.data_dir,
270        seq_len=args.tokenizer_seq_len,
271        stride=args.tokenizer_seq_len,
272        image_height=384,
273        image_width=640,
274        max_trajectories=args.max_trajectories,
275    )
276
277    # The VideoTokenizerTrainer expects a dataset that yields video tensors.
278    # Our dataset yields dicts, so we wrap it to extract just the video.
279    class VideoOnlyDataset(torch.utils.data.Dataset):
280        """Adapter that returns only the ``'video'`` key from MinecraftVPTDataset."""
281
282        def __init__(self, base_dataset):
283            self.base = base_dataset
284
285        def __len__(self):
286            return len(self.base)
287
288        def __getitem__(self, idx):
289            return self.base[idx]['video']  # (3, T, H, W)
290
291    video_dataset = VideoOnlyDataset(dataset)
292
293    if len(video_dataset) == 0:
294        raise RuntimeError(
295            f"No training clips were created from data in '{args.data_dir}'. "
296            f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. "
297            f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)"
298        )
299
300    # Create tokenizer
301    tokenizer = VideoTokenizer(**TOKENIZER_DEFAULTS)
302    print(f"VideoTokenizer parameters: {sum(p.numel() for p in tokenizer.parameters()):,}")
303
304    # Train
305    trainer = VideoTokenizerTrainer(
306        model=tokenizer,
307        dataset=video_dataset,
308        batch_size=args.tokenizer_batch_size,
309        learning_rate=args.tokenizer_lr,
310        max_grad_norm=args.tokenizer_max_grad_norm,
311        num_train_steps=args.num_steps or args.tokenizer_num_steps,
312        cpu=use_cpu,
313        mixed_precision=args.mixed_precision,
314        use_tensorboard_logger=args.use_tensorboard,
315        log_dir=args.output_dir,
316        log_video=args.log_video,
317        video_fps=20,
318        log_video_every=args.log_video_every,
319        checkpoint_folder=args.output_dir,
320        dataloader_num_workers=args.num_workers,
321        dataloader_pin_memory=not args.no_pin_memory,
322        dataloader_prefetch_factor=args.prefetch_factor,
323        resume_from=resume_from,
324    )
325
326    trainer()
327
328    # Save checkpoint
329    os.makedirs(args.output_dir, exist_ok=True)
330    ckpt_path = os.path.join(args.output_dir, "tokenizer.pt")
331    torch.save({
332        'model': tokenizer.state_dict(),
333        'config': TOKENIZER_DEFAULTS,
334    }, ckpt_path)
335    print(f"Tokenizer saved to {ckpt_path}")
336
337    # Only prune the last state-<N>/ if training finished cleanly. If the
338    # trainer was cancelled via SIGTERM (scancel / walltime), keep the
339    # state dir so the user can resume with --resume_from latest.
340    if not getattr(trainer, 'cancelled', False):
341        _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='tokenizer')
342    else:
343        print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")
344
345
346# ─── Phase 2: Train Dynamics World Model ───────────────────────────
347
348def train_dynamics(args):
349    """Train the DynamicsWorldModel on tokenized video + actions.
350
351    The dynamics model learns to predict future latent states using
352    flow matching with shortcut consistency training:
353
354    1. Takes clean latents from the tokenizer
355    2. Adds noise at random signal levels (flow matching)
356    3. Predicts the clean latents from noised versions
357    4. Also predicts rewards and actions (multi-token prediction)
358
359    Shortcut training (from Frans et al.):
360      - Sometimes trains with step_size > 1, allowing the model to
361        make larger jumps in the denoising process
362      - Consistency loss ensures half-step predictions compose correctly
363      - This makes generation faster at inference time
364
365    The world model processes sequences of:
366      - Spatial tokens (from latents)
367      - Action tokens (embedded discrete actions)
368      - Reward tokens (SymExp two-hot encoded)
369      - Register tokens (for temporal consistency)
370      - Agent tokens (for policy/value heads)
371    """
372    print("=" * 60)
373    print("PHASE 2: Training Dynamics World Model")
374    print("=" * 60)
375
376    use_cpu = _enforce_cuda_or_exit(args, "Phase 2")
377    resume_from = _resolve_resume_from(args.resume_from, args.output_dir)
378
379    # Load tokenizer from Phase 1 checkpoint
380    assert args.tokenizer_ckpt is not None, "Must provide --tokenizer_ckpt for Phase 2"
381    tok_ckpt = torch.load(args.tokenizer_ckpt, map_location='cpu', weights_only=False)
382    tok_config = tok_ckpt.get('config', TOKENIZER_DEFAULTS)
383
384    tokenizer = VideoTokenizer(**tok_config)
385    tokenizer.load_state_dict(tok_ckpt['model'])
386    tokenizer.eval()
387    for p in tokenizer.parameters():
388        p.requires_grad_(False)
389    print("Loaded frozen tokenizer")
390
391    # Load dataset with actions
392    dataset = MinecraftVPTDataset(
393        data_dir=args.data_dir,
394        seq_len=args.dynamics_seq_len,
395        stride=args.dynamics_seq_len,
396        image_height=384,
397        image_width=640,
398        max_trajectories=args.max_trajectories,
399    )
400
401    if len(dataset) == 0:
402        raise RuntimeError(
403            f"No training clips were created from data in '{args.data_dir}'. "
404            f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. "
405            f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)"
406        )
407
408    # Create dynamics model with the tokenizer
409    base_defaults = LEWM_DYNAMICS_DEFAULTS if args.use_lewm else DYNAMICS_DEFAULTS
410    dynamics_config = base_defaults.copy()
411    dynamics_config['num_latent_tokens'] = tok_config.get('num_latent_tokens', 16)
412    dynamics_config['dim_latent'] = tok_config.get('dim_latent', 32)
413
414    variant = "LeWM" if args.use_lewm else "Dreamer4"
415    dynamics = DynamicsWorldModel(
416        video_tokenizer=tokenizer,
417        **dynamics_config,
418    )
419    n_params = sum(p.numel() for p in dynamics.parameters())
420    print(f"DynamicsWorldModel ({variant}) parameters: {n_params:,}")
421
422    # Train using BehaviorCloneTrainer
423    # The trainer accepts dict batches and calls dynamics(**batch_data)
424    trainer = BehaviorCloneTrainer(
425        model=dynamics,
426        dataset=dataset,
427        batch_size=args.dynamics_batch_size,
428        learning_rate=args.dynamics_lr,
429        max_grad_norm=args.dynamics_max_grad_norm,
430        num_train_steps=args.num_steps or args.dynamics_num_steps,
431        cpu=use_cpu,
432        mixed_precision=args.mixed_precision,
433        use_tensorboard_logger=args.use_tensorboard,
434        log_dir=args.output_dir,
435        checkpoint_folder=args.output_dir,
436        dataloader_num_workers=args.num_workers,
437        dataloader_pin_memory=not args.no_pin_memory,
438        dataloader_prefetch_factor=args.prefetch_factor,
439        resume_from=resume_from,
440    )
441
442    trainer()
443
444    # Save checkpoint
445    os.makedirs(args.output_dir, exist_ok=True)
446    ckpt_name = "lewm_dynamics.pt" if args.use_lewm else "dynamics.pt"
447    ckpt_path = os.path.join(args.output_dir, ckpt_name)
448    torch.save({
449        'model': dynamics.state_dict(),
450        'config': dynamics_config,
451        'tokenizer_config': tok_config,
452        'use_lewm': args.use_lewm,
453    }, ckpt_path)
454    print(f"Dynamics model saved to {ckpt_path}")
455
456    if not getattr(trainer, 'cancelled', False):
457        _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='dynamics')
458    else:
459        print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")
460
461
462# ─── Phase 3: Dream-based Agent Training ───────────────────────────
463
464def train_agent(args):
465    """Train the policy/value heads using imagined rollouts.
466
467    DreamTrainer generates experience entirely inside the world model:
468
469    1. Start with random noise latents
470    2. Iteratively denoise using the dynamics model (generate())
471    3. At each step, the agent token embedding provides:
472       - Policy distribution (via policy head MLP)
473       - Value estimate (via value head MLP)
474       - Predicted reward (via reward prediction heads)
475    4. After generating a trajectory, compute GAE returns
476    5. Update policy head with PPO/PMPO loss
477    6. Update value head with clipped value loss
478
479    Only the policy and value head parameters are updated.
480    The world model transformer weights remain frozen.
481
482    This is the key advantage of Dreamer-style methods:
483    the agent can train on unlimited imagined experience
484    without needing access to the real Minecraft environment.
485    """
486    print("=" * 60)
487    print("PHASE 3: Training Agent in Dreams")
488    print("=" * 60)
489
490    use_cpu = _enforce_cuda_or_exit(args, "Phase 3")
491    resume_from = _resolve_resume_from(args.resume_from, args.output_dir)
492
493    # Load dynamics model from Phase 2
494    assert args.dynamics_ckpt is not None, "Must provide --dynamics_ckpt for Phase 3"
495    dyn_ckpt = torch.load(args.dynamics_ckpt, map_location='cpu', weights_only=False)
496    is_lewm = dyn_ckpt.get('use_lewm', False) or args.use_lewm
497    default_config = LEWM_DYNAMICS_DEFAULTS if is_lewm else DYNAMICS_DEFAULTS
498    dyn_config = dyn_ckpt.get('config', default_config)
499    tok_config = dyn_ckpt.get('tokenizer_config', TOKENIZER_DEFAULTS)
500    if is_lewm:
501        print("Detected LeWM dynamics checkpoint")
502
503    # Rebuild tokenizer (frozen)
504    tokenizer = VideoTokenizer(**tok_config)
505    tokenizer.eval()
506
507    # Rebuild dynamics model
508    dynamics = DynamicsWorldModel(
509        video_tokenizer=tokenizer,
510        **dyn_config,
511    )
512    dynamics.load_state_dict(dyn_ckpt['model'])
513    print("Loaded dynamics model")
514
515    # Freeze everything except policy and value heads
516    for p in dynamics.parameters():
517        p.requires_grad_(False)
518    for p in dynamics.policy_head_parameters():
519        p.requires_grad_(True)
520    for p in dynamics.value_head_parameters():
521        p.requires_grad_(True)
522
523    n_trainable = sum(p.numel() for p in dynamics.parameters() if p.requires_grad)
524    print(f"Trainable parameters: {n_trainable:,}")
525
526    # Train using DreamTrainer
527    trainer = DreamTrainer(
528        model=dynamics,
529        batch_size=args.dream_batch_size,
530        generate_timesteps=args.dream_generate_timesteps,
531        learning_rate=args.dream_lr,
532        max_grad_norm=args.dream_max_grad_norm,
533        num_train_steps=args.num_steps or args.dream_num_steps,
534        cpu=use_cpu,
535        mixed_precision=args.mixed_precision,
536        use_tensorboard_logger=args.use_tensorboard,
537        log_dir=args.output_dir,
538        checkpoint_every=args.dream_checkpoint_every,
539        checkpoint_folder=args.output_dir,
540        resume_from=resume_from,
541    )
542
543    trainer()
544
545    # Save final checkpoint with everything
546    os.makedirs(args.output_dir, exist_ok=True)
547    ckpt_name = "lewm_minecraft.pt" if is_lewm else "dreamer4_minecraft.pt"
548    ckpt_path = os.path.join(args.output_dir, ckpt_name)
549    torch.save({
550        'model': dynamics.state_dict(),
551        'config': dyn_config,
552        'tokenizer_config': tok_config,
553        'use_lewm': is_lewm,
554    }, ckpt_path)
555    print(f"Trained agent saved to {ckpt_path}")
556
557    if not getattr(trainer, 'cancelled', False):
558        _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix=None)
559    else:
560        print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")
561
562
563# ─── CLI ────────────────────────────────────────────────────────────
564
565def main():
566    """CLI entry point: parse arguments and dispatch to the requested phase.
567
568    Wires up the full argparse surface for all three phases (tokenizer /
569    dynamics / dream), validates that a data directory is present for
570    phases that need it, and then calls :func:`train_tokenizer`,
571    :func:`train_dynamics`, or :func:`train_agent` based on ``--phase``.
572    """
573    parser = argparse.ArgumentParser(
574        description="Train Dreamer4 on Minecraft using VPT data"
575    )
576
577    # Required arguments
578    parser.add_argument("--phase", type=int, required=True, choices=[1, 2, 3],
579                        help="Training phase: 1=tokenizer, 2=dynamics, 3=agent")
580    parser.add_argument("--output_dir", type=str, default="./checkpoints",
581                        help="Directory for saving checkpoints and logs")
582
583    # Data arguments
584    parser.add_argument("--data_dir", type=str, default=None,
585                        help="Directory containing VPT .mp4/.jsonl pairs")
586    parser.add_argument("--max_trajectories", type=int, default=None,
587                        help="Limit number of trajectories (for debugging)")
588
589    # Checkpoint arguments
590    parser.add_argument("--tokenizer_ckpt", type=str, default=None,
591                        help="Path to tokenizer checkpoint (Phase 2)")
592    parser.add_argument("--dynamics_ckpt", type=str, default=None,
593                        help="Path to dynamics checkpoint (Phase 3)")
594
595    # Training arguments
596    parser.add_argument("--num_steps", type=int, default=None,
597                        help="Override number of training steps")
598
599    # Tokenizer hyperparameters
600    parser.add_argument("--tokenizer_batch_size", type=int,
601                        default=TRAINING_DEFAULTS['tokenizer_batch_size'])
602    parser.add_argument("--tokenizer_lr", type=float,
603                        default=TRAINING_DEFAULTS['tokenizer_lr'])
604    parser.add_argument("--tokenizer_max_grad_norm", type=float,
605                        default=TRAINING_DEFAULTS['tokenizer_max_grad_norm'])
606    parser.add_argument("--tokenizer_num_steps", type=int,
607                        default=TRAINING_DEFAULTS['tokenizer_num_steps'])
608    parser.add_argument("--tokenizer_seq_len", type=int,
609                        default=TRAINING_DEFAULTS['tokenizer_seq_len'])
610
611    # Dynamics hyperparameters
612    parser.add_argument("--dynamics_batch_size", type=int,
613                        default=TRAINING_DEFAULTS['dynamics_batch_size'])
614    parser.add_argument("--dynamics_lr", type=float,
615                        default=TRAINING_DEFAULTS['dynamics_lr'])
616    parser.add_argument("--dynamics_max_grad_norm", type=float,
617                        default=TRAINING_DEFAULTS['dynamics_max_grad_norm'])
618    parser.add_argument("--dynamics_num_steps", type=int,
619                        default=TRAINING_DEFAULTS['dynamics_num_steps'])
620    parser.add_argument("--dynamics_seq_len", type=int,
621                        default=TRAINING_DEFAULTS['dynamics_seq_len'])
622
623    # Dream training hyperparameters
624    parser.add_argument("--dream_batch_size", type=int,
625                        default=TRAINING_DEFAULTS['dream_batch_size'])
626    parser.add_argument("--dream_lr", type=float,
627                        default=TRAINING_DEFAULTS['dream_lr'])
628    parser.add_argument("--dream_max_grad_norm", type=float,
629                        default=TRAINING_DEFAULTS['dream_max_grad_norm'])
630    parser.add_argument("--dream_num_steps", type=int,
631                        default=TRAINING_DEFAULTS['dream_num_steps'])
632    parser.add_argument("--dream_generate_timesteps", type=int,
633                        default=TRAINING_DEFAULTS['dream_generate_timesteps'])
634
635    # Model variant
636    parser.add_argument("--use_lewm", action="store_true",
637                        help="Use LeWM dynamics (JEPA-style next-embedding prediction) "
638                             "instead of flow matching")
639
640    # Logging
641    parser.add_argument("--use_tensorboard", action="store_true",
642                        help="Enable TensorBoard logging")
643    parser.add_argument("--log_video", action="store_true",
644                        help="Log video reconstructions (Phase 1 only)")
645    parser.add_argument("--log_video_every", type=int, default=1000,
646                        help="Log video every N steps")
647
648    # Resume / checkpointing
649    parser.add_argument("--resume_from", type=str, default=None,
650                        help="Resume training from an accelerator.save_state() dump. "
651                             "Accepts either a specific state-<N>/ directory, the "
652                             "parent dir containing latest_state.txt, or the literal "
653                             "'latest' to read the pointer inside --output_dir.")
654    parser.add_argument("--dream_checkpoint_every", type=int, default=500,
655                        help="Phase 3: save a DreamTrainer checkpoint every N steps "
656                             "(0 to disable). Phases 1/2 use the trainer defaults.")
657
658    # Performance / hardware
659    parser.add_argument("--num_workers", type=int, default=4,
660                        help="DataLoader num_workers. Parallel frame decoding so "
661                             "data prep does not bottleneck GPU training.")
662    parser.add_argument("--no_pin_memory", action="store_true",
663                        help="Disable DataLoader pin_memory (default: enabled).")
664    parser.add_argument("--prefetch_factor", type=int, default=4,
665                        help="DataLoader prefetch_factor (per worker). Bigger "
666                             "value keeps the GPU fed while workers decode "
667                             "the next batches of video.")
668    parser.add_argument("--mixed_precision", type=str, default="no",
669                        choices=["no", "fp16", "bf16"],
670                        help="Mixed precision mode passed to Accelerate. "
671                             "Use 'fp16' on V100S (Volta), 'bf16' on A100/H100.")
672    parser.add_argument("--allow_cpu", action="store_true",
673                        help="Allow running on CPU when torch.cuda.is_available() "
674                             "is False. Without this flag, a missing GPU is a hard "
675                             "error — preventing a silent 24h CPU fallback.")
676
677    args = parser.parse_args()
678
679    # Validate arguments
680    if args.phase in [1, 2]:
681        assert args.data_dir is not None, f"Phase {args.phase} requires --data_dir"
682
683    # Run the appropriate phase
684    if args.phase == 1:
685        train_tokenizer(args)
686    elif args.phase == 2:
687        train_dynamics(args)
688    elif args.phase == 3:
689        train_agent(args)
690
691
692if __name__ == "__main__":
693    main()
TOKENIZER_DEFAULTS = {'dim': 256, 'dim_latent': 16, 'patch_size': 16, 'image_height': 384, 'image_width': 640, 'num_latent_tokens': 512, 'encoder_depth': 4, 'decoder_depth': 4, 'time_block_every': 4, 'attn_heads': 4, 'attn_dim_head': 64, 'lpips_loss_weight': 0.2, 'per_image_patch_mask_prob': (0.0, 0.9), 'use_loss_normalization': True}
DYNAMICS_DEFAULTS = {'dim': 256, 'dim_latent': 16, 'max_steps': 64, 'num_register_tokens': 8, 'num_spatial_tokens': 256, 'num_latent_tokens': 512, 'depth': 8, 'time_block_every': 4, 'attn_heads': 4, 'attn_dim_head': 64, 'use_time_rnn': True, 'num_discrete_actions': (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 121), 'num_continuous_actions': 0, 'multi_token_pred_len': 8, 'pred_orig_latent': True, 'gae_discount_factor': 0.997, 'gae_lambda': 0.95, 'ppo_eps_clip': 0.2, 'policy_entropy_weight': 0.01}
LEWM_DYNAMICS_DEFAULTS = {'dim': 256, 'dim_latent': 16, 'max_steps': 64, 'num_register_tokens': 8, 'num_spatial_tokens': 256, 'num_latent_tokens': 512, 'depth': 8, 'time_block_every': 4, 'attn_heads': 4, 'attn_dim_head': 64, 'use_time_rnn': True, 'num_discrete_actions': (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 121), 'num_continuous_actions': 0, 'multi_token_pred_len': 8, 'pred_orig_latent': True, 'gae_discount_factor': 0.997, 'gae_lambda': 0.95, 'ppo_eps_clip': 0.2, 'policy_entropy_weight': 0.01, 'use_lewm_dynamics': True, 'lewm_loss_weight': 1.0, 'lewm_sigreg_loss_weight': 0.05, 'lewm_layer': -1, 'lewm_action_conditioned': True}
TRAINING_DEFAULTS = {'tokenizer_batch_size': 2, 'tokenizer_lr': 0.0003, 'tokenizer_num_steps': 50000, 'tokenizer_max_grad_norm': 1.0, 'tokenizer_seq_len': 16, 'dynamics_batch_size': 4, 'dynamics_lr': 0.0003, 'dynamics_num_steps': 100000, 'dynamics_max_grad_norm': 1.0, 'dynamics_seq_len': 16, 'dream_batch_size': 16, 'dream_lr': 0.0003, 'dream_num_steps': 50000, 'dream_max_grad_norm': 1.0, 'dream_generate_timesteps': 16}
def train_tokenizer(args):
247def train_tokenizer(args):
248    """Train the VideoTokenizer on Minecraft video data.
249
250    The tokenizer learns to compress 384x640 RGB frames into compact
251    latent representations using:
252      - Patch embedding (16x16 patches → 24x40 spatial grid)
253      - Axial space-time transformer encoder
254      - Latent bottleneck with Tanh activation
255      - MAE-style patch masking for regularization
256      - LPIPS perceptual loss for visual quality
257      - Temporal/spatial decorrelation losses
258
259    The encoder output shape per frame: (num_latent_tokens, dim_latent) = (512, 16)
260    """
261    print("=" * 60)
262    print("PHASE 1: Training Video Tokenizer")
263    print("=" * 60)
264
265    use_cpu = _enforce_cuda_or_exit(args, "Phase 1")
266    resume_from = _resolve_resume_from(args.resume_from, args.output_dir)
267
268    # Load dataset — tokenizer only needs video, no actions
269    dataset = MinecraftVPTDataset(
270        data_dir=args.data_dir,
271        seq_len=args.tokenizer_seq_len,
272        stride=args.tokenizer_seq_len,
273        image_height=384,
274        image_width=640,
275        max_trajectories=args.max_trajectories,
276    )
277
278    # The VideoTokenizerTrainer expects a dataset that yields video tensors.
279    # Our dataset yields dicts, so we wrap it to extract just the video.
280    class VideoOnlyDataset(torch.utils.data.Dataset):
281        """Adapter that returns only the ``'video'`` key from MinecraftVPTDataset."""
282
283        def __init__(self, base_dataset):
284            self.base = base_dataset
285
286        def __len__(self):
287            return len(self.base)
288
289        def __getitem__(self, idx):
290            return self.base[idx]['video']  # (3, T, H, W)
291
292    video_dataset = VideoOnlyDataset(dataset)
293
294    if len(video_dataset) == 0:
295        raise RuntimeError(
296            f"No training clips were created from data in '{args.data_dir}'. "
297            f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. "
298            f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)"
299        )
300
301    # Create tokenizer
302    tokenizer = VideoTokenizer(**TOKENIZER_DEFAULTS)
303    print(f"VideoTokenizer parameters: {sum(p.numel() for p in tokenizer.parameters()):,}")
304
305    # Train
306    trainer = VideoTokenizerTrainer(
307        model=tokenizer,
308        dataset=video_dataset,
309        batch_size=args.tokenizer_batch_size,
310        learning_rate=args.tokenizer_lr,
311        max_grad_norm=args.tokenizer_max_grad_norm,
312        num_train_steps=args.num_steps or args.tokenizer_num_steps,
313        cpu=use_cpu,
314        mixed_precision=args.mixed_precision,
315        use_tensorboard_logger=args.use_tensorboard,
316        log_dir=args.output_dir,
317        log_video=args.log_video,
318        video_fps=20,
319        log_video_every=args.log_video_every,
320        checkpoint_folder=args.output_dir,
321        dataloader_num_workers=args.num_workers,
322        dataloader_pin_memory=not args.no_pin_memory,
323        dataloader_prefetch_factor=args.prefetch_factor,
324        resume_from=resume_from,
325    )
326
327    trainer()
328
329    # Save checkpoint
330    os.makedirs(args.output_dir, exist_ok=True)
331    ckpt_path = os.path.join(args.output_dir, "tokenizer.pt")
332    torch.save({
333        'model': tokenizer.state_dict(),
334        'config': TOKENIZER_DEFAULTS,
335    }, ckpt_path)
336    print(f"Tokenizer saved to {ckpt_path}")
337
338    # Only prune the last state-<N>/ if training finished cleanly. If the
339    # trainer was cancelled via SIGTERM (scancel / walltime), keep the
340    # state dir so the user can resume with --resume_from latest.
341    if not getattr(trainer, 'cancelled', False):
342        _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='tokenizer')
343    else:
344        print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")

Train the VideoTokenizer on Minecraft video data.

The tokenizer learns to compress 384x640 RGB frames into compact latent representations using:

  • Patch embedding (16x16 patches → 24x40 spatial grid)
  • Axial space-time transformer encoder
  • Latent bottleneck with Tanh activation
  • MAE-style patch masking for regularization
  • LPIPS perceptual loss for visual quality
  • Temporal/spatial decorrelation losses

The encoder output shape per frame: (num_latent_tokens, dim_latent) = (512, 16)

def train_dynamics(args):
349def train_dynamics(args):
350    """Train the DynamicsWorldModel on tokenized video + actions.
351
352    The dynamics model learns to predict future latent states using
353    flow matching with shortcut consistency training:
354
355    1. Takes clean latents from the tokenizer
356    2. Adds noise at random signal levels (flow matching)
357    3. Predicts the clean latents from noised versions
358    4. Also predicts rewards and actions (multi-token prediction)
359
360    Shortcut training (from Frans et al.):
361      - Sometimes trains with step_size > 1, allowing the model to
362        make larger jumps in the denoising process
363      - Consistency loss ensures half-step predictions compose correctly
364      - This makes generation faster at inference time
365
366    The world model processes sequences of:
367      - Spatial tokens (from latents)
368      - Action tokens (embedded discrete actions)
369      - Reward tokens (SymExp two-hot encoded)
370      - Register tokens (for temporal consistency)
371      - Agent tokens (for policy/value heads)
372    """
373    print("=" * 60)
374    print("PHASE 2: Training Dynamics World Model")
375    print("=" * 60)
376
377    use_cpu = _enforce_cuda_or_exit(args, "Phase 2")
378    resume_from = _resolve_resume_from(args.resume_from, args.output_dir)
379
380    # Load tokenizer from Phase 1 checkpoint
381    assert args.tokenizer_ckpt is not None, "Must provide --tokenizer_ckpt for Phase 2"
382    tok_ckpt = torch.load(args.tokenizer_ckpt, map_location='cpu', weights_only=False)
383    tok_config = tok_ckpt.get('config', TOKENIZER_DEFAULTS)
384
385    tokenizer = VideoTokenizer(**tok_config)
386    tokenizer.load_state_dict(tok_ckpt['model'])
387    tokenizer.eval()
388    for p in tokenizer.parameters():
389        p.requires_grad_(False)
390    print("Loaded frozen tokenizer")
391
392    # Load dataset with actions
393    dataset = MinecraftVPTDataset(
394        data_dir=args.data_dir,
395        seq_len=args.dynamics_seq_len,
396        stride=args.dynamics_seq_len,
397        image_height=384,
398        image_width=640,
399        max_trajectories=args.max_trajectories,
400    )
401
402    if len(dataset) == 0:
403        raise RuntimeError(
404            f"No training clips were created from data in '{args.data_dir}'. "
405            f"Check that the --data_dir path is correct and contains .mp4/.jsonl pairs. "
406            f"(Note: the default folder is 'data/vpt-recordings' with hyphens, not underscores.)"
407        )
408
409    # Create dynamics model with the tokenizer
410    base_defaults = LEWM_DYNAMICS_DEFAULTS if args.use_lewm else DYNAMICS_DEFAULTS
411    dynamics_config = base_defaults.copy()
412    dynamics_config['num_latent_tokens'] = tok_config.get('num_latent_tokens', 16)
413    dynamics_config['dim_latent'] = tok_config.get('dim_latent', 32)
414
415    variant = "LeWM" if args.use_lewm else "Dreamer4"
416    dynamics = DynamicsWorldModel(
417        video_tokenizer=tokenizer,
418        **dynamics_config,
419    )
420    n_params = sum(p.numel() for p in dynamics.parameters())
421    print(f"DynamicsWorldModel ({variant}) parameters: {n_params:,}")
422
423    # Train using BehaviorCloneTrainer
424    # The trainer accepts dict batches and calls dynamics(**batch_data)
425    trainer = BehaviorCloneTrainer(
426        model=dynamics,
427        dataset=dataset,
428        batch_size=args.dynamics_batch_size,
429        learning_rate=args.dynamics_lr,
430        max_grad_norm=args.dynamics_max_grad_norm,
431        num_train_steps=args.num_steps or args.dynamics_num_steps,
432        cpu=use_cpu,
433        mixed_precision=args.mixed_precision,
434        use_tensorboard_logger=args.use_tensorboard,
435        log_dir=args.output_dir,
436        checkpoint_folder=args.output_dir,
437        dataloader_num_workers=args.num_workers,
438        dataloader_pin_memory=not args.no_pin_memory,
439        dataloader_prefetch_factor=args.prefetch_factor,
440        resume_from=resume_from,
441    )
442
443    trainer()
444
445    # Save checkpoint
446    os.makedirs(args.output_dir, exist_ok=True)
447    ckpt_name = "lewm_dynamics.pt" if args.use_lewm else "dynamics.pt"
448    ckpt_path = os.path.join(args.output_dir, ckpt_name)
449    torch.save({
450        'model': dynamics.state_dict(),
451        'config': dynamics_config,
452        'tokenizer_config': tok_config,
453        'use_lewm': args.use_lewm,
454    }, ckpt_path)
455    print(f"Dynamics model saved to {ckpt_path}")
456
457    if not getattr(trainer, 'cancelled', False):
458        _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix='dynamics')
459    else:
460        print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")

Train the DynamicsWorldModel on tokenized video + actions.

The dynamics model learns to predict future latent states using flow matching with shortcut consistency training:

  1. Takes clean latents from the tokenizer
  2. Adds noise at random signal levels (flow matching)
  3. Predicts the clean latents from noised versions
  4. Also predicts rewards and actions (multi-token prediction)

Shortcut training (from Frans et al.):

  • Sometimes trains with step_size > 1, allowing the model to make larger jumps in the denoising process
  • Consistency loss ensures half-step predictions compose correctly
  • This makes generation faster at inference time
The world model processes sequences of:
  • Spatial tokens (from latents)
  • Action tokens (embedded discrete actions)
  • Reward tokens (SymExp two-hot encoded)
  • Register tokens (for temporal consistency)
  • Agent tokens (for policy/value heads)
def train_agent(args):
465def train_agent(args):
466    """Train the policy/value heads using imagined rollouts.
467
468    DreamTrainer generates experience entirely inside the world model:
469
470    1. Start with random noise latents
471    2. Iteratively denoise using the dynamics model (generate())
472    3. At each step, the agent token embedding provides:
473       - Policy distribution (via policy head MLP)
474       - Value estimate (via value head MLP)
475       - Predicted reward (via reward prediction heads)
476    4. After generating a trajectory, compute GAE returns
477    5. Update policy head with PPO/PMPO loss
478    6. Update value head with clipped value loss
479
480    Only the policy and value head parameters are updated.
481    The world model transformer weights remain frozen.
482
483    This is the key advantage of Dreamer-style methods:
484    the agent can train on unlimited imagined experience
485    without needing access to the real Minecraft environment.
486    """
487    print("=" * 60)
488    print("PHASE 3: Training Agent in Dreams")
489    print("=" * 60)
490
491    use_cpu = _enforce_cuda_or_exit(args, "Phase 3")
492    resume_from = _resolve_resume_from(args.resume_from, args.output_dir)
493
494    # Load dynamics model from Phase 2
495    assert args.dynamics_ckpt is not None, "Must provide --dynamics_ckpt for Phase 3"
496    dyn_ckpt = torch.load(args.dynamics_ckpt, map_location='cpu', weights_only=False)
497    is_lewm = dyn_ckpt.get('use_lewm', False) or args.use_lewm
498    default_config = LEWM_DYNAMICS_DEFAULTS if is_lewm else DYNAMICS_DEFAULTS
499    dyn_config = dyn_ckpt.get('config', default_config)
500    tok_config = dyn_ckpt.get('tokenizer_config', TOKENIZER_DEFAULTS)
501    if is_lewm:
502        print("Detected LeWM dynamics checkpoint")
503
504    # Rebuild tokenizer (frozen)
505    tokenizer = VideoTokenizer(**tok_config)
506    tokenizer.eval()
507
508    # Rebuild dynamics model
509    dynamics = DynamicsWorldModel(
510        video_tokenizer=tokenizer,
511        **dyn_config,
512    )
513    dynamics.load_state_dict(dyn_ckpt['model'])
514    print("Loaded dynamics model")
515
516    # Freeze everything except policy and value heads
517    for p in dynamics.parameters():
518        p.requires_grad_(False)
519    for p in dynamics.policy_head_parameters():
520        p.requires_grad_(True)
521    for p in dynamics.value_head_parameters():
522        p.requires_grad_(True)
523
524    n_trainable = sum(p.numel() for p in dynamics.parameters() if p.requires_grad)
525    print(f"Trainable parameters: {n_trainable:,}")
526
527    # Train using DreamTrainer
528    trainer = DreamTrainer(
529        model=dynamics,
530        batch_size=args.dream_batch_size,
531        generate_timesteps=args.dream_generate_timesteps,
532        learning_rate=args.dream_lr,
533        max_grad_norm=args.dream_max_grad_norm,
534        num_train_steps=args.num_steps or args.dream_num_steps,
535        cpu=use_cpu,
536        mixed_precision=args.mixed_precision,
537        use_tensorboard_logger=args.use_tensorboard,
538        log_dir=args.output_dir,
539        checkpoint_every=args.dream_checkpoint_every,
540        checkpoint_folder=args.output_dir,
541        resume_from=resume_from,
542    )
543
544    trainer()
545
546    # Save final checkpoint with everything
547    os.makedirs(args.output_dir, exist_ok=True)
548    ckpt_name = "lewm_minecraft.pt" if is_lewm else "dreamer4_minecraft.pt"
549    ckpt_path = os.path.join(args.output_dir, ckpt_name)
550    torch.save({
551        'model': dynamics.state_dict(),
552        'config': dyn_config,
553        'tokenizer_config': tok_config,
554        'use_lewm': is_lewm,
555    }, ckpt_path)
556    print(f"Trained agent saved to {ckpt_path}")
557
558    if not getattr(trainer, 'cancelled', False):
559        _cleanup_last_intermediate_checkpoint(args.output_dir, legacy_prefix=None)
560    else:
561        print(f"training cancelled — keeping state dir in {args.output_dir} for --resume_from")

Train the policy/value heads using imagined rollouts.

DreamTrainer generates experience entirely inside the world model:

  1. Start with random noise latents
  2. Iteratively denoise using the dynamics model (generate())
  3. At each step, the agent token embedding provides:
    • Policy distribution (via policy head MLP)
    • Value estimate (via value head MLP)
    • Predicted reward (via reward prediction heads)
  4. After generating a trajectory, compute GAE returns
  5. Update policy head with PPO/PMPO loss
  6. Update value head with clipped value loss

Only the policy and value head parameters are updated. The world model transformer weights remain frozen.

This is the key advantage of Dreamer-style methods: the agent can train on unlimited imagined experience without needing access to the real Minecraft environment.

def main():
566def main():
567    """CLI entry point: parse arguments and dispatch to the requested phase.
568
569    Wires up the full argparse surface for all three phases (tokenizer /
570    dynamics / dream), validates that a data directory is present for
571    phases that need it, and then calls :func:`train_tokenizer`,
572    :func:`train_dynamics`, or :func:`train_agent` based on ``--phase``.
573    """
574    parser = argparse.ArgumentParser(
575        description="Train Dreamer4 on Minecraft using VPT data"
576    )
577
578    # Required arguments
579    parser.add_argument("--phase", type=int, required=True, choices=[1, 2, 3],
580                        help="Training phase: 1=tokenizer, 2=dynamics, 3=agent")
581    parser.add_argument("--output_dir", type=str, default="./checkpoints",
582                        help="Directory for saving checkpoints and logs")
583
584    # Data arguments
585    parser.add_argument("--data_dir", type=str, default=None,
586                        help="Directory containing VPT .mp4/.jsonl pairs")
587    parser.add_argument("--max_trajectories", type=int, default=None,
588                        help="Limit number of trajectories (for debugging)")
589
590    # Checkpoint arguments
591    parser.add_argument("--tokenizer_ckpt", type=str, default=None,
592                        help="Path to tokenizer checkpoint (Phase 2)")
593    parser.add_argument("--dynamics_ckpt", type=str, default=None,
594                        help="Path to dynamics checkpoint (Phase 3)")
595
596    # Training arguments
597    parser.add_argument("--num_steps", type=int, default=None,
598                        help="Override number of training steps")
599
600    # Tokenizer hyperparameters
601    parser.add_argument("--tokenizer_batch_size", type=int,
602                        default=TRAINING_DEFAULTS['tokenizer_batch_size'])
603    parser.add_argument("--tokenizer_lr", type=float,
604                        default=TRAINING_DEFAULTS['tokenizer_lr'])
605    parser.add_argument("--tokenizer_max_grad_norm", type=float,
606                        default=TRAINING_DEFAULTS['tokenizer_max_grad_norm'])
607    parser.add_argument("--tokenizer_num_steps", type=int,
608                        default=TRAINING_DEFAULTS['tokenizer_num_steps'])
609    parser.add_argument("--tokenizer_seq_len", type=int,
610                        default=TRAINING_DEFAULTS['tokenizer_seq_len'])
611
612    # Dynamics hyperparameters
613    parser.add_argument("--dynamics_batch_size", type=int,
614                        default=TRAINING_DEFAULTS['dynamics_batch_size'])
615    parser.add_argument("--dynamics_lr", type=float,
616                        default=TRAINING_DEFAULTS['dynamics_lr'])
617    parser.add_argument("--dynamics_max_grad_norm", type=float,
618                        default=TRAINING_DEFAULTS['dynamics_max_grad_norm'])
619    parser.add_argument("--dynamics_num_steps", type=int,
620                        default=TRAINING_DEFAULTS['dynamics_num_steps'])
621    parser.add_argument("--dynamics_seq_len", type=int,
622                        default=TRAINING_DEFAULTS['dynamics_seq_len'])
623
624    # Dream training hyperparameters
625    parser.add_argument("--dream_batch_size", type=int,
626                        default=TRAINING_DEFAULTS['dream_batch_size'])
627    parser.add_argument("--dream_lr", type=float,
628                        default=TRAINING_DEFAULTS['dream_lr'])
629    parser.add_argument("--dream_max_grad_norm", type=float,
630                        default=TRAINING_DEFAULTS['dream_max_grad_norm'])
631    parser.add_argument("--dream_num_steps", type=int,
632                        default=TRAINING_DEFAULTS['dream_num_steps'])
633    parser.add_argument("--dream_generate_timesteps", type=int,
634                        default=TRAINING_DEFAULTS['dream_generate_timesteps'])
635
636    # Model variant
637    parser.add_argument("--use_lewm", action="store_true",
638                        help="Use LeWM dynamics (JEPA-style next-embedding prediction) "
639                             "instead of flow matching")
640
641    # Logging
642    parser.add_argument("--use_tensorboard", action="store_true",
643                        help="Enable TensorBoard logging")
644    parser.add_argument("--log_video", action="store_true",
645                        help="Log video reconstructions (Phase 1 only)")
646    parser.add_argument("--log_video_every", type=int, default=1000,
647                        help="Log video every N steps")
648
649    # Resume / checkpointing
650    parser.add_argument("--resume_from", type=str, default=None,
651                        help="Resume training from an accelerator.save_state() dump. "
652                             "Accepts either a specific state-<N>/ directory, the "
653                             "parent dir containing latest_state.txt, or the literal "
654                             "'latest' to read the pointer inside --output_dir.")
655    parser.add_argument("--dream_checkpoint_every", type=int, default=500,
656                        help="Phase 3: save a DreamTrainer checkpoint every N steps "
657                             "(0 to disable). Phases 1/2 use the trainer defaults.")
658
659    # Performance / hardware
660    parser.add_argument("--num_workers", type=int, default=4,
661                        help="DataLoader num_workers. Parallel frame decoding so "
662                             "data prep does not bottleneck GPU training.")
663    parser.add_argument("--no_pin_memory", action="store_true",
664                        help="Disable DataLoader pin_memory (default: enabled).")
665    parser.add_argument("--prefetch_factor", type=int, default=4,
666                        help="DataLoader prefetch_factor (per worker). Bigger "
667                             "value keeps the GPU fed while workers decode "
668                             "the next batches of video.")
669    parser.add_argument("--mixed_precision", type=str, default="no",
670                        choices=["no", "fp16", "bf16"],
671                        help="Mixed precision mode passed to Accelerate. "
672                             "Use 'fp16' on V100S (Volta), 'bf16' on A100/H100.")
673    parser.add_argument("--allow_cpu", action="store_true",
674                        help="Allow running on CPU when torch.cuda.is_available() "
675                             "is False. Without this flag, a missing GPU is a hard "
676                             "error — preventing a silent 24h CPU fallback.")
677
678    args = parser.parse_args()
679
680    # Validate arguments
681    if args.phase in [1, 2]:
682        assert args.data_dir is not None, f"Phase {args.phase} requires --data_dir"
683
684    # Run the appropriate phase
685    if args.phase == 1:
686        train_tokenizer(args)
687    elif args.phase == 2:
688        train_dynamics(args)
689    elif args.phase == 3:
690        train_agent(args)

CLI entry point: parse arguments and dispatch to the requested phase.

Wires up the full argparse surface for all three phases (tokenizer / dynamics / dream), validates that a data directory is present for phases that need it, and then calls train_tokenizer(), train_dynamics(), or train_agent() based on --phase.