minecraft_vpt_dataset
VPT Minecraft Dataset for Dreamer4 World Model Training.
Converts OpenAI VPT .mp4/.jsonl recording pairs into the tensor format required by Dreamer4's VideoTokenizer and DynamicsWorldModel.
VPT recordings consist of:
- .mp4 files: 640x360 Minecraft gameplay video at 20 FPS
- .jsonl files: One JSON object per frame with keyboard/mouse/camera actions
Dreamer4 expects:
- video: (batch, channels, time, height, width) float32 tensors in [0, 1]
- discrete_actions: (batch, time, num_discrete_action_types) long tensors
- rewards: (batch, time) float32 tensors
The key design decision is how to map VPT's action space to Dreamer4's:
VPT Action Space:
- 20 binary keyboard/mouse buttons (attack, forward, jump, ...)
- 2D continuous camera (pitch, yaw) in degrees, discretized to 11x11=121 bins
Dreamer4 Action Space Options:
- All discrete: 20 binary buttons (each Discrete(2)) + 1 camera (Discrete(121)) → num_discrete_actions = (2,2,2,...,2, 121) = tuple of 21 ints
- Hybrid: 20 binary buttons discrete + 2 continuous camera → num_discrete_actions = (2,)*20, num_continuous_actions = 2
We use option 1 (all discrete) because:
- VPT already discretizes camera via mu-law encoding
- Avoids continuous action normalization complexity
- Matches how VPT's own policy head works (CategoricalActionHead)
- Camera bins preserve the mu-law foveation (more precision near center)
Action encoding for Dreamer4: discrete_actions[:, :, 0:20] = button states (0 or 1 each, num_choices=2) discrete_actions[:, :, 20] = camera bin index (0 to 120, num_choices=121)
This gives: num_discrete_actions = (2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 121)
1""" 2VPT Minecraft Dataset for Dreamer4 World Model Training. 3 4Converts OpenAI VPT .mp4/.jsonl recording pairs into the tensor format 5required by Dreamer4's VideoTokenizer and DynamicsWorldModel. 6 7VPT recordings consist of: 8 - .mp4 files: 640x360 Minecraft gameplay video at 20 FPS 9 - .jsonl files: One JSON object per frame with keyboard/mouse/camera actions 10 11Dreamer4 expects: 12 - video: (batch, channels, time, height, width) float32 tensors in [0, 1] 13 - discrete_actions: (batch, time, num_discrete_action_types) long tensors 14 - rewards: (batch, time) float32 tensors 15 16The key design decision is how to map VPT's action space to Dreamer4's: 17 18VPT Action Space: 19 - 20 binary keyboard/mouse buttons (attack, forward, jump, ...) 20 - 2D continuous camera (pitch, yaw) in degrees, discretized to 11x11=121 bins 21 22Dreamer4 Action Space Options: 23 1. All discrete: 20 binary buttons (each Discrete(2)) + 1 camera (Discrete(121)) 24 → num_discrete_actions = (2,2,2,...,2, 121) = tuple of 21 ints 25 2. Hybrid: 20 binary buttons discrete + 2 continuous camera 26 → num_discrete_actions = (2,)*20, num_continuous_actions = 2 27 28We use option 1 (all discrete) because: 29 - VPT already discretizes camera via mu-law encoding 30 - Avoids continuous action normalization complexity 31 - Matches how VPT's own policy head works (CategoricalActionHead) 32 - Camera bins preserve the mu-law foveation (more precision near center) 33 34Action encoding for Dreamer4: 35 discrete_actions[:, :, 0:20] = button states (0 or 1 each, num_choices=2) 36 discrete_actions[:, :, 20] = camera bin index (0 to 120, num_choices=121) 37 38This gives: num_discrete_actions = (2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 121) 39""" 40 41import os 42import json 43import glob 44import random 45from typing import Optional 46 47import cv2 48import numpy as np 49import torch 50from torch.utils.data import Dataset 51 52# Prefer decord for fast random-access frame seeking; fall back to cv2 53try: 54 import decord 55 decord.bridge.set_bridge("native") 56 HAS_DECORD = True 57except ImportError: 58 HAS_DECORD = False 59 60# ─── VPT Action Constants ─────────────────────────────────────────── 61 62# The 20 binary button actions in VPT, in the order defined by Buttons.ALL 63# in Video-Pre-Training/lib/actions.py 64BUTTONS_ALL = [ 65 "attack", "back", "forward", "jump", "left", "right", 66 "sneak", "sprint", "use", "drop", "inventory", 67 "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5", 68 "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", 69] 70 71N_BUTTONS = len(BUTTONS_ALL) # 20 72 73# VPT keyboard key → MineRL action name mapping 74# (from Video-Pre-Training/run_inverse_dynamics_model.py) 75KEYBOARD_BUTTON_MAPPING = { 76 "key.keyboard.escape": "ESC", 77 "key.keyboard.s": "back", 78 "key.keyboard.q": "drop", 79 "key.keyboard.w": "forward", 80 "key.keyboard.1": "hotbar.1", 81 "key.keyboard.2": "hotbar.2", 82 "key.keyboard.3": "hotbar.3", 83 "key.keyboard.4": "hotbar.4", 84 "key.keyboard.5": "hotbar.5", 85 "key.keyboard.6": "hotbar.6", 86 "key.keyboard.7": "hotbar.7", 87 "key.keyboard.8": "hotbar.8", 88 "key.keyboard.9": "hotbar.9", 89 "key.keyboard.e": "inventory", 90 "key.keyboard.space": "jump", 91 "key.keyboard.a": "left", 92 "key.keyboard.d": "right", 93 "key.keyboard.left.shift": "sneak", 94 "key.keyboard.left.control": "sprint", 95 "key.keyboard.f": "swapHands", 96} 97 98# Camera quantization parameters matching VPT's ActionTransformer 99# from Video-Pre-Training/agent.py: ACTION_TRANSFORMER_KWARGS 100CAMERA_MAXVAL = 10 101CAMERA_BINSIZE = 2 102CAMERA_MU = 10 # mu-law parameter 103N_CAMERA_BINS = 11 # per axis → 11*11 = 121 total camera actions 104 105# Sensitivity scaler from VPT recording format 106CAMERA_SCALER = 360.0 / 2400.0 107 108# Dreamer4 action space specification: 109# 20 binary buttons (each with 2 choices) + 1 camera (121 choices) 110DREAMER4_NUM_DISCRETE_ACTIONS = tuple([2] * N_BUTTONS + [N_CAMERA_BINS * N_CAMERA_BINS]) 111# Total action dimensions for the discrete_actions tensor: 21 112 113 114# ─── Camera Discretization (matching VPT's mu-law scheme) ─────────── 115 116def mu_law_encode(x: np.ndarray, mu: float = CAMERA_MU) -> np.ndarray: 117 """Apply mu-law compression: maps [-maxval, maxval] → [-maxval, maxval]. 118 119 This is the same encoding used by VPT's CameraQuantizer with 120 quantization_scheme="mu_law". It compresses the dynamic range so 121 small camera movements get more precision (foveated discretization). 122 """ 123 x_clipped = np.clip(x, -CAMERA_MAXVAL, CAMERA_MAXVAL) 124 x_norm = x_clipped / CAMERA_MAXVAL 125 encoded = np.sign(x_norm) * (np.log(1.0 + mu * np.abs(x_norm)) / np.log(1.0 + mu)) 126 return encoded * CAMERA_MAXVAL 127 128 129def discretize_camera(camera_xy: np.ndarray) -> np.ndarray: 130 """Discretize continuous camera (pitch, yaw) to bin indices. 131 132 Matches VPT's CameraQuantizer.discretize() with mu_law scheme: 133 1. Clip to [-CAMERA_MAXVAL, CAMERA_MAXVAL] 134 2. Apply mu-law encoding 135 3. Linear quantization to N_CAMERA_BINS bins 136 137 Args: 138 camera_xy: (..., 2) array of [pitch, yaw] in degrees 139 140 Returns: 141 (..., 2) array of bin indices, each in [0, N_CAMERA_BINS-1] 142 """ 143 encoded = mu_law_encode(camera_xy) 144 bins = np.round((encoded + CAMERA_MAXVAL) / CAMERA_BINSIZE).astype(np.int64) 145 bins = np.clip(bins, 0, N_CAMERA_BINS - 1) 146 return bins 147 148 149def camera_bins_to_joint_index(pitch_bin: int, yaw_bin: int) -> int: 150 """Combine 2D camera bins into a single index for Dreamer4. 151 152 Joint index = pitch_bin * N_CAMERA_BINS + yaw_bin 153 Range: [0, 120] for 11x11 grid 154 """ 155 return pitch_bin * N_CAMERA_BINS + yaw_bin 156 157 158# ─── JSONL Action Parsing ─────────────────────────────────────────── 159 160def parse_jsonl_action(step_data: dict, attack_is_stuck: bool = False) -> tuple: 161 """Parse a single JSONL action record into MineRL-style env action. 162 163 This replicates the logic from VPT's data_loader.py and 164 run_inverse_dynamics_model.py:json_action_to_env_action(). 165 166 Args: 167 step_data: One parsed JSON object from the .jsonl file 168 attack_is_stuck: Whether the attack button is stuck down (recorder bug) 169 170 Returns: 171 (env_action_dict, is_null_action, new_attack_is_stuck) 172 """ 173 # Handle attack-stuck bug (same as VPT data_loader.py lines 86-95) 174 if attack_is_stuck: 175 step_data["mouse"]["buttons"] = [ 176 b for b in step_data["mouse"]["buttons"] if b != 0 177 ] 178 179 # Build env action dict (matches NOOP_ACTION structure) 180 env_action = {b: 0 for b in BUTTONS_ALL} 181 env_action["ESC"] = 0 182 env_action["pickItem"] = 0 183 env_action["swapHands"] = 0 184 env_action["camera"] = np.array([0.0, 0.0]) 185 186 is_null_action = True 187 188 # Keyboard keys 189 keyboard_keys = step_data.get("keyboard", {}).get("keys", []) 190 for key in keyboard_keys: 191 if key in KEYBOARD_BUTTON_MAPPING: 192 action_name = KEYBOARD_BUTTON_MAPPING[key] 193 if action_name in env_action: 194 env_action[action_name] = 1 195 if action_name in BUTTONS_ALL: 196 is_null_action = False 197 198 # Mouse buttons: 0=attack, 1=use, 2=pickItem 199 mouse = step_data.get("mouse", {}) 200 mouse_buttons = mouse.get("buttons", []) 201 if 0 in mouse_buttons: 202 env_action["attack"] = 1 203 is_null_action = False 204 if 1 in mouse_buttons: 205 env_action["use"] = 1 206 is_null_action = False 207 if 2 in mouse_buttons: 208 env_action["pickItem"] = 1 209 is_null_action = False 210 211 # Camera: mouse dx/dy → degrees 212 camera_action = env_action["camera"] 213 camera_action[0] = mouse.get("dy", 0) * CAMERA_SCALER # pitch 214 camera_action[1] = mouse.get("dx", 0) * CAMERA_SCALER # yaw 215 216 if mouse.get("dx", 0) != 0 or mouse.get("dy", 0) != 0: 217 is_null_action = False 218 else: 219 if abs(camera_action[0]) > 180: 220 camera_action[0] = 0 221 if abs(camera_action[1]) > 180: 222 camera_action[1] = 0 223 224 return env_action, is_null_action 225 226 227def env_action_to_dreamer4(env_action: dict) -> np.ndarray: 228 """Convert MineRL env action dict to Dreamer4 discrete action vector. 229 230 Returns: 231 (21,) int64 array: 232 [0:20] = button values (0 or 1) 233 [20] = camera joint index (0 to 120) 234 """ 235 action = np.zeros(N_BUTTONS + 1, dtype=np.int64) 236 237 # Buttons 238 for i, button_name in enumerate(BUTTONS_ALL): 239 action[i] = int(env_action.get(button_name, 0)) 240 241 # Camera: discretize continuous degrees to bins, then to joint index 242 camera_xy = env_action["camera"] # [pitch, yaw] in degrees 243 bins = discretize_camera(camera_xy) 244 action[N_BUTTONS] = camera_bins_to_joint_index(int(bins[0]), int(bins[1])) 245 246 return action 247 248 249# ─── Trajectory Loading ───────────────────────────────────────────── 250 251def load_trajectory( 252 video_path: str, 253 jsonl_path: str, 254 target_height: int = 384, 255 target_width: int = 640, 256 skip_null_actions: bool = True, 257) -> tuple: 258 """Load a single VPT recording into arrays. 259 260 Reads the .mp4 and .jsonl pair, parsing actions frame-by-frame using 261 the same logic as VPT's data_loader.py (null-action filtering, 262 attack-stuck handling, hotbar tracking, cursor overlay for GUI). 263 264 Args: 265 video_path: Path to .mp4 file 266 jsonl_path: Path to .jsonl file 267 target_height: Resize frames to this height 268 target_width: Resize frames to this width 269 skip_null_actions: Whether to skip null actions (as VPT paper does) 270 271 Returns: 272 (frames, actions, rewards) where: 273 frames: (T, H, W, 3) uint8 array 274 actions: (T, 21) int64 array (Dreamer4 discrete format) 275 rewards: (T,) float32 array (zeros — VPT recordings have no reward) 276 """ 277 # Load JSONL actions 278 with open(jsonl_path, encoding="utf-8") as f: 279 json_lines = f.readlines() 280 json_data = json.loads("[" + ",".join(json_lines) + "]") 281 282 # Open video 283 cap = cv2.VideoCapture(video_path) 284 if not cap.isOpened(): 285 raise RuntimeError(f"Cannot open video: {video_path}") 286 287 frames = [] 288 actions = [] 289 290 # Attack-stuck workaround (same as VPT data_loader.py) 291 attack_is_stuck = False 292 last_hotbar = 0 293 294 # Optional: load cursor image for GUI overlay 295 cursor_path = os.path.join( 296 os.path.dirname(__file__), "Video-Pre-Training", "cursors", 297 "mouse_cursor_white_16x16.png" 298 ) 299 cursor_image = None 300 cursor_alpha = None 301 if os.path.exists(cursor_path): 302 cursor_img = cv2.imread(cursor_path, cv2.IMREAD_UNCHANGED) 303 if cursor_img is not None: 304 cursor_img = cursor_img[:16, :16, :] 305 cursor_alpha = cursor_img[:, :, 3:] / 255.0 306 cursor_image = cursor_img[:, :, :3] 307 308 for i, step_data in enumerate(json_data): 309 # Handle attack-stuck bug 310 if i == 0: 311 if step_data.get("mouse", {}).get("newButtons") == [0]: 312 attack_is_stuck = True 313 elif attack_is_stuck: 314 if 0 in step_data.get("mouse", {}).get("newButtons", []): 315 attack_is_stuck = False 316 317 if attack_is_stuck: 318 step_data["mouse"]["buttons"] = [ 319 b for b in step_data["mouse"]["buttons"] if b != 0 320 ] 321 322 # Parse action 323 env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False) 324 325 # Hotbar tracking (VPT data_loader.py lines 99-103) 326 current_hotbar = step_data.get("hotbar", 0) 327 if current_hotbar != last_hotbar: 328 env_action[f"hotbar.{current_hotbar + 1}"] = 1 329 is_null = False 330 last_hotbar = current_hotbar 331 332 # Read corresponding video frame 333 ret, frame = cap.read() 334 if not ret: 335 break 336 337 # Skip null actions (as done in VPT paper) 338 if skip_null_actions and is_null: 339 continue 340 341 # GUI cursor overlay (VPT data_loader.py lines 113-117) 342 if step_data.get("isGuiOpen", False) and cursor_image is not None: 343 h_orig = 720 # MINEREC_ORIGINAL_HEIGHT_PX 344 scale = frame.shape[0] / h_orig 345 cx = int(step_data["mouse"]["x"] * scale) 346 cy = int(step_data["mouse"]["y"] * scale) 347 _composite_cursor(frame, cursor_image, cursor_alpha, cx, cy) 348 349 # BGR → RGB 350 cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) 351 frame = np.clip(frame, 0, 255).astype(np.uint8) 352 353 # Zero-pad to target resolution (paper: 360x640 → 384x640) 354 # If native resolution matches target width, pad height only; 355 # otherwise fall back to resize for non-standard source video. 356 h, w = frame.shape[:2] 357 if w == target_width and h < target_height: 358 pad_total = target_height - h 359 pad_top = pad_total // 2 360 pad_bottom = pad_total - pad_top 361 frame = cv2.copyMakeBorder( 362 frame, pad_top, pad_bottom, 0, 0, 363 cv2.BORDER_CONSTANT, value=(0, 0, 0) 364 ) 365 elif h != target_height or w != target_width: 366 frame = cv2.resize(frame, (target_width, target_height), 367 interpolation=cv2.INTER_LINEAR) 368 369 # Convert to Dreamer4 action format 370 d4_action = env_action_to_dreamer4(env_action) 371 372 frames.append(frame) 373 actions.append(d4_action) 374 375 cap.release() 376 377 if len(frames) == 0: 378 return np.empty((0, target_height, target_width, 3), dtype=np.uint8), \ 379 np.empty((0, N_BUTTONS + 1), dtype=np.int64), \ 380 np.empty((0,), dtype=np.float32) 381 382 frames = np.stack(frames, axis=0) 383 actions = np.stack(actions, axis=0) 384 # VPT recordings don't include reward, so we use zeros. 385 # During behavior cloning this is fine — rewards are only needed 386 # for RL training (Phase 3: DreamTrainer). 387 rewards = np.zeros(len(frames), dtype=np.float32) 388 389 return frames, actions, rewards 390 391 392def _composite_cursor(image, cursor_img, cursor_alpha, x, y): 393 """Draw cursor onto image at (x, y). Modifies image in-place.""" 394 ch = max(0, min(image.shape[0] - y, cursor_img.shape[0])) 395 cw = max(0, min(image.shape[1] - x, cursor_img.shape[1])) 396 if ch == 0 or cw == 0: 397 return 398 alpha = cursor_alpha[:ch, :cw] 399 image[y:y+ch, x:x+cw, :] = ( 400 image[y:y+ch, x:x+cw, :] * (1 - alpha) + 401 cursor_img[:ch, :cw, :] * alpha 402 ).astype(np.uint8) 403 404 405# ─── Lazy Trajectory Pre-scan ────────────────────────────────────── 406 407def prescan_trajectory( 408 mp4_path: str, 409 jsonl_path: str, 410 skip_null_actions: bool = True, 411) -> tuple: 412 """Pre-scan a VPT recording to extract metadata without loading video frames. 413 414 Reads only the JSONL (for actions and null-action filtering) and the MP4 415 header (for frame count). No pixel data is loaded into memory. 416 417 Returns: 418 valid_frame_indices: (N,) int32 array of raw MP4 frame indices that 419 survived null-action filtering. 420 actions: (N, 21) int16 array of Dreamer4 discrete actions. 421 """ 422 # Read JSONL 423 with open(jsonl_path, encoding="utf-8") as f: 424 json_data = [json.loads(line) for line in f] 425 426 # Get frame count from video header only (no pixel decode) 427 cap = cv2.VideoCapture(mp4_path) 428 n_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 429 cap.release() 430 431 valid_indices = [] 432 actions = [] 433 434 attack_is_stuck = False 435 last_hotbar = 0 436 437 for i, step_data in enumerate(json_data): 438 if i >= n_video_frames: 439 break 440 441 # Handle attack-stuck bug (same as load_trajectory) 442 if i == 0: 443 if step_data.get("mouse", {}).get("newButtons") == [0]: 444 attack_is_stuck = True 445 elif attack_is_stuck: 446 if 0 in step_data.get("mouse", {}).get("newButtons", []): 447 attack_is_stuck = False 448 449 if attack_is_stuck: 450 step_data["mouse"]["buttons"] = [ 451 b for b in step_data["mouse"]["buttons"] if b != 0 452 ] 453 454 # Parse action 455 env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False) 456 457 # Hotbar tracking (VPT data_loader.py lines 99-103) 458 current_hotbar = step_data.get("hotbar", 0) 459 if current_hotbar != last_hotbar: 460 env_action[f"hotbar.{current_hotbar + 1}"] = 1 461 is_null = False 462 last_hotbar = current_hotbar 463 464 # Skip null actions (as done in VPT paper) 465 if skip_null_actions and is_null: 466 continue 467 468 # Convert to Dreamer4 format 469 d4_action = env_action_to_dreamer4(env_action) 470 471 valid_indices.append(i) 472 actions.append(d4_action) 473 474 valid_indices = np.array(valid_indices, dtype=np.int32) 475 if len(actions) > 0: 476 actions = np.stack(actions, axis=0).astype(np.int16) 477 else: 478 actions = np.empty((0, N_BUTTONS + 1), dtype=np.int16) 479 480 return valid_indices, actions 481 482 483# ─── Frame Decoding Backends ─────────────────────────────────────── 484 485def _zero_pad_frame(frame, image_height, image_width): 486 """Zero-pad a frame to (image_height, image_width) if width matches. 487 488 Follows the paper: 360x640 → 384x640 via symmetric zero-padding on height. 489 Falls back to resize if dimensions are unexpected. 490 """ 491 h, w = frame.shape[:2] 492 if w == image_width and h < image_height: 493 pad_total = image_height - h 494 pad_top = pad_total // 2 495 pad_bottom = pad_total - pad_top 496 return cv2.copyMakeBorder( 497 frame, pad_top, pad_bottom, 0, 0, 498 cv2.BORDER_CONSTANT, value=(0, 0, 0) 499 ) 500 if h != image_height or w != image_width: 501 return cv2.resize(frame, (image_width, image_height), 502 interpolation=cv2.INTER_LINEAR) 503 return frame 504 505 506def _decode_frames_decord(mp4_path, frame_indices, image_height, image_width): 507 """Decode specific frames from MP4 using decord (fast random access).""" 508 vr = decord.VideoReader(mp4_path, num_threads=1) 509 frames = vr.get_batch(frame_indices.tolist()).asnumpy() # (T, H, W, 3) RGB 510 if frames.shape[1] != image_height or frames.shape[2] != image_width: 511 frames = np.stack([ 512 _zero_pad_frame(f, image_height, image_width) for f in frames 513 ]) 514 return frames 515 516 517def _decode_frames_cv2(mp4_path, frame_indices, image_height, image_width): 518 """Decode specific frames from MP4 using cv2 sequential read. 519 520 Seeks to the first needed frame and reads sequentially to the last, 521 keeping only the requested indices. Efficient when frame_indices are 522 roughly contiguous (as they are after null-action filtering). 523 """ 524 first_idx = int(frame_indices[0]) 525 last_idx = int(frame_indices[-1]) 526 needed = set(int(i) for i in frame_indices) 527 528 cap = cv2.VideoCapture(mp4_path) 529 cap.set(cv2.CAP_PROP_POS_FRAMES, first_idx) 530 531 frames = {} 532 for raw_idx in range(first_idx, last_idx + 1): 533 ret, frame = cap.read() 534 if not ret: 535 break 536 if raw_idx in needed: 537 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 538 frame = _zero_pad_frame(frame, image_height, image_width) 539 frames[raw_idx] = frame 540 cap.release() 541 542 return np.stack([frames[int(i)] for i in frame_indices]) 543 544 545_decode_frames = _decode_frames_decord if HAS_DECORD else _decode_frames_cv2 546 547 548# ─── PyTorch Dataset ──────────────────────────────────────────────── 549 550class MinecraftVPTDataset(Dataset): 551 """Lazy-loading PyTorch Dataset for VPT Minecraft recordings. 552 553 Decodes video frames on-the-fly from the original MP4 files instead of 554 pre-loading everything into RAM. This means: 555 - Zero extra disk storage (the MP4 files ARE the dataset) 556 - RAM usage is independent of video count (~1 MB metadata per trajectory) 557 - Scales to thousands of videos without issue 558 559 At init time, only the JSONL files and MP4 headers are read to build a 560 lightweight index of valid frame positions and pre-parsed actions. The 561 actual pixel data is decoded from the MP4 on each __getitem__ call. 562 563 For best throughput, use DataLoader with num_workers > 0 so that frame 564 decoding runs in parallel worker processes. Install ``decord`` for 565 faster random-access seeking (falls back to cv2 otherwise). 566 567 Shape convention (per sample): 568 video: (3, seq_len, H, W) float32 in [0, 1] 569 actions: (seq_len, 21) int64 570 rewards: (seq_len,) float32 571 """ 572 573 def __init__( 574 self, 575 data_dir: str, 576 seq_len: int = 16, 577 stride: int = 8, 578 image_height: int = 384, 579 image_width: int = 640, 580 skip_null_actions: bool = True, 581 max_trajectories: Optional[int] = None, 582 ): 583 """ 584 Args: 585 data_dir: Directory containing .mp4/.jsonl pairs 586 seq_len: Number of frames per training clip 587 stride: Step size between consecutive clips (for overlap) 588 image_height: Target image height 589 image_width: Target image width 590 skip_null_actions: Skip null actions as VPT paper does 591 max_trajectories: Limit number of loaded trajectories (for debugging) 592 """ 593 super().__init__() 594 self.seq_len = seq_len 595 self.image_height = image_height 596 self.image_width = image_width 597 598 # Find all .mp4 files and their matching .jsonl files 599 mp4_files = sorted(glob.glob(os.path.join(data_dir, "*.mp4"))) 600 if max_trajectories is not None: 601 mp4_files = mp4_files[:max_trajectories] 602 603 # Pre-scan: read only JSONLs + MP4 headers (no pixel data loaded) 604 # Per trajectory we store: 605 # mp4_path: string (~100 bytes) 606 # valid_indices: int32 array (~4 bytes/frame) 607 # actions: int16 array (~42 bytes/frame) 608 # Total: ~46 bytes per valid frame — vs. 196,608 bytes per frame 609 # for the old eager float32 approach (4,272x smaller). 610 self.trajectories = [] # list of (mp4_path, valid_indices, actions) 611 self.clip_index = [] # list of (traj_idx, offset_in_valid_frames) 612 613 print(f"Scanning {len(mp4_files)} trajectories from {data_dir}...") 614 for mp4_path in mp4_files: 615 base = os.path.splitext(mp4_path)[0] 616 jsonl_path = base + ".jsonl" 617 if not os.path.exists(jsonl_path): 618 print(f" Warning: no .jsonl for {mp4_path}, skipping") 619 continue 620 621 try: 622 valid_indices, actions = prescan_trajectory( 623 mp4_path, jsonl_path, 624 skip_null_actions=skip_null_actions, 625 ) 626 except Exception as e: 627 print(f" Error scanning {mp4_path}: {e}") 628 continue 629 630 if len(valid_indices) < seq_len: 631 print(f" Trajectory too short ({len(valid_indices)} valid frames " 632 f"< {seq_len}), skipping") 633 continue 634 635 traj_idx = len(self.trajectories) 636 self.trajectories.append((mp4_path, valid_indices, actions)) 637 638 # Build clip entries with sliding window 639 num_clips = (len(valid_indices) - seq_len) // stride + 1 640 for c in range(num_clips): 641 self.clip_index.append((traj_idx, c * stride)) 642 643 print(f"Created {len(self.clip_index)} clips of length {seq_len} " 644 f"from {len(self.trajectories)} trajectories" 645 f" (backend: {'decord' if HAS_DECORD else 'cv2'})") 646 647 def __len__(self): 648 """Return the number of sliding-window clips across all trajectories.""" 649 return len(self.clip_index) 650 651 def __getitem__(self, idx, _retries=3): 652 """Decode a clip on-the-fly and return as a dict. 653 654 If frame decoding fails (corrupt MP4 data), retries with a random 655 different clip up to ``_retries`` times so a single bad file doesn't 656 crash the entire training run. 657 658 Returns dict compatible with Dreamer4's BehaviorCloneTrainer: 659 'video': (3, seq_len, H, W) float32 in [0, 1] 660 'discrete_actions': (seq_len, 21) int64 661 'rewards': (seq_len,) float32 662 """ 663 for attempt in range(_retries + 1): 664 try: 665 traj_idx, offset = self.clip_index[idx] 666 mp4_path, valid_indices, actions = self.trajectories[traj_idx] 667 668 # Slice the frame indices and actions for this clip 669 clip_frame_indices = valid_indices[offset:offset + self.seq_len] 670 clip_actions = actions[offset:offset + self.seq_len] 671 672 # Decode frames on-the-fly from the MP4 (no data stored in RAM) 673 frames = _decode_frames( 674 mp4_path, clip_frame_indices, 675 self.image_height, self.image_width, 676 ) # (seq_len, H, W, 3) uint8 RGB 677 678 # (T, H, W, 3) uint8 → (3, T, H, W) float32 in [0, 1] 679 video = torch.from_numpy(frames).permute(3, 0, 1, 2).float().div_(255.0) 680 681 return { 682 'video': video, 683 'discrete_actions': torch.from_numpy(clip_actions.astype(np.int64)), 684 'rewards': torch.zeros(self.seq_len, dtype=torch.float32), 685 } 686 except Exception as e: 687 if attempt < _retries: 688 print(f" Warning: decode error at clip {idx} " 689 f"({mp4_path}), retrying with a different clip: {e}") 690 idx = random.randint(0, len(self.clip_index) - 1) 691 else: 692 raise RuntimeError( 693 f"Failed to decode clip after {_retries + 1} attempts. " 694 f"Last error: {e}" 695 ) from e 696 697 698def collate_minecraft_batch(batch: list) -> dict: 699 """Custom collate function that stacks dicts into batched tensors. 700 701 Transforms list of per-sample dicts into a single dict of batched tensors, 702 matching the shapes expected by DynamicsWorldModel.forward(): 703 video: (B, 3, T, H, W) 704 discrete_actions: (B, T, 21) 705 rewards: (B, T) 706 """ 707 return { 708 'video': torch.stack([s['video'] for s in batch]), 709 'discrete_actions': torch.stack([s['discrete_actions'] for s in batch]), 710 'rewards': torch.stack([s['rewards'] for s in batch]), 711 }
117def mu_law_encode(x: np.ndarray, mu: float = CAMERA_MU) -> np.ndarray: 118 """Apply mu-law compression: maps [-maxval, maxval] → [-maxval, maxval]. 119 120 This is the same encoding used by VPT's CameraQuantizer with 121 quantization_scheme="mu_law". It compresses the dynamic range so 122 small camera movements get more precision (foveated discretization). 123 """ 124 x_clipped = np.clip(x, -CAMERA_MAXVAL, CAMERA_MAXVAL) 125 x_norm = x_clipped / CAMERA_MAXVAL 126 encoded = np.sign(x_norm) * (np.log(1.0 + mu * np.abs(x_norm)) / np.log(1.0 + mu)) 127 return encoded * CAMERA_MAXVAL
Apply mu-law compression: maps [-maxval, maxval] → [-maxval, maxval].
This is the same encoding used by VPT's CameraQuantizer with quantization_scheme="mu_law". It compresses the dynamic range so small camera movements get more precision (foveated discretization).
130def discretize_camera(camera_xy: np.ndarray) -> np.ndarray: 131 """Discretize continuous camera (pitch, yaw) to bin indices. 132 133 Matches VPT's CameraQuantizer.discretize() with mu_law scheme: 134 1. Clip to [-CAMERA_MAXVAL, CAMERA_MAXVAL] 135 2. Apply mu-law encoding 136 3. Linear quantization to N_CAMERA_BINS bins 137 138 Args: 139 camera_xy: (..., 2) array of [pitch, yaw] in degrees 140 141 Returns: 142 (..., 2) array of bin indices, each in [0, N_CAMERA_BINS-1] 143 """ 144 encoded = mu_law_encode(camera_xy) 145 bins = np.round((encoded + CAMERA_MAXVAL) / CAMERA_BINSIZE).astype(np.int64) 146 bins = np.clip(bins, 0, N_CAMERA_BINS - 1) 147 return bins
Discretize continuous camera (pitch, yaw) to bin indices.
Matches VPT's CameraQuantizer.discretize() with mu_law scheme:
- Clip to [-CAMERA_MAXVAL, CAMERA_MAXVAL]
- Apply mu-law encoding
- Linear quantization to N_CAMERA_BINS bins
Arguments:
- camera_xy: (..., 2) array of [pitch, yaw] in degrees
Returns:
(..., 2) array of bin indices, each in [0, N_CAMERA_BINS-1]
150def camera_bins_to_joint_index(pitch_bin: int, yaw_bin: int) -> int: 151 """Combine 2D camera bins into a single index for Dreamer4. 152 153 Joint index = pitch_bin * N_CAMERA_BINS + yaw_bin 154 Range: [0, 120] for 11x11 grid 155 """ 156 return pitch_bin * N_CAMERA_BINS + yaw_bin
Combine 2D camera bins into a single index for Dreamer4.
Joint index = pitch_bin * N_CAMERA_BINS + yaw_bin Range: [0, 120] for 11x11 grid
161def parse_jsonl_action(step_data: dict, attack_is_stuck: bool = False) -> tuple: 162 """Parse a single JSONL action record into MineRL-style env action. 163 164 This replicates the logic from VPT's data_loader.py and 165 run_inverse_dynamics_model.py:json_action_to_env_action(). 166 167 Args: 168 step_data: One parsed JSON object from the .jsonl file 169 attack_is_stuck: Whether the attack button is stuck down (recorder bug) 170 171 Returns: 172 (env_action_dict, is_null_action, new_attack_is_stuck) 173 """ 174 # Handle attack-stuck bug (same as VPT data_loader.py lines 86-95) 175 if attack_is_stuck: 176 step_data["mouse"]["buttons"] = [ 177 b for b in step_data["mouse"]["buttons"] if b != 0 178 ] 179 180 # Build env action dict (matches NOOP_ACTION structure) 181 env_action = {b: 0 for b in BUTTONS_ALL} 182 env_action["ESC"] = 0 183 env_action["pickItem"] = 0 184 env_action["swapHands"] = 0 185 env_action["camera"] = np.array([0.0, 0.0]) 186 187 is_null_action = True 188 189 # Keyboard keys 190 keyboard_keys = step_data.get("keyboard", {}).get("keys", []) 191 for key in keyboard_keys: 192 if key in KEYBOARD_BUTTON_MAPPING: 193 action_name = KEYBOARD_BUTTON_MAPPING[key] 194 if action_name in env_action: 195 env_action[action_name] = 1 196 if action_name in BUTTONS_ALL: 197 is_null_action = False 198 199 # Mouse buttons: 0=attack, 1=use, 2=pickItem 200 mouse = step_data.get("mouse", {}) 201 mouse_buttons = mouse.get("buttons", []) 202 if 0 in mouse_buttons: 203 env_action["attack"] = 1 204 is_null_action = False 205 if 1 in mouse_buttons: 206 env_action["use"] = 1 207 is_null_action = False 208 if 2 in mouse_buttons: 209 env_action["pickItem"] = 1 210 is_null_action = False 211 212 # Camera: mouse dx/dy → degrees 213 camera_action = env_action["camera"] 214 camera_action[0] = mouse.get("dy", 0) * CAMERA_SCALER # pitch 215 camera_action[1] = mouse.get("dx", 0) * CAMERA_SCALER # yaw 216 217 if mouse.get("dx", 0) != 0 or mouse.get("dy", 0) != 0: 218 is_null_action = False 219 else: 220 if abs(camera_action[0]) > 180: 221 camera_action[0] = 0 222 if abs(camera_action[1]) > 180: 223 camera_action[1] = 0 224 225 return env_action, is_null_action
Parse a single JSONL action record into MineRL-style env action.
This replicates the logic from VPT's data_loader.py and run_inverse_dynamics_model.py:json_action_to_env_action().
Arguments:
- step_data: One parsed JSON object from the .jsonl file
- attack_is_stuck: Whether the attack button is stuck down (recorder bug)
Returns:
(env_action_dict, is_null_action, new_attack_is_stuck)
228def env_action_to_dreamer4(env_action: dict) -> np.ndarray: 229 """Convert MineRL env action dict to Dreamer4 discrete action vector. 230 231 Returns: 232 (21,) int64 array: 233 [0:20] = button values (0 or 1) 234 [20] = camera joint index (0 to 120) 235 """ 236 action = np.zeros(N_BUTTONS + 1, dtype=np.int64) 237 238 # Buttons 239 for i, button_name in enumerate(BUTTONS_ALL): 240 action[i] = int(env_action.get(button_name, 0)) 241 242 # Camera: discretize continuous degrees to bins, then to joint index 243 camera_xy = env_action["camera"] # [pitch, yaw] in degrees 244 bins = discretize_camera(camera_xy) 245 action[N_BUTTONS] = camera_bins_to_joint_index(int(bins[0]), int(bins[1])) 246 247 return action
Convert MineRL env action dict to Dreamer4 discrete action vector.
Returns:
(21,) int64 array: [0:20] = button values (0 or 1) [20] = camera joint index (0 to 120)
252def load_trajectory( 253 video_path: str, 254 jsonl_path: str, 255 target_height: int = 384, 256 target_width: int = 640, 257 skip_null_actions: bool = True, 258) -> tuple: 259 """Load a single VPT recording into arrays. 260 261 Reads the .mp4 and .jsonl pair, parsing actions frame-by-frame using 262 the same logic as VPT's data_loader.py (null-action filtering, 263 attack-stuck handling, hotbar tracking, cursor overlay for GUI). 264 265 Args: 266 video_path: Path to .mp4 file 267 jsonl_path: Path to .jsonl file 268 target_height: Resize frames to this height 269 target_width: Resize frames to this width 270 skip_null_actions: Whether to skip null actions (as VPT paper does) 271 272 Returns: 273 (frames, actions, rewards) where: 274 frames: (T, H, W, 3) uint8 array 275 actions: (T, 21) int64 array (Dreamer4 discrete format) 276 rewards: (T,) float32 array (zeros — VPT recordings have no reward) 277 """ 278 # Load JSONL actions 279 with open(jsonl_path, encoding="utf-8") as f: 280 json_lines = f.readlines() 281 json_data = json.loads("[" + ",".join(json_lines) + "]") 282 283 # Open video 284 cap = cv2.VideoCapture(video_path) 285 if not cap.isOpened(): 286 raise RuntimeError(f"Cannot open video: {video_path}") 287 288 frames = [] 289 actions = [] 290 291 # Attack-stuck workaround (same as VPT data_loader.py) 292 attack_is_stuck = False 293 last_hotbar = 0 294 295 # Optional: load cursor image for GUI overlay 296 cursor_path = os.path.join( 297 os.path.dirname(__file__), "Video-Pre-Training", "cursors", 298 "mouse_cursor_white_16x16.png" 299 ) 300 cursor_image = None 301 cursor_alpha = None 302 if os.path.exists(cursor_path): 303 cursor_img = cv2.imread(cursor_path, cv2.IMREAD_UNCHANGED) 304 if cursor_img is not None: 305 cursor_img = cursor_img[:16, :16, :] 306 cursor_alpha = cursor_img[:, :, 3:] / 255.0 307 cursor_image = cursor_img[:, :, :3] 308 309 for i, step_data in enumerate(json_data): 310 # Handle attack-stuck bug 311 if i == 0: 312 if step_data.get("mouse", {}).get("newButtons") == [0]: 313 attack_is_stuck = True 314 elif attack_is_stuck: 315 if 0 in step_data.get("mouse", {}).get("newButtons", []): 316 attack_is_stuck = False 317 318 if attack_is_stuck: 319 step_data["mouse"]["buttons"] = [ 320 b for b in step_data["mouse"]["buttons"] if b != 0 321 ] 322 323 # Parse action 324 env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False) 325 326 # Hotbar tracking (VPT data_loader.py lines 99-103) 327 current_hotbar = step_data.get("hotbar", 0) 328 if current_hotbar != last_hotbar: 329 env_action[f"hotbar.{current_hotbar + 1}"] = 1 330 is_null = False 331 last_hotbar = current_hotbar 332 333 # Read corresponding video frame 334 ret, frame = cap.read() 335 if not ret: 336 break 337 338 # Skip null actions (as done in VPT paper) 339 if skip_null_actions and is_null: 340 continue 341 342 # GUI cursor overlay (VPT data_loader.py lines 113-117) 343 if step_data.get("isGuiOpen", False) and cursor_image is not None: 344 h_orig = 720 # MINEREC_ORIGINAL_HEIGHT_PX 345 scale = frame.shape[0] / h_orig 346 cx = int(step_data["mouse"]["x"] * scale) 347 cy = int(step_data["mouse"]["y"] * scale) 348 _composite_cursor(frame, cursor_image, cursor_alpha, cx, cy) 349 350 # BGR → RGB 351 cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) 352 frame = np.clip(frame, 0, 255).astype(np.uint8) 353 354 # Zero-pad to target resolution (paper: 360x640 → 384x640) 355 # If native resolution matches target width, pad height only; 356 # otherwise fall back to resize for non-standard source video. 357 h, w = frame.shape[:2] 358 if w == target_width and h < target_height: 359 pad_total = target_height - h 360 pad_top = pad_total // 2 361 pad_bottom = pad_total - pad_top 362 frame = cv2.copyMakeBorder( 363 frame, pad_top, pad_bottom, 0, 0, 364 cv2.BORDER_CONSTANT, value=(0, 0, 0) 365 ) 366 elif h != target_height or w != target_width: 367 frame = cv2.resize(frame, (target_width, target_height), 368 interpolation=cv2.INTER_LINEAR) 369 370 # Convert to Dreamer4 action format 371 d4_action = env_action_to_dreamer4(env_action) 372 373 frames.append(frame) 374 actions.append(d4_action) 375 376 cap.release() 377 378 if len(frames) == 0: 379 return np.empty((0, target_height, target_width, 3), dtype=np.uint8), \ 380 np.empty((0, N_BUTTONS + 1), dtype=np.int64), \ 381 np.empty((0,), dtype=np.float32) 382 383 frames = np.stack(frames, axis=0) 384 actions = np.stack(actions, axis=0) 385 # VPT recordings don't include reward, so we use zeros. 386 # During behavior cloning this is fine — rewards are only needed 387 # for RL training (Phase 3: DreamTrainer). 388 rewards = np.zeros(len(frames), dtype=np.float32) 389 390 return frames, actions, rewards
Load a single VPT recording into arrays.
Reads the .mp4 and .jsonl pair, parsing actions frame-by-frame using the same logic as VPT's data_loader.py (null-action filtering, attack-stuck handling, hotbar tracking, cursor overlay for GUI).
Arguments:
- video_path: Path to .mp4 file
- jsonl_path: Path to .jsonl file
- target_height: Resize frames to this height
- target_width: Resize frames to this width
- skip_null_actions: Whether to skip null actions (as VPT paper does)
Returns:
(frames, actions, rewards) where: frames: (T, H, W, 3) uint8 array actions: (T, 21) int64 array (Dreamer4 discrete format) rewards: (T,) float32 array (zeros — VPT recordings have no reward)
408def prescan_trajectory( 409 mp4_path: str, 410 jsonl_path: str, 411 skip_null_actions: bool = True, 412) -> tuple: 413 """Pre-scan a VPT recording to extract metadata without loading video frames. 414 415 Reads only the JSONL (for actions and null-action filtering) and the MP4 416 header (for frame count). No pixel data is loaded into memory. 417 418 Returns: 419 valid_frame_indices: (N,) int32 array of raw MP4 frame indices that 420 survived null-action filtering. 421 actions: (N, 21) int16 array of Dreamer4 discrete actions. 422 """ 423 # Read JSONL 424 with open(jsonl_path, encoding="utf-8") as f: 425 json_data = [json.loads(line) for line in f] 426 427 # Get frame count from video header only (no pixel decode) 428 cap = cv2.VideoCapture(mp4_path) 429 n_video_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 430 cap.release() 431 432 valid_indices = [] 433 actions = [] 434 435 attack_is_stuck = False 436 last_hotbar = 0 437 438 for i, step_data in enumerate(json_data): 439 if i >= n_video_frames: 440 break 441 442 # Handle attack-stuck bug (same as load_trajectory) 443 if i == 0: 444 if step_data.get("mouse", {}).get("newButtons") == [0]: 445 attack_is_stuck = True 446 elif attack_is_stuck: 447 if 0 in step_data.get("mouse", {}).get("newButtons", []): 448 attack_is_stuck = False 449 450 if attack_is_stuck: 451 step_data["mouse"]["buttons"] = [ 452 b for b in step_data["mouse"]["buttons"] if b != 0 453 ] 454 455 # Parse action 456 env_action, is_null = parse_jsonl_action(step_data, attack_is_stuck=False) 457 458 # Hotbar tracking (VPT data_loader.py lines 99-103) 459 current_hotbar = step_data.get("hotbar", 0) 460 if current_hotbar != last_hotbar: 461 env_action[f"hotbar.{current_hotbar + 1}"] = 1 462 is_null = False 463 last_hotbar = current_hotbar 464 465 # Skip null actions (as done in VPT paper) 466 if skip_null_actions and is_null: 467 continue 468 469 # Convert to Dreamer4 format 470 d4_action = env_action_to_dreamer4(env_action) 471 472 valid_indices.append(i) 473 actions.append(d4_action) 474 475 valid_indices = np.array(valid_indices, dtype=np.int32) 476 if len(actions) > 0: 477 actions = np.stack(actions, axis=0).astype(np.int16) 478 else: 479 actions = np.empty((0, N_BUTTONS + 1), dtype=np.int16) 480 481 return valid_indices, actions
Pre-scan a VPT recording to extract metadata without loading video frames.
Reads only the JSONL (for actions and null-action filtering) and the MP4 header (for frame count). No pixel data is loaded into memory.
Returns:
valid_frame_indices: (N,) int32 array of raw MP4 frame indices that survived null-action filtering. actions: (N, 21) int16 array of Dreamer4 discrete actions.
551class MinecraftVPTDataset(Dataset): 552 """Lazy-loading PyTorch Dataset for VPT Minecraft recordings. 553 554 Decodes video frames on-the-fly from the original MP4 files instead of 555 pre-loading everything into RAM. This means: 556 - Zero extra disk storage (the MP4 files ARE the dataset) 557 - RAM usage is independent of video count (~1 MB metadata per trajectory) 558 - Scales to thousands of videos without issue 559 560 At init time, only the JSONL files and MP4 headers are read to build a 561 lightweight index of valid frame positions and pre-parsed actions. The 562 actual pixel data is decoded from the MP4 on each __getitem__ call. 563 564 For best throughput, use DataLoader with num_workers > 0 so that frame 565 decoding runs in parallel worker processes. Install ``decord`` for 566 faster random-access seeking (falls back to cv2 otherwise). 567 568 Shape convention (per sample): 569 video: (3, seq_len, H, W) float32 in [0, 1] 570 actions: (seq_len, 21) int64 571 rewards: (seq_len,) float32 572 """ 573 574 def __init__( 575 self, 576 data_dir: str, 577 seq_len: int = 16, 578 stride: int = 8, 579 image_height: int = 384, 580 image_width: int = 640, 581 skip_null_actions: bool = True, 582 max_trajectories: Optional[int] = None, 583 ): 584 """ 585 Args: 586 data_dir: Directory containing .mp4/.jsonl pairs 587 seq_len: Number of frames per training clip 588 stride: Step size between consecutive clips (for overlap) 589 image_height: Target image height 590 image_width: Target image width 591 skip_null_actions: Skip null actions as VPT paper does 592 max_trajectories: Limit number of loaded trajectories (for debugging) 593 """ 594 super().__init__() 595 self.seq_len = seq_len 596 self.image_height = image_height 597 self.image_width = image_width 598 599 # Find all .mp4 files and their matching .jsonl files 600 mp4_files = sorted(glob.glob(os.path.join(data_dir, "*.mp4"))) 601 if max_trajectories is not None: 602 mp4_files = mp4_files[:max_trajectories] 603 604 # Pre-scan: read only JSONLs + MP4 headers (no pixel data loaded) 605 # Per trajectory we store: 606 # mp4_path: string (~100 bytes) 607 # valid_indices: int32 array (~4 bytes/frame) 608 # actions: int16 array (~42 bytes/frame) 609 # Total: ~46 bytes per valid frame — vs. 196,608 bytes per frame 610 # for the old eager float32 approach (4,272x smaller). 611 self.trajectories = [] # list of (mp4_path, valid_indices, actions) 612 self.clip_index = [] # list of (traj_idx, offset_in_valid_frames) 613 614 print(f"Scanning {len(mp4_files)} trajectories from {data_dir}...") 615 for mp4_path in mp4_files: 616 base = os.path.splitext(mp4_path)[0] 617 jsonl_path = base + ".jsonl" 618 if not os.path.exists(jsonl_path): 619 print(f" Warning: no .jsonl for {mp4_path}, skipping") 620 continue 621 622 try: 623 valid_indices, actions = prescan_trajectory( 624 mp4_path, jsonl_path, 625 skip_null_actions=skip_null_actions, 626 ) 627 except Exception as e: 628 print(f" Error scanning {mp4_path}: {e}") 629 continue 630 631 if len(valid_indices) < seq_len: 632 print(f" Trajectory too short ({len(valid_indices)} valid frames " 633 f"< {seq_len}), skipping") 634 continue 635 636 traj_idx = len(self.trajectories) 637 self.trajectories.append((mp4_path, valid_indices, actions)) 638 639 # Build clip entries with sliding window 640 num_clips = (len(valid_indices) - seq_len) // stride + 1 641 for c in range(num_clips): 642 self.clip_index.append((traj_idx, c * stride)) 643 644 print(f"Created {len(self.clip_index)} clips of length {seq_len} " 645 f"from {len(self.trajectories)} trajectories" 646 f" (backend: {'decord' if HAS_DECORD else 'cv2'})") 647 648 def __len__(self): 649 """Return the number of sliding-window clips across all trajectories.""" 650 return len(self.clip_index) 651 652 def __getitem__(self, idx, _retries=3): 653 """Decode a clip on-the-fly and return as a dict. 654 655 If frame decoding fails (corrupt MP4 data), retries with a random 656 different clip up to ``_retries`` times so a single bad file doesn't 657 crash the entire training run. 658 659 Returns dict compatible with Dreamer4's BehaviorCloneTrainer: 660 'video': (3, seq_len, H, W) float32 in [0, 1] 661 'discrete_actions': (seq_len, 21) int64 662 'rewards': (seq_len,) float32 663 """ 664 for attempt in range(_retries + 1): 665 try: 666 traj_idx, offset = self.clip_index[idx] 667 mp4_path, valid_indices, actions = self.trajectories[traj_idx] 668 669 # Slice the frame indices and actions for this clip 670 clip_frame_indices = valid_indices[offset:offset + self.seq_len] 671 clip_actions = actions[offset:offset + self.seq_len] 672 673 # Decode frames on-the-fly from the MP4 (no data stored in RAM) 674 frames = _decode_frames( 675 mp4_path, clip_frame_indices, 676 self.image_height, self.image_width, 677 ) # (seq_len, H, W, 3) uint8 RGB 678 679 # (T, H, W, 3) uint8 → (3, T, H, W) float32 in [0, 1] 680 video = torch.from_numpy(frames).permute(3, 0, 1, 2).float().div_(255.0) 681 682 return { 683 'video': video, 684 'discrete_actions': torch.from_numpy(clip_actions.astype(np.int64)), 685 'rewards': torch.zeros(self.seq_len, dtype=torch.float32), 686 } 687 except Exception as e: 688 if attempt < _retries: 689 print(f" Warning: decode error at clip {idx} " 690 f"({mp4_path}), retrying with a different clip: {e}") 691 idx = random.randint(0, len(self.clip_index) - 1) 692 else: 693 raise RuntimeError( 694 f"Failed to decode clip after {_retries + 1} attempts. " 695 f"Last error: {e}" 696 ) from e
Lazy-loading PyTorch Dataset for VPT Minecraft recordings.
Decodes video frames on-the-fly from the original MP4 files instead of pre-loading everything into RAM. This means:
- Zero extra disk storage (the MP4 files ARE the dataset)
- RAM usage is independent of video count (~1 MB metadata per trajectory)
- Scales to thousands of videos without issue
At init time, only the JSONL files and MP4 headers are read to build a lightweight index of valid frame positions and pre-parsed actions. The actual pixel data is decoded from the MP4 on each __getitem__ call.
For best throughput, use DataLoader with num_workers > 0 so that frame
decoding runs in parallel worker processes. Install decord for
faster random-access seeking (falls back to cv2 otherwise).
Shape convention (per sample): video: (3, seq_len, H, W) float32 in [0, 1] actions: (seq_len, 21) int64 rewards: (seq_len,) float32
574 def __init__( 575 self, 576 data_dir: str, 577 seq_len: int = 16, 578 stride: int = 8, 579 image_height: int = 384, 580 image_width: int = 640, 581 skip_null_actions: bool = True, 582 max_trajectories: Optional[int] = None, 583 ): 584 """ 585 Args: 586 data_dir: Directory containing .mp4/.jsonl pairs 587 seq_len: Number of frames per training clip 588 stride: Step size between consecutive clips (for overlap) 589 image_height: Target image height 590 image_width: Target image width 591 skip_null_actions: Skip null actions as VPT paper does 592 max_trajectories: Limit number of loaded trajectories (for debugging) 593 """ 594 super().__init__() 595 self.seq_len = seq_len 596 self.image_height = image_height 597 self.image_width = image_width 598 599 # Find all .mp4 files and their matching .jsonl files 600 mp4_files = sorted(glob.glob(os.path.join(data_dir, "*.mp4"))) 601 if max_trajectories is not None: 602 mp4_files = mp4_files[:max_trajectories] 603 604 # Pre-scan: read only JSONLs + MP4 headers (no pixel data loaded) 605 # Per trajectory we store: 606 # mp4_path: string (~100 bytes) 607 # valid_indices: int32 array (~4 bytes/frame) 608 # actions: int16 array (~42 bytes/frame) 609 # Total: ~46 bytes per valid frame — vs. 196,608 bytes per frame 610 # for the old eager float32 approach (4,272x smaller). 611 self.trajectories = [] # list of (mp4_path, valid_indices, actions) 612 self.clip_index = [] # list of (traj_idx, offset_in_valid_frames) 613 614 print(f"Scanning {len(mp4_files)} trajectories from {data_dir}...") 615 for mp4_path in mp4_files: 616 base = os.path.splitext(mp4_path)[0] 617 jsonl_path = base + ".jsonl" 618 if not os.path.exists(jsonl_path): 619 print(f" Warning: no .jsonl for {mp4_path}, skipping") 620 continue 621 622 try: 623 valid_indices, actions = prescan_trajectory( 624 mp4_path, jsonl_path, 625 skip_null_actions=skip_null_actions, 626 ) 627 except Exception as e: 628 print(f" Error scanning {mp4_path}: {e}") 629 continue 630 631 if len(valid_indices) < seq_len: 632 print(f" Trajectory too short ({len(valid_indices)} valid frames " 633 f"< {seq_len}), skipping") 634 continue 635 636 traj_idx = len(self.trajectories) 637 self.trajectories.append((mp4_path, valid_indices, actions)) 638 639 # Build clip entries with sliding window 640 num_clips = (len(valid_indices) - seq_len) // stride + 1 641 for c in range(num_clips): 642 self.clip_index.append((traj_idx, c * stride)) 643 644 print(f"Created {len(self.clip_index)} clips of length {seq_len} " 645 f"from {len(self.trajectories)} trajectories" 646 f" (backend: {'decord' if HAS_DECORD else 'cv2'})")
Arguments:
- data_dir: Directory containing .mp4/.jsonl pairs
- seq_len: Number of frames per training clip
- stride: Step size between consecutive clips (for overlap)
- image_height: Target image height
- image_width: Target image width
- skip_null_actions: Skip null actions as VPT paper does
- max_trajectories: Limit number of loaded trajectories (for debugging)
699def collate_minecraft_batch(batch: list) -> dict: 700 """Custom collate function that stacks dicts into batched tensors. 701 702 Transforms list of per-sample dicts into a single dict of batched tensors, 703 matching the shapes expected by DynamicsWorldModel.forward(): 704 video: (B, 3, T, H, W) 705 discrete_actions: (B, T, 21) 706 rewards: (B, T) 707 """ 708 return { 709 'video': torch.stack([s['video'] for s in batch]), 710 'discrete_actions': torch.stack([s['discrete_actions'] for s in batch]), 711 'rewards': torch.stack([s['rewards'] for s in batch]), 712 }
Custom collate function that stacks dicts into batched tensors.
Transforms list of per-sample dicts into a single dict of batched tensors, matching the shapes expected by DynamicsWorldModel.forward(): video: (B, 3, T, H, W) discrete_actions: (B, T, 21) rewards: (B, T)