minecraft_vpt_dataset

VPT Minecraft Dataset for Dreamer4 World Model Training.

Converts OpenAI VPT .mp4/.jsonl recording pairs into the tensor format required by Dreamer4's VideoTokenizer and DynamicsWorldModel.

VPT recordings consist of:
  • .mp4 files: 640x360 Minecraft gameplay video at 20 FPS
  • .jsonl files: One JSON object per frame with keyboard/mouse/camera actions

Dreamer4 expects:

  • video: (batch, channels, time, height, width) float32 tensors in [0, 1]
  • discrete_actions: (batch, time, num_discrete_action_types) long tensors
  • rewards: (batch, time) float32 tensors

The key design decision is how to map VPT's action space to Dreamer4's:

VPT Action Space:
  • 20 binary keyboard/mouse buttons (attack, forward, jump, ...)
  • 2D continuous camera (pitch, yaw) in degrees, discretized to 11x11=121 bins

Dreamer4 Action Space Options:

  1. All discrete: 20 binary buttons (each Discrete(2)) + 1 camera (Discrete(121)) → num_discrete_actions = (2,2,2,...,2, 121) = tuple of 21 ints
  2. Hybrid: 20 binary buttons discrete + 2 continuous camera → num_discrete_actions = (2,)*20, num_continuous_actions = 2

We use option 1 (all discrete) because:

  • VPT already discretizes camera via mu-law encoding
  • Avoids continuous action normalization complexity
  • Matches how VPT's own policy head works (CategoricalActionHead)
  • Camera bins preserve the mu-law foveation (more precision near center)

Action encoding for Dreamer4: discrete_actions[:, :, 0:20] = button states (0 or 1 each, num_choices=2) discrete_actions[:, :, 20] = camera bin index (0 to 120, num_choices=121)

This gives: num_discrete_actions = (2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 121)

  1"""
  2VPT Minecraft Dataset for Dreamer4 World Model Training.
  3
  4Converts OpenAI VPT .mp4/.jsonl recording pairs into the tensor format
  5required by Dreamer4's VideoTokenizer and DynamicsWorldModel.
  6
  7VPT recordings consist of:
  8  - .mp4 files: 640x360 Minecraft gameplay video at 20 FPS
  9  - .jsonl files: One JSON object per frame with keyboard/mouse/camera actions
 10
 11Dreamer4 expects:
 12  - video: (batch, channels, time, height, width) float32 tensors in [0, 1]
 13  - discrete_actions: (batch, time, num_discrete_action_types) long tensors
 14  - rewards: (batch, time) float32 tensors
 15
 16The key design decision is how to map VPT's action space to Dreamer4's:
 17
 18VPT Action Space:
 19  - 20 binary keyboard/mouse buttons (attack, forward, jump, ...)
 20  - 2D continuous camera (pitch, yaw) in degrees, discretized to 11x11=121 bins
 21
 22Dreamer4 Action Space Options:
 23  1. All discrete: 20 binary buttons (each Discrete(2)) + 1 camera (Discrete(121))
 24     → num_discrete_actions = (2,2,2,...,2, 121) = tuple of 21 ints
 25  2. Hybrid: 20 binary buttons discrete + 2 continuous camera
 26     → num_discrete_actions = (2,)*20, num_continuous_actions = 2
 27
 28We use option 1 (all discrete) because:
 29  - VPT already discretizes camera via mu-law encoding
 30  - Avoids continuous action normalization complexity
 31  - Matches how VPT's own policy head works (CategoricalActionHead)
 32  - Camera bins preserve the mu-law foveation (more precision near center)
 33
 34Action encoding for Dreamer4:
 35  discrete_actions[:, :, 0:20]  = button states (0 or 1 each, num_choices=2)
 36  discrete_actions[:, :, 20]    = camera bin index (0 to 120, num_choices=121)
 37
 38This gives: num_discrete_actions = (2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 121)
 39"""
 40
 41import os
 42import json
 43import glob
 44import random
 45from typing import Optional
 46
 47import cv2
 48import numpy as np
 49import torch
 50from torch.utils.data import Dataset
 51
 52# Prefer decord for fast random-access frame seeking; fall back to cv2
 53try:
 54    import decord
 55    decord.bridge.set_bridge("native")
 56    HAS_DECORD = True
 57except ImportError:
 58    HAS_DECORD = False
 59
 60# ─── VPT Action Constants ───────────────────────────────────────────
 61
 62# The 20 binary button actions in VPT, in the order defined by Buttons.ALL
 63# in Video-Pre-Training/lib/actions.py
 64BUTTONS_ALL = [
 65    "attack", "back", "forward", "jump", "left", "right",
 66    "sneak", "sprint", "use", "drop", "inventory",
 67    "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5",
 68    "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9",
 69]
 70
 71N_BUTTONS = len(BUTTONS_ALL)  # 20
 72
 73# VPT keyboard key → MineRL action name mapping
 74# (from Video-Pre-Training/run_inverse_dynamics_model.py)
 75KEYBOARD_BUTTON_MAPPING = {
 76    "key.keyboard.escape": "ESC",
 77    "key.keyboard.s": "back",
 78    "key.keyboard.q": "drop",
 79    "key.keyboard.w": "forward",
 80    "key.keyboard.1": "hotbar.1",
 81    "key.keyboard.2": "hotbar.2",
 82    "key.keyboard.3": "hotbar.3",
 83    "key.keyboard.4": "hotbar.4",
 84    "key.keyboard.5": "hotbar.5",
 85    "key.keyboard.6": "hotbar.6",
 86    "key.keyboard.7": "hotbar.7",
 87    "key.keyboard.8": "hotbar.8",
 88    "key.keyboard.9": "hotbar.9",
 89    "key.keyboard.e": "inventory",
 90    "key.keyboard.space": "jump",
 91    "key.keyboard.a": "left",
 92    "key.keyboard.d": "right",
 93    "key.keyboard.left.shift": "sneak",
 94    "key.keyboard.left.control": "sprint",
 95    "key.keyboard.f": "swapHands",
 96}
 97
 98# Camera quantization parameters matching VPT's ActionTransformer
 99# from Video-Pre-Training/agent.py: ACTION_TRANSFORMER_KWARGS
100CAMERA_MAXVAL = 10
101CAMERA_BINSIZE = 2
102CAMERA_MU = 10  # mu-law parameter
103N_CAMERA_BINS = 11  # per axis → 11*11 = 121 total camera actions
104
105# Sensitivity scaler from VPT recording format
106CAMERA_SCALER = 360.0 / 2400.0
107
108# Dreamer4 action space specification:
109# 20 binary buttons (each with 2 choices) + 1 camera (121 choices)
110DREAMER4_NUM_DISCRETE_ACTIONS = tuple([2] * N_BUTTONS + [N_CAMERA_BINS * N_CAMERA_BINS])
111# Total action dimensions for the discrete_actions tensor: 21
112
113
114# ─── Camera Discretization (matching VPT's mu-law scheme) ───────────
115
116def mu_law_encode(x: np.ndarray, mu: float = CAMERA_MU) -> np.ndarray:
117    """Apply mu-law compression: maps [-maxval, maxval] → [-maxval, maxval].
118
119    This is the same encoding used by VPT's CameraQuantizer with
120    quantization_scheme="mu_law". It compresses the dynamic range so
121    small camera movements get more precision (foveated discretization).
122    """
123    x_clipped = np.clip(x, -CAMERA_MAXVAL, CAMERA_MAXVAL)
124    x_norm = x_clipped / CAMERA_MAXVAL
125    encoded = np.sign(x_norm) * (np.log(1.0 + mu * np.abs(x_norm)) / np.log(1.0 + mu))
126    return encoded * CAMERA_MAXVAL
127
128
129def discretize_camera(camera_xy: np.ndarray) -> np.ndarray:
130    """Discretize continuous camera (pitch, yaw) to bin indices.
131
132    Matches VPT's CameraQuantizer.discretize() with mu_law scheme:
133      1. Clip to [-CAMERA_MAXVAL, CAMERA_MAXVAL]
134      2. Apply mu-law encoding
135      3. Linear quantization to N_CAMERA_BINS bins
136
137    Args:
138        camera_xy: (..., 2) array of [pitch, yaw] in degrees
139
140    Returns:
141        (..., 2) array of bin indices, each in [0, N_CAMERA_BINS-1]
142    """
143    encoded = mu_law_encode(camera_xy)
144    bins = np.round((encoded + CAMERA_MAXVAL) / CAMERA_BINSIZE).astype(np.int64)
145    bins = np.clip(bins, 0, N_CAMERA_BINS - 1)
146    return bins
147
148
149def camera_bins_to_joint_index(pitch_bin: int, yaw_bin: int) -> int:
150    """Combine 2D camera bins into a single index for Dreamer4.
151
152    Joint index = pitch_bin * N_CAMERA_BINS + yaw_bin
153    Range: [0, 120] for 11x11 grid
154    """
155    return pitch_bin * N_CAMERA_BINS + yaw_bin
156
157
158# ─── JSONL Action Parsing ───────────────────────────────────────────
159
160def parse_jsonl_action(step_data: dict, attack_is_stuck: bool = False) -> tuple:
161    """Parse a single JSONL action record into MineRL-style env action.
162
163    This replicates the logic from VPT's data_loader.py and
164    run_inverse_dynamics_model.py:json_action_to_env_action().
165
166    Args:
167        step_data: One parsed JSON object from the .jsonl file
168        attack_is_stuck: Whether the attack button is stuck down (recorder bug)
169
170    Returns:
171        (env_action_dict, is_null_action, new_attack_is_stuck)
172    """
173    # Handle attack-stuck bug (same as VPT data_loader.py lines 86-95)
174    if attack_is_stuck:
175        step_data["mouse"]["buttons"] = [
176            b for b in step_data["mouse"]["buttons"] if b != 0
177        ]
178
179    # Build env action dict (matches NOOP_ACTION structure)
180    env_action = {b: 0 for b in BUTTONS_ALL}
181    env_action["ESC"] = 0
182    env_action["pickItem"] = 0
183    env_action["swapHands"] = 0
184    env_action["camera"] = np.array([0.0, 0.0])
185
186    is_null_action = True
187
188    # Keyboard keys
189    keyboard_keys = step_data.get("keyboard", {}).get("keys", [])
190    for key in keyboard_keys:
191        if key in KEYBOARD_BUTTON_MAPPING:
192            action_name = KEYBOARD_BUTTON_MAPPING[key]
193            if action_name in env_action:
194                env_action[action_name] = 1
195                if action_name in BUTTONS_ALL:
196                    is_null_action = False
197
198    # Mouse buttons: 0=attack, 1=use, 2=pickItem
199    mouse = step_data.get("mouse", {})
200    mouse_buttons = mouse.get("buttons", [])
201    if 0 in mouse_buttons:
202        env_action["attack"] = 1
203        is_null_action = False
204    if 1 in mouse_buttons:
205        env_action["use"] = 1
206        is_null_action = False
207    if 2 in mouse_buttons:
208        env_action["pickItem"] = 1
209        is_null_action = False
210
211    # Camera: mouse dx/dy → degrees
212    camera_action = env_action["camera"]
213    camera_action[0] = mouse.get("dy", 0) * CAMERA_SCALER  # pitch
214    camera_action[1] = mouse.get("dx", 0) * CAMERA_SCALER  # yaw
215
216    if mouse.get("dx", 0) != 0 or mouse.get("dy", 0) != 0:
217        is_null_action = False
218    else:
219        if abs(camera_action[0]) > 180:
220            camera_action[0] = 0
221        if abs(camera_action[1]) > 180:
222            camera_action[1] = 0
223
224    return env_action, is_null_action
225
226
227def env_action_to_dreamer4(env_action: dict) -> np.ndarray:
228    """Convert MineRL env action dict to Dreamer4 discrete action vector.
229
230    Returns:
231        (21,) int64 array:
232          [0:20] = button values (0 or 1)
233          [20]   = camera joint index (0 to 120)
234    """
235    action = np.zeros(N_BUTTONS + 1, dtype=np.int64)
236
237    # Buttons
238    for i, button_name in enumerate(BUTTONS_ALL):
239        action[i] = int(env_action.get(button_name, 0))
240
241    # Camera: discretize continuous degrees to bins, then to joint index
242    camera_xy = env_action["camera"]  # [pitch, yaw] in degrees
243    bins = discretize_camera(camera_xy)
244    action[N_BUTTONS] = camera_bins_to_joint_index(int(bins[0]), int(bins[1]))
245
246    return action
247
248
249# ─── Trajectory Loading ─────────────────────────────────────────────
250
251def load_trajectory(
252    video_path: str,
253    jsonl_path: str,
254    target_height: int = 384,
255    target_width: int = 640,
256    skip_null_actions: bool = True,
257) -> tuple:
258    """Load a single VPT recording into arrays.
259
260    Reads the .mp4 and .jsonl pair, parsing actions frame-by-frame using
261    the same logic as VPT's data_loader.py (null-action filtering,
262    attack-stuck handling, hotbar tracking, cursor overlay for GUI).
263
264    Args:
265        video_path: Path to .mp4 file
266        jsonl_path: Path to .jsonl file
267        target_height: Resize frames to this height
268        target_width: Resize frames to this width
269        skip_null_actions: Whether to skip null actions (as VPT paper does)
270
271    Returns:
272        (frames, actions, rewards) where:
273          frames: (T, H, W, 3) uint8 array
274          actions: (T, 21) int64 array (Dreamer4 discrete format)
275          rewards: (T,) float32 array (zeros — VPT recordings have no reward)
276    """
277    # Load JSONL actions
278    with open(jsonl_path, encoding="utf-8") as f:
279        json_lines = f.readlines()
280        json_data = json.loads("[" + ",".join(json_lines) + "]")
281
282    # Open video
283    cap = cv2.VideoCapture(video_path)
284    if not cap.isOpened():
285        raise RuntimeError(f"Cannot open video: {video_path}")
286
287    frames = []
288    actions = []
289
290    # Attack-stuck workaround (same as VPT data_loader.py)
291    attack_is_stuck = False
292    last_hotbar = 0
293
294    # Optional: load cursor image for GUI overlay
295    cursor_path = os.path.join(
296        os.path.dirname(__file__), "Video-Pre-Training", "cursors",
297        "mouse_cursor_white_16x16.png"
298    )
299    cursor_image = None
300    cursor_alpha = None
301    if os.path.exists(cursor_path):
302        cursor_img = cv2.imread(cursor_path, cv2.IMREAD_UNCHANGED)
303        if cursor_img is not None:
304            cursor_img = cursor_img[:16, :16, :]
305            cursor_alpha = cursor_img[:, :, 3:] / 255.0
306            cursor_image = cursor_img[:, :, :3]
307
308    for i, step_data in enumerate(json_data):
309        # Handle attack-stuck bug
310        if i == 0:
311            if step_data.get("mouse", {}).get("newButtons") == [0]:
312                attack_is_stuck = True
313        elif attack_is_stuck:
314            if 0 in step_data.get("mouse", {}).get("newButtons", []):
315                attack_is_stuck = False
316
317        if attack_is_stuck:
318            step_data["mouse"]["buttons"] = [
319                b for b in step_data["mouse"]["buttons"] if b != 0
320            ]
321
322        # Parse action
323        env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False)
324
325        # Hotbar tracking (VPT data_loader.py lines 99-103)
326        current_hotbar = step_data.get("hotbar", 0)
327        if current_hotbar != last_hotbar:
328            env_action[f"hotbar.{current_hotbar + 1}"] = 1
329            is_null = False
330        last_hotbar = current_hotbar
331
332        # Read corresponding video frame
333        ret, frame = cap.read()
334        if not ret:
335            break
336
337        # Skip null actions (as done in VPT paper)
338        if skip_null_actions and is_null:
339            continue
340
341        # GUI cursor overlay (VPT data_loader.py lines 113-117)
342        if step_data.get("isGuiOpen", False) and cursor_image is not None:
343            h_orig = 720  # MINEREC_ORIGINAL_HEIGHT_PX
344            scale = frame.shape[0] / h_orig
345            cx = int(step_data["mouse"]["x"] * scale)
346            cy = int(step_data["mouse"]["y"] * scale)
347            _composite_cursor(frame, cursor_image, cursor_alpha, cx, cy)
348
349        # BGR → RGB
350        cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)
351        frame = np.clip(frame, 0, 255).astype(np.uint8)
352
353        # Zero-pad to target resolution (paper: 360x640 → 384x640)
354        # If native resolution matches target width, pad height only;
355        # otherwise fall back to resize for non-standard source video.
356        h, w = frame.shape[:2]
357        if w == target_width and h < target_height:
358            pad_total = target_height - h
359            pad_top = pad_total // 2
360            pad_bottom = pad_total - pad_top
361            frame = cv2.copyMakeBorder(
362                frame, pad_top, pad_bottom, 0, 0,
363                cv2.BORDER_CONSTANT, value=(0, 0, 0)
364            )
365        elif h != target_height or w != target_width:
366            frame = cv2.resize(frame, (target_width, target_height),
367                               interpolation=cv2.INTER_LINEAR)
368
369        # Convert to Dreamer4 action format
370        d4_action = env_action_to_dreamer4(env_action)
371
372        frames.append(frame)
373        actions.append(d4_action)
374
375    cap.release()
376
377    if len(frames) == 0:
378        return np.empty((0, target_height, target_width, 3), dtype=np.uint8), \
379               np.empty((0, N_BUTTONS + 1), dtype=np.int64), \
380               np.empty((0,), dtype=np.float32)
381
382    frames = np.stack(frames, axis=0)
383    actions = np.stack(actions, axis=0)
384    # VPT recordings don't include reward, so we use zeros.
385    # During behavior cloning this is fine — rewards are only needed
386    # for RL training (Phase 3: DreamTrainer).
387    rewards = np.zeros(len(frames), dtype=np.float32)
388
389    return frames, actions, rewards
390
391
392def _composite_cursor(image, cursor_img, cursor_alpha, x, y):
393    """Draw cursor onto image at (x, y). Modifies image in-place."""
394    ch = max(0, min(image.shape[0] - y, cursor_img.shape[0]))
395    cw = max(0, min(image.shape[1] - x, cursor_img.shape[1]))
396    if ch == 0 or cw == 0:
397        return
398    alpha = cursor_alpha[:ch, :cw]
399    image[y:y+ch, x:x+cw, :] = (
400        image[y:y+ch, x:x+cw, :] * (1 - alpha) +
401        cursor_img[:ch, :cw, :] * alpha
402    ).astype(np.uint8)
403
404
405# ─── Lazy Trajectory Pre-scan ──────────────────────────────────────
406
407def prescan_trajectory(
408    mp4_path: str,
409    jsonl_path: str,
410    skip_null_actions: bool = True,
411) -> tuple:
412    """Pre-scan a VPT recording to extract metadata without loading video frames.
413
414    Reads only the JSONL (for actions and null-action filtering) and the MP4
415    header (for frame count). No pixel data is loaded into memory.
416
417    Returns:
418        valid_frame_indices: (N,) int32 array of raw MP4 frame indices that
419            survived null-action filtering.
420        actions: (N, 21) int16 array of Dreamer4 discrete actions.
421    """
422    # Read JSONL
423    with open(jsonl_path, encoding="utf-8") as f:
424        json_data = [json.loads(line) for line in f]
425
426    # Get frame count from video header only (no pixel decode)
427    cap = cv2.VideoCapture(mp4_path)
428    n_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
429    cap.release()
430
431    valid_indices = []
432    actions = []
433
434    attack_is_stuck = False
435    last_hotbar = 0
436
437    for i, step_data in enumerate(json_data):
438        if i >= n_video_frames:
439            break
440
441        # Handle attack-stuck bug (same as load_trajectory)
442        if i == 0:
443            if step_data.get("mouse", {}).get("newButtons") == [0]:
444                attack_is_stuck = True
445        elif attack_is_stuck:
446            if 0 in step_data.get("mouse", {}).get("newButtons", []):
447                attack_is_stuck = False
448
449        if attack_is_stuck:
450            step_data["mouse"]["buttons"] = [
451                b for b in step_data["mouse"]["buttons"] if b != 0
452            ]
453
454        # Parse action
455        env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False)
456
457        # Hotbar tracking (VPT data_loader.py lines 99-103)
458        current_hotbar = step_data.get("hotbar", 0)
459        if current_hotbar != last_hotbar:
460            env_action[f"hotbar.{current_hotbar + 1}"] = 1
461            is_null = False
462        last_hotbar = current_hotbar
463
464        # Skip null actions (as done in VPT paper)
465        if skip_null_actions and is_null:
466            continue
467
468        # Convert to Dreamer4 format
469        d4_action = env_action_to_dreamer4(env_action)
470
471        valid_indices.append(i)
472        actions.append(d4_action)
473
474    valid_indices = np.array(valid_indices, dtype=np.int32)
475    if len(actions) > 0:
476        actions = np.stack(actions, axis=0).astype(np.int16)
477    else:
478        actions = np.empty((0, N_BUTTONS + 1), dtype=np.int16)
479
480    return valid_indices, actions
481
482
483# ─── Frame Decoding Backends ───────────────────────────────────────
484
485def _zero_pad_frame(frame, image_height, image_width):
486    """Zero-pad a frame to (image_height, image_width) if width matches.
487
488    Follows the paper: 360x640 → 384x640 via symmetric zero-padding on height.
489    Falls back to resize if dimensions are unexpected.
490    """
491    h, w = frame.shape[:2]
492    if w == image_width and h < image_height:
493        pad_total = image_height - h
494        pad_top = pad_total // 2
495        pad_bottom = pad_total - pad_top
496        return cv2.copyMakeBorder(
497            frame, pad_top, pad_bottom, 0, 0,
498            cv2.BORDER_CONSTANT, value=(0, 0, 0)
499        )
500    if h != image_height or w != image_width:
501        return cv2.resize(frame, (image_width, image_height),
502                          interpolation=cv2.INTER_LINEAR)
503    return frame
504
505
506def _decode_frames_decord(mp4_path, frame_indices, image_height, image_width):
507    """Decode specific frames from MP4 using decord (fast random access)."""
508    vr = decord.VideoReader(mp4_path, num_threads=1)
509    frames = vr.get_batch(frame_indices.tolist()).asnumpy()  # (T, H, W, 3) RGB
510    if frames.shape[1] != image_height or frames.shape[2] != image_width:
511        frames = np.stack([
512            _zero_pad_frame(f, image_height, image_width) for f in frames
513        ])
514    return frames
515
516
517def _decode_frames_cv2(mp4_path, frame_indices, image_height, image_width):
518    """Decode specific frames from MP4 using cv2 sequential read.
519
520    Seeks to the first needed frame and reads sequentially to the last,
521    keeping only the requested indices. Efficient when frame_indices are
522    roughly contiguous (as they are after null-action filtering).
523    """
524    first_idx = int(frame_indices[0])
525    last_idx = int(frame_indices[-1])
526    needed = set(int(i) for i in frame_indices)
527
528    cap = cv2.VideoCapture(mp4_path)
529    cap.set(cv2.CAP_PROP_POS_FRAMES, first_idx)
530
531    frames = {}
532    for raw_idx in range(first_idx, last_idx + 1):
533        ret, frame = cap.read()
534        if not ret:
535            break
536        if raw_idx in needed:
537            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
538            frame = _zero_pad_frame(frame, image_height, image_width)
539            frames[raw_idx] = frame
540    cap.release()
541
542    return np.stack([frames[int(i)] for i in frame_indices])
543
544
545_decode_frames = _decode_frames_decord if HAS_DECORD else _decode_frames_cv2
546
547
548# ─── PyTorch Dataset ────────────────────────────────────────────────
549
550class MinecraftVPTDataset(Dataset):
551    """Lazy-loading PyTorch Dataset for VPT Minecraft recordings.
552
553    Decodes video frames on-the-fly from the original MP4 files instead of
554    pre-loading everything into RAM.  This means:
555      - Zero extra disk storage (the MP4 files ARE the dataset)
556      - RAM usage is independent of video count (~1 MB metadata per trajectory)
557      - Scales to thousands of videos without issue
558
559    At init time, only the JSONL files and MP4 headers are read to build a
560    lightweight index of valid frame positions and pre-parsed actions.  The
561    actual pixel data is decoded from the MP4 on each __getitem__ call.
562
563    For best throughput, use DataLoader with num_workers > 0 so that frame
564    decoding runs in parallel worker processes.  Install ``decord`` for
565    faster random-access seeking (falls back to cv2 otherwise).
566
567    Shape convention (per sample):
568        video:   (3, seq_len, H, W) float32 in [0, 1]
569        actions: (seq_len, 21) int64
570        rewards: (seq_len,) float32
571    """
572
573    def __init__(
574        self,
575        data_dir: str,
576        seq_len: int = 16,
577        stride: int = 8,
578        image_height: int = 384,
579        image_width: int = 640,
580        skip_null_actions: bool = True,
581        max_trajectories: Optional[int] = None,
582    ):
583        """
584        Args:
585            data_dir: Directory containing .mp4/.jsonl pairs
586            seq_len: Number of frames per training clip
587            stride: Step size between consecutive clips (for overlap)
588            image_height: Target image height
589            image_width: Target image width
590            skip_null_actions: Skip null actions as VPT paper does
591            max_trajectories: Limit number of loaded trajectories (for debugging)
592        """
593        super().__init__()
594        self.seq_len = seq_len
595        self.image_height = image_height
596        self.image_width = image_width
597
598        # Find all .mp4 files and their matching .jsonl files
599        mp4_files = sorted(glob.glob(os.path.join(data_dir, "*.mp4")))
600        if max_trajectories is not None:
601            mp4_files = mp4_files[:max_trajectories]
602
603        # Pre-scan: read only JSONLs + MP4 headers (no pixel data loaded)
604        # Per trajectory we store:
605        #   mp4_path:       string          (~100 bytes)
606        #   valid_indices:  int32 array     (~4 bytes/frame)
607        #   actions:        int16 array     (~42 bytes/frame)
608        # Total: ~46 bytes per valid frame — vs. 196,608 bytes per frame
609        # for the old eager float32 approach (4,272x smaller).
610        self.trajectories = []   # list of (mp4_path, valid_indices, actions)
611        self.clip_index = []     # list of (traj_idx, offset_in_valid_frames)
612
613        print(f"Scanning {len(mp4_files)} trajectories from {data_dir}...")
614        for mp4_path in mp4_files:
615            base = os.path.splitext(mp4_path)[0]
616            jsonl_path = base + ".jsonl"
617            if not os.path.exists(jsonl_path):
618                print(f"  Warning: no .jsonl for {mp4_path}, skipping")
619                continue
620
621            try:
622                valid_indices, actions = prescan_trajectory(
623                    mp4_path, jsonl_path,
624                    skip_null_actions=skip_null_actions,
625                )
626            except Exception as e:
627                print(f"  Error scanning {mp4_path}: {e}")
628                continue
629
630            if len(valid_indices) < seq_len:
631                print(f"  Trajectory too short ({len(valid_indices)} valid frames "
632                      f"< {seq_len}), skipping")
633                continue
634
635            traj_idx = len(self.trajectories)
636            self.trajectories.append((mp4_path, valid_indices, actions))
637
638            # Build clip entries with sliding window
639            num_clips = (len(valid_indices) - seq_len) // stride + 1
640            for c in range(num_clips):
641                self.clip_index.append((traj_idx, c * stride))
642
643        print(f"Created {len(self.clip_index)} clips of length {seq_len} "
644              f"from {len(self.trajectories)} trajectories"
645              f" (backend: {'decord' if HAS_DECORD else 'cv2'})")
646
647    def __len__(self):
648        """Return the number of sliding-window clips across all trajectories."""
649        return len(self.clip_index)
650
651    def __getitem__(self, idx, _retries=3):
652        """Decode a clip on-the-fly and return as a dict.
653
654        If frame decoding fails (corrupt MP4 data), retries with a random
655        different clip up to ``_retries`` times so a single bad file doesn't
656        crash the entire training run.
657
658        Returns dict compatible with Dreamer4's BehaviorCloneTrainer:
659            'video': (3, seq_len, H, W) float32 in [0, 1]
660            'discrete_actions': (seq_len, 21) int64
661            'rewards': (seq_len,) float32
662        """
663        for attempt in range(_retries + 1):
664            try:
665                traj_idx, offset = self.clip_index[idx]
666                mp4_path, valid_indices, actions = self.trajectories[traj_idx]
667
668                # Slice the frame indices and actions for this clip
669                clip_frame_indices = valid_indices[offset:offset + self.seq_len]
670                clip_actions = actions[offset:offset + self.seq_len]
671
672                # Decode frames on-the-fly from the MP4 (no data stored in RAM)
673                frames = _decode_frames(
674                    mp4_path, clip_frame_indices,
675                    self.image_height, self.image_width,
676                )  # (seq_len, H, W, 3) uint8 RGB
677
678                # (T, H, W, 3) uint8 → (3, T, H, W) float32 in [0, 1]
679                video = torch.from_numpy(frames).permute(3, 0, 1, 2).float().div_(255.0)
680
681                return {
682                    'video': video,
683                    'discrete_actions': torch.from_numpy(clip_actions.astype(np.int64)),
684                    'rewards': torch.zeros(self.seq_len, dtype=torch.float32),
685                }
686            except Exception as e:
687                if attempt < _retries:
688                    print(f"  Warning: decode error at clip {idx} "
689                          f"({mp4_path}), retrying with a different clip: {e}")
690                    idx = random.randint(0, len(self.clip_index) - 1)
691                else:
692                    raise RuntimeError(
693                        f"Failed to decode clip after {_retries + 1} attempts. "
694                        f"Last error: {e}"
695                    ) from e
696
697
698def collate_minecraft_batch(batch: list) -> dict:
699    """Custom collate function that stacks dicts into batched tensors.
700
701    Transforms list of per-sample dicts into a single dict of batched tensors,
702    matching the shapes expected by DynamicsWorldModel.forward():
703      video:            (B, 3, T, H, W)
704      discrete_actions: (B, T, 21)
705      rewards:          (B, T)
706    """
707    return {
708        'video': torch.stack([s['video'] for s in batch]),
709        'discrete_actions': torch.stack([s['discrete_actions'] for s in batch]),
710        'rewards': torch.stack([s['rewards'] for s in batch]),
711    }
BUTTONS_ALL = ['attack', 'back', 'forward', 'jump', 'left', 'right', 'sneak', 'sprint', 'use', 'drop', 'inventory', 'hotbar.1', 'hotbar.2', 'hotbar.3', 'hotbar.4', 'hotbar.5', 'hotbar.6', 'hotbar.7', 'hotbar.8', 'hotbar.9']
N_BUTTONS = 20
KEYBOARD_BUTTON_MAPPING = {'key.keyboard.escape': 'ESC', 'key.keyboard.s': 'back', 'key.keyboard.q': 'drop', 'key.keyboard.w': 'forward', 'key.keyboard.1': 'hotbar.1', 'key.keyboard.2': 'hotbar.2', 'key.keyboard.3': 'hotbar.3', 'key.keyboard.4': 'hotbar.4', 'key.keyboard.5': 'hotbar.5', 'key.keyboard.6': 'hotbar.6', 'key.keyboard.7': 'hotbar.7', 'key.keyboard.8': 'hotbar.8', 'key.keyboard.9': 'hotbar.9', 'key.keyboard.e': 'inventory', 'key.keyboard.space': 'jump', 'key.keyboard.a': 'left', 'key.keyboard.d': 'right', 'key.keyboard.left.shift': 'sneak', 'key.keyboard.left.control': 'sprint', 'key.keyboard.f': 'swapHands'}
CAMERA_MAXVAL = 10
CAMERA_BINSIZE = 2
CAMERA_MU = 10
N_CAMERA_BINS = 11
CAMERA_SCALER = 0.15
DREAMER4_NUM_DISCRETE_ACTIONS = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 121)
def mu_law_encode(x: numpy.ndarray, mu: float = 10) -> numpy.ndarray:
117def mu_law_encode(x: np.ndarray, mu: float = CAMERA_MU) -> np.ndarray:
118    """Apply mu-law compression: maps [-maxval, maxval] → [-maxval, maxval].
119
120    This is the same encoding used by VPT's CameraQuantizer with
121    quantization_scheme="mu_law". It compresses the dynamic range so
122    small camera movements get more precision (foveated discretization).
123    """
124    x_clipped = np.clip(x, -CAMERA_MAXVAL, CAMERA_MAXVAL)
125    x_norm = x_clipped / CAMERA_MAXVAL
126    encoded = np.sign(x_norm) * (np.log(1.0 + mu * np.abs(x_norm)) / np.log(1.0 + mu))
127    return encoded * CAMERA_MAXVAL

Apply mu-law compression: maps [-maxval, maxval] → [-maxval, maxval].

This is the same encoding used by VPT's CameraQuantizer with quantization_scheme="mu_law". It compresses the dynamic range so small camera movements get more precision (foveated discretization).

def discretize_camera(camera_xy: numpy.ndarray) -> numpy.ndarray:
130def discretize_camera(camera_xy: np.ndarray) -> np.ndarray:
131    """Discretize continuous camera (pitch, yaw) to bin indices.
132
133    Matches VPT's CameraQuantizer.discretize() with mu_law scheme:
134      1. Clip to [-CAMERA_MAXVAL, CAMERA_MAXVAL]
135      2. Apply mu-law encoding
136      3. Linear quantization to N_CAMERA_BINS bins
137
138    Args:
139        camera_xy: (..., 2) array of [pitch, yaw] in degrees
140
141    Returns:
142        (..., 2) array of bin indices, each in [0, N_CAMERA_BINS-1]
143    """
144    encoded = mu_law_encode(camera_xy)
145    bins = np.round((encoded + CAMERA_MAXVAL) / CAMERA_BINSIZE).astype(np.int64)
146    bins = np.clip(bins, 0, N_CAMERA_BINS - 1)
147    return bins

Discretize continuous camera (pitch, yaw) to bin indices.

Matches VPT's CameraQuantizer.discretize() with mu_law scheme:

  1. Clip to [-CAMERA_MAXVAL, CAMERA_MAXVAL]
  2. Apply mu-law encoding
  3. Linear quantization to N_CAMERA_BINS bins
Arguments:
  • camera_xy: (..., 2) array of [pitch, yaw] in degrees
Returns:

(..., 2) array of bin indices, each in [0, N_CAMERA_BINS-1]

def camera_bins_to_joint_index(pitch_bin: int, yaw_bin: int) -> int:
150def camera_bins_to_joint_index(pitch_bin: int, yaw_bin: int) -> int:
151    """Combine 2D camera bins into a single index for Dreamer4.
152
153    Joint index = pitch_bin * N_CAMERA_BINS + yaw_bin
154    Range: [0, 120] for 11x11 grid
155    """
156    return pitch_bin * N_CAMERA_BINS + yaw_bin

Combine 2D camera bins into a single index for Dreamer4.

Joint index = pitch_bin * N_CAMERA_BINS + yaw_bin Range: [0, 120] for 11x11 grid

def parse_jsonl_action(step_data: dict, attack_is_stuck: bool = False) -> tuple:
161def parse_jsonl_action(step_data: dict, attack_is_stuck: bool = False) -> tuple:
162    """Parse a single JSONL action record into MineRL-style env action.
163
164    This replicates the logic from VPT's data_loader.py and
165    run_inverse_dynamics_model.py:json_action_to_env_action().
166
167    Args:
168        step_data: One parsed JSON object from the .jsonl file
169        attack_is_stuck: Whether the attack button is stuck down (recorder bug)
170
171    Returns:
172        (env_action_dict, is_null_action, new_attack_is_stuck)
173    """
174    # Handle attack-stuck bug (same as VPT data_loader.py lines 86-95)
175    if attack_is_stuck:
176        step_data["mouse"]["buttons"] = [
177            b for b in step_data["mouse"]["buttons"] if b != 0
178        ]
179
180    # Build env action dict (matches NOOP_ACTION structure)
181    env_action = {b: 0 for b in BUTTONS_ALL}
182    env_action["ESC"] = 0
183    env_action["pickItem"] = 0
184    env_action["swapHands"] = 0
185    env_action["camera"] = np.array([0.0, 0.0])
186
187    is_null_action = True
188
189    # Keyboard keys
190    keyboard_keys = step_data.get("keyboard", {}).get("keys", [])
191    for key in keyboard_keys:
192        if key in KEYBOARD_BUTTON_MAPPING:
193            action_name = KEYBOARD_BUTTON_MAPPING[key]
194            if action_name in env_action:
195                env_action[action_name] = 1
196                if action_name in BUTTONS_ALL:
197                    is_null_action = False
198
199    # Mouse buttons: 0=attack, 1=use, 2=pickItem
200    mouse = step_data.get("mouse", {})
201    mouse_buttons = mouse.get("buttons", [])
202    if 0 in mouse_buttons:
203        env_action["attack"] = 1
204        is_null_action = False
205    if 1 in mouse_buttons:
206        env_action["use"] = 1
207        is_null_action = False
208    if 2 in mouse_buttons:
209        env_action["pickItem"] = 1
210        is_null_action = False
211
212    # Camera: mouse dx/dy → degrees
213    camera_action = env_action["camera"]
214    camera_action[0] = mouse.get("dy", 0) * CAMERA_SCALER  # pitch
215    camera_action[1] = mouse.get("dx", 0) * CAMERA_SCALER  # yaw
216
217    if mouse.get("dx", 0) != 0 or mouse.get("dy", 0) != 0:
218        is_null_action = False
219    else:
220        if abs(camera_action[0]) > 180:
221            camera_action[0] = 0
222        if abs(camera_action[1]) > 180:
223            camera_action[1] = 0
224
225    return env_action, is_null_action

Parse a single JSONL action record into MineRL-style env action.

This replicates the logic from VPT's data_loader.py and run_inverse_dynamics_model.py:json_action_to_env_action().

Arguments:
  • step_data: One parsed JSON object from the .jsonl file
  • attack_is_stuck: Whether the attack button is stuck down (recorder bug)
Returns:

(env_action_dict, is_null_action, new_attack_is_stuck)

def env_action_to_dreamer4(env_action: dict) -> numpy.ndarray:
228def env_action_to_dreamer4(env_action: dict) -> np.ndarray:
229    """Convert MineRL env action dict to Dreamer4 discrete action vector.
230
231    Returns:
232        (21,) int64 array:
233          [0:20] = button values (0 or 1)
234          [20]   = camera joint index (0 to 120)
235    """
236    action = np.zeros(N_BUTTONS + 1, dtype=np.int64)
237
238    # Buttons
239    for i, button_name in enumerate(BUTTONS_ALL):
240        action[i] = int(env_action.get(button_name, 0))
241
242    # Camera: discretize continuous degrees to bins, then to joint index
243    camera_xy = env_action["camera"]  # [pitch, yaw] in degrees
244    bins = discretize_camera(camera_xy)
245    action[N_BUTTONS] = camera_bins_to_joint_index(int(bins[0]), int(bins[1]))
246
247    return action

Convert MineRL env action dict to Dreamer4 discrete action vector.

Returns:

(21,) int64 array: [0:20] = button values (0 or 1) [20] = camera joint index (0 to 120)

def load_trajectory( video_path: str, jsonl_path: str, target_height: int = 384, target_width: int = 640, skip_null_actions: bool = True) -> tuple:
252def load_trajectory(
253    video_path: str,
254    jsonl_path: str,
255    target_height: int = 384,
256    target_width: int = 640,
257    skip_null_actions: bool = True,
258) -> tuple:
259    """Load a single VPT recording into arrays.
260
261    Reads the .mp4 and .jsonl pair, parsing actions frame-by-frame using
262    the same logic as VPT's data_loader.py (null-action filtering,
263    attack-stuck handling, hotbar tracking, cursor overlay for GUI).
264
265    Args:
266        video_path: Path to .mp4 file
267        jsonl_path: Path to .jsonl file
268        target_height: Resize frames to this height
269        target_width: Resize frames to this width
270        skip_null_actions: Whether to skip null actions (as VPT paper does)
271
272    Returns:
273        (frames, actions, rewards) where:
274          frames: (T, H, W, 3) uint8 array
275          actions: (T, 21) int64 array (Dreamer4 discrete format)
276          rewards: (T,) float32 array (zeros — VPT recordings have no reward)
277    """
278    # Load JSONL actions
279    with open(jsonl_path, encoding="utf-8") as f:
280        json_lines = f.readlines()
281        json_data = json.loads("[" + ",".join(json_lines) + "]")
282
283    # Open video
284    cap = cv2.VideoCapture(video_path)
285    if not cap.isOpened():
286        raise RuntimeError(f"Cannot open video: {video_path}")
287
288    frames = []
289    actions = []
290
291    # Attack-stuck workaround (same as VPT data_loader.py)
292    attack_is_stuck = False
293    last_hotbar = 0
294
295    # Optional: load cursor image for GUI overlay
296    cursor_path = os.path.join(
297        os.path.dirname(__file__), "Video-Pre-Training", "cursors",
298        "mouse_cursor_white_16x16.png"
299    )
300    cursor_image = None
301    cursor_alpha = None
302    if os.path.exists(cursor_path):
303        cursor_img = cv2.imread(cursor_path, cv2.IMREAD_UNCHANGED)
304        if cursor_img is not None:
305            cursor_img = cursor_img[:16, :16, :]
306            cursor_alpha = cursor_img[:, :, 3:] / 255.0
307            cursor_image = cursor_img[:, :, :3]
308
309    for i, step_data in enumerate(json_data):
310        # Handle attack-stuck bug
311        if i == 0:
312            if step_data.get("mouse", {}).get("newButtons") == [0]:
313                attack_is_stuck = True
314        elif attack_is_stuck:
315            if 0 in step_data.get("mouse", {}).get("newButtons", []):
316                attack_is_stuck = False
317
318        if attack_is_stuck:
319            step_data["mouse"]["buttons"] = [
320                b for b in step_data["mouse"]["buttons"] if b != 0
321            ]
322
323        # Parse action
324        env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False)
325
326        # Hotbar tracking (VPT data_loader.py lines 99-103)
327        current_hotbar = step_data.get("hotbar", 0)
328        if current_hotbar != last_hotbar:
329            env_action[f"hotbar.{current_hotbar + 1}"] = 1
330            is_null = False
331        last_hotbar = current_hotbar
332
333        # Read corresponding video frame
334        ret, frame = cap.read()
335        if not ret:
336            break
337
338        # Skip null actions (as done in VPT paper)
339        if skip_null_actions and is_null:
340            continue
341
342        # GUI cursor overlay (VPT data_loader.py lines 113-117)
343        if step_data.get("isGuiOpen", False) and cursor_image is not None:
344            h_orig = 720  # MINEREC_ORIGINAL_HEIGHT_PX
345            scale = frame.shape[0] / h_orig
346            cx = int(step_data["mouse"]["x"] * scale)
347            cy = int(step_data["mouse"]["y"] * scale)
348            _composite_cursor(frame, cursor_image, cursor_alpha, cx, cy)
349
350        # BGR → RGB
351        cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)
352        frame = np.clip(frame, 0, 255).astype(np.uint8)
353
354        # Zero-pad to target resolution (paper: 360x640 → 384x640)
355        # If native resolution matches target width, pad height only;
356        # otherwise fall back to resize for non-standard source video.
357        h, w = frame.shape[:2]
358        if w == target_width and h < target_height:
359            pad_total = target_height - h
360            pad_top = pad_total // 2
361            pad_bottom = pad_total - pad_top
362            frame = cv2.copyMakeBorder(
363                frame, pad_top, pad_bottom, 0, 0,
364                cv2.BORDER_CONSTANT, value=(0, 0, 0)
365            )
366        elif h != target_height or w != target_width:
367            frame = cv2.resize(frame, (target_width, target_height),
368                               interpolation=cv2.INTER_LINEAR)
369
370        # Convert to Dreamer4 action format
371        d4_action = env_action_to_dreamer4(env_action)
372
373        frames.append(frame)
374        actions.append(d4_action)
375
376    cap.release()
377
378    if len(frames) == 0:
379        return np.empty((0, target_height, target_width, 3), dtype=np.uint8), \
380               np.empty((0, N_BUTTONS + 1), dtype=np.int64), \
381               np.empty((0,), dtype=np.float32)
382
383    frames = np.stack(frames, axis=0)
384    actions = np.stack(actions, axis=0)
385    # VPT recordings don't include reward, so we use zeros.
386    # During behavior cloning this is fine — rewards are only needed
387    # for RL training (Phase 3: DreamTrainer).
388    rewards = np.zeros(len(frames), dtype=np.float32)
389
390    return frames, actions, rewards

Load a single VPT recording into arrays.

Reads the .mp4 and .jsonl pair, parsing actions frame-by-frame using the same logic as VPT's data_loader.py (null-action filtering, attack-stuck handling, hotbar tracking, cursor overlay for GUI).

Arguments:
  • video_path: Path to .mp4 file
  • jsonl_path: Path to .jsonl file
  • target_height: Resize frames to this height
  • target_width: Resize frames to this width
  • skip_null_actions: Whether to skip null actions (as VPT paper does)
Returns:

(frames, actions, rewards) where: frames: (T, H, W, 3) uint8 array actions: (T, 21) int64 array (Dreamer4 discrete format) rewards: (T,) float32 array (zeros — VPT recordings have no reward)

def prescan_trajectory(mp4_path: str, jsonl_path: str, skip_null_actions: bool = True) -> tuple:
408def prescan_trajectory(
409    mp4_path: str,
410    jsonl_path: str,
411    skip_null_actions: bool = True,
412) -> tuple:
413    """Pre-scan a VPT recording to extract metadata without loading video frames.
414
415    Reads only the JSONL (for actions and null-action filtering) and the MP4
416    header (for frame count). No pixel data is loaded into memory.
417
418    Returns:
419        valid_frame_indices: (N,) int32 array of raw MP4 frame indices that
420            survived null-action filtering.
421        actions: (N, 21) int16 array of Dreamer4 discrete actions.
422    """
423    # Read JSONL
424    with open(jsonl_path, encoding="utf-8") as f:
425        json_data = [json.loads(line) for line in f]
426
427    # Get frame count from video header only (no pixel decode)
428    cap = cv2.VideoCapture(mp4_path)
429    n_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
430    cap.release()
431
432    valid_indices = []
433    actions = []
434
435    attack_is_stuck = False
436    last_hotbar = 0
437
438    for i, step_data in enumerate(json_data):
439        if i >= n_video_frames:
440            break
441
442        # Handle attack-stuck bug (same as load_trajectory)
443        if i == 0:
444            if step_data.get("mouse", {}).get("newButtons") == [0]:
445                attack_is_stuck = True
446        elif attack_is_stuck:
447            if 0 in step_data.get("mouse", {}).get("newButtons", []):
448                attack_is_stuck = False
449
450        if attack_is_stuck:
451            step_data["mouse"]["buttons"] = [
452                b for b in step_data["mouse"]["buttons"] if b != 0
453            ]
454
455        # Parse action
456        env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False)
457
458        # Hotbar tracking (VPT data_loader.py lines 99-103)
459        current_hotbar = step_data.get("hotbar", 0)
460        if current_hotbar != last_hotbar:
461            env_action[f"hotbar.{current_hotbar + 1}"] = 1
462            is_null = False
463        last_hotbar = current_hotbar
464
465        # Skip null actions (as done in VPT paper)
466        if skip_null_actions and is_null:
467            continue
468
469        # Convert to Dreamer4 format
470        d4_action = env_action_to_dreamer4(env_action)
471
472        valid_indices.append(i)
473        actions.append(d4_action)
474
475    valid_indices = np.array(valid_indices, dtype=np.int32)
476    if len(actions) > 0:
477        actions = np.stack(actions, axis=0).astype(np.int16)
478    else:
479        actions = np.empty((0, N_BUTTONS + 1), dtype=np.int16)
480
481    return valid_indices, actions

Pre-scan a VPT recording to extract metadata without loading video frames.

Reads only the JSONL (for actions and null-action filtering) and the MP4 header (for frame count). No pixel data is loaded into memory.

Returns:

valid_frame_indices: (N,) int32 array of raw MP4 frame indices that survived null-action filtering. actions: (N, 21) int16 array of Dreamer4 discrete actions.

class MinecraftVPTDataset(typing.Generic[+_T_co]):
551class MinecraftVPTDataset(Dataset):
552    """Lazy-loading PyTorch Dataset for VPT Minecraft recordings.
553
554    Decodes video frames on-the-fly from the original MP4 files instead of
555    pre-loading everything into RAM.  This means:
556      - Zero extra disk storage (the MP4 files ARE the dataset)
557      - RAM usage is independent of video count (~1 MB metadata per trajectory)
558      - Scales to thousands of videos without issue
559
560    At init time, only the JSONL files and MP4 headers are read to build a
561    lightweight index of valid frame positions and pre-parsed actions.  The
562    actual pixel data is decoded from the MP4 on each __getitem__ call.
563
564    For best throughput, use DataLoader with num_workers > 0 so that frame
565    decoding runs in parallel worker processes.  Install ``decord`` for
566    faster random-access seeking (falls back to cv2 otherwise).
567
568    Shape convention (per sample):
569        video:   (3, seq_len, H, W) float32 in [0, 1]
570        actions: (seq_len, 21) int64
571        rewards: (seq_len,) float32
572    """
573
574    def __init__(
575        self,
576        data_dir: str,
577        seq_len: int = 16,
578        stride: int = 8,
579        image_height: int = 384,
580        image_width: int = 640,
581        skip_null_actions: bool = True,
582        max_trajectories: Optional[int] = None,
583    ):
584        """
585        Args:
586            data_dir: Directory containing .mp4/.jsonl pairs
587            seq_len: Number of frames per training clip
588            stride: Step size between consecutive clips (for overlap)
589            image_height: Target image height
590            image_width: Target image width
591            skip_null_actions: Skip null actions as VPT paper does
592            max_trajectories: Limit number of loaded trajectories (for debugging)
593        """
594        super().__init__()
595        self.seq_len = seq_len
596        self.image_height = image_height
597        self.image_width = image_width
598
599        # Find all .mp4 files and their matching .jsonl files
600        mp4_files = sorted(glob.glob(os.path.join(data_dir, "*.mp4")))
601        if max_trajectories is not None:
602            mp4_files = mp4_files[:max_trajectories]
603
604        # Pre-scan: read only JSONLs + MP4 headers (no pixel data loaded)
605        # Per trajectory we store:
606        #   mp4_path:       string          (~100 bytes)
607        #   valid_indices:  int32 array     (~4 bytes/frame)
608        #   actions:        int16 array     (~42 bytes/frame)
609        # Total: ~46 bytes per valid frame — vs. 196,608 bytes per frame
610        # for the old eager float32 approach (4,272x smaller).
611        self.trajectories = []   # list of (mp4_path, valid_indices, actions)
612        self.clip_index = []     # list of (traj_idx, offset_in_valid_frames)
613
614        print(f"Scanning {len(mp4_files)} trajectories from {data_dir}...")
615        for mp4_path in mp4_files:
616            base = os.path.splitext(mp4_path)[0]
617            jsonl_path = base + ".jsonl"
618            if not os.path.exists(jsonl_path):
619                print(f"  Warning: no .jsonl for {mp4_path}, skipping")
620                continue
621
622            try:
623                valid_indices, actions = prescan_trajectory(
624                    mp4_path, jsonl_path,
625                    skip_null_actions=skip_null_actions,
626                )
627            except Exception as e:
628                print(f"  Error scanning {mp4_path}: {e}")
629                continue
630
631            if len(valid_indices) < seq_len:
632                print(f"  Trajectory too short ({len(valid_indices)} valid frames "
633                      f"< {seq_len}), skipping")
634                continue
635
636            traj_idx = len(self.trajectories)
637            self.trajectories.append((mp4_path, valid_indices, actions))
638
639            # Build clip entries with sliding window
640            num_clips = (len(valid_indices) - seq_len) // stride + 1
641            for c in range(num_clips):
642                self.clip_index.append((traj_idx, c * stride))
643
644        print(f"Created {len(self.clip_index)} clips of length {seq_len} "
645              f"from {len(self.trajectories)} trajectories"
646              f" (backend: {'decord' if HAS_DECORD else 'cv2'})")
647
648    def __len__(self):
649        """Return the number of sliding-window clips across all trajectories."""
650        return len(self.clip_index)
651
652    def __getitem__(self, idx, _retries=3):
653        """Decode a clip on-the-fly and return as a dict.
654
655        If frame decoding fails (corrupt MP4 data), retries with a random
656        different clip up to ``_retries`` times so a single bad file doesn't
657        crash the entire training run.
658
659        Returns dict compatible with Dreamer4's BehaviorCloneTrainer:
660            'video': (3, seq_len, H, W) float32 in [0, 1]
661            'discrete_actions': (seq_len, 21) int64
662            'rewards': (seq_len,) float32
663        """
664        for attempt in range(_retries + 1):
665            try:
666                traj_idx, offset = self.clip_index[idx]
667                mp4_path, valid_indices, actions = self.trajectories[traj_idx]
668
669                # Slice the frame indices and actions for this clip
670                clip_frame_indices = valid_indices[offset:offset + self.seq_len]
671                clip_actions = actions[offset:offset + self.seq_len]
672
673                # Decode frames on-the-fly from the MP4 (no data stored in RAM)
674                frames = _decode_frames(
675                    mp4_path, clip_frame_indices,
676                    self.image_height, self.image_width,
677                )  # (seq_len, H, W, 3) uint8 RGB
678
679                # (T, H, W, 3) uint8 → (3, T, H, W) float32 in [0, 1]
680                video = torch.from_numpy(frames).permute(3, 0, 1, 2).float().div_(255.0)
681
682                return {
683                    'video': video,
684                    'discrete_actions': torch.from_numpy(clip_actions.astype(np.int64)),
685                    'rewards': torch.zeros(self.seq_len, dtype=torch.float32),
686                }
687            except Exception as e:
688                if attempt < _retries:
689                    print(f"  Warning: decode error at clip {idx} "
690                          f"({mp4_path}), retrying with a different clip: {e}")
691                    idx = random.randint(0, len(self.clip_index) - 1)
692                else:
693                    raise RuntimeError(
694                        f"Failed to decode clip after {_retries + 1} attempts. "
695                        f"Last error: {e}"
696                    ) from e

Lazy-loading PyTorch Dataset for VPT Minecraft recordings.

Decodes video frames on-the-fly from the original MP4 files instead of pre-loading everything into RAM. This means:

  • Zero extra disk storage (the MP4 files ARE the dataset)
  • RAM usage is independent of video count (~1 MB metadata per trajectory)
  • Scales to thousands of videos without issue

At init time, only the JSONL files and MP4 headers are read to build a lightweight index of valid frame positions and pre-parsed actions. The actual pixel data is decoded from the MP4 on each __getitem__ call.

For best throughput, use DataLoader with num_workers > 0 so that frame decoding runs in parallel worker processes. Install decord for faster random-access seeking (falls back to cv2 otherwise).

Shape convention (per sample): video: (3, seq_len, H, W) float32 in [0, 1] actions: (seq_len, 21) int64 rewards: (seq_len,) float32

MinecraftVPTDataset( data_dir: str, seq_len: int = 16, stride: int = 8, image_height: int = 384, image_width: int = 640, skip_null_actions: bool = True, max_trajectories: Optional[int] = None)
574    def __init__(
575        self,
576        data_dir: str,
577        seq_len: int = 16,
578        stride: int = 8,
579        image_height: int = 384,
580        image_width: int = 640,
581        skip_null_actions: bool = True,
582        max_trajectories: Optional[int] = None,
583    ):
584        """
585        Args:
586            data_dir: Directory containing .mp4/.jsonl pairs
587            seq_len: Number of frames per training clip
588            stride: Step size between consecutive clips (for overlap)
589            image_height: Target image height
590            image_width: Target image width
591            skip_null_actions: Skip null actions as VPT paper does
592            max_trajectories: Limit number of loaded trajectories (for debugging)
593        """
594        super().__init__()
595        self.seq_len = seq_len
596        self.image_height = image_height
597        self.image_width = image_width
598
599        # Find all .mp4 files and their matching .jsonl files
600        mp4_files = sorted(glob.glob(os.path.join(data_dir, "*.mp4")))
601        if max_trajectories is not None:
602            mp4_files = mp4_files[:max_trajectories]
603
604        # Pre-scan: read only JSONLs + MP4 headers (no pixel data loaded)
605        # Per trajectory we store:
606        #   mp4_path:       string          (~100 bytes)
607        #   valid_indices:  int32 array     (~4 bytes/frame)
608        #   actions:        int16 array     (~42 bytes/frame)
609        # Total: ~46 bytes per valid frame — vs. 196,608 bytes per frame
610        # for the old eager float32 approach (4,272x smaller).
611        self.trajectories = []   # list of (mp4_path, valid_indices, actions)
612        self.clip_index = []     # list of (traj_idx, offset_in_valid_frames)
613
614        print(f"Scanning {len(mp4_files)} trajectories from {data_dir}...")
615        for mp4_path in mp4_files:
616            base = os.path.splitext(mp4_path)[0]
617            jsonl_path = base + ".jsonl"
618            if not os.path.exists(jsonl_path):
619                print(f"  Warning: no .jsonl for {mp4_path}, skipping")
620                continue
621
622            try:
623                valid_indices, actions = prescan_trajectory(
624                    mp4_path, jsonl_path,
625                    skip_null_actions=skip_null_actions,
626                )
627            except Exception as e:
628                print(f"  Error scanning {mp4_path}: {e}")
629                continue
630
631            if len(valid_indices) < seq_len:
632                print(f"  Trajectory too short ({len(valid_indices)} valid frames "
633                      f"< {seq_len}), skipping")
634                continue
635
636            traj_idx = len(self.trajectories)
637            self.trajectories.append((mp4_path, valid_indices, actions))
638
639            # Build clip entries with sliding window
640            num_clips = (len(valid_indices) - seq_len) // stride + 1
641            for c in range(num_clips):
642                self.clip_index.append((traj_idx, c * stride))
643
644        print(f"Created {len(self.clip_index)} clips of length {seq_len} "
645              f"from {len(self.trajectories)} trajectories"
646              f" (backend: {'decord' if HAS_DECORD else 'cv2'})")
Arguments:
  • data_dir: Directory containing .mp4/.jsonl pairs
  • seq_len: Number of frames per training clip
  • stride: Step size between consecutive clips (for overlap)
  • image_height: Target image height
  • image_width: Target image width
  • skip_null_actions: Skip null actions as VPT paper does
  • max_trajectories: Limit number of loaded trajectories (for debugging)
seq_len
image_height
image_width
trajectories
clip_index
def collate_minecraft_batch(batch: list) -> dict:
699def collate_minecraft_batch(batch: list) -> dict:
700    """Custom collate function that stacks dicts into batched tensors.
701
702    Transforms list of per-sample dicts into a single dict of batched tensors,
703    matching the shapes expected by DynamicsWorldModel.forward():
704      video:            (B, 3, T, H, W)
705      discrete_actions: (B, T, 21)
706      rewards:          (B, T)
707    """
708    return {
709        'video': torch.stack([s['video'] for s in batch]),
710        'discrete_actions': torch.stack([s['discrete_actions'] for s in batch]),
711        'rewards': torch.stack([s['rewards'] for s in batch]),
712    }

Custom collate function that stacks dicts into batched tensors.

Transforms list of per-sample dicts into a single dict of batched tensors, matching the shapes expected by DynamicsWorldModel.forward(): video: (B, 3, T, H, W) discrete_actions: (B, T, 21) rewards: (B, T)