dreamer4_minecraft_agent
Dreamer4 Minecraft Agent — drop-in replacement for VPT's MineRLAgent.
This module wraps a trained Dreamer4 world model (VideoTokenizer + DynamicsWorldModel with trained policy/value heads) to create an agent that can play Minecraft through the VPT evaluation infrastructure.
It replaces the broken examples/dreamer4_agent.py which imported non-existent classes (Encoder, Decoder, Tokenizer, Dynamics) from a non-existent 'model' module. This version uses the actual Dreamer4 API:
- VideoTokenizer (from dreamer4.dreamer4)
- DynamicsWorldModel (from dreamer4.dreamer4)
Inference pipeline at each step:
- Receive MineRL observation dict with "pov" key (H, W, 3) uint8
- Zero-pad to 384x640, convert to (1, 3, 1, 384, 640) float tensor
- VideoTokenizer.tokenize() → latent (1, 1, num_latents, dim_latent)
- Concatenate with history latents, add view dim for DynamicsWorldModel
- DynamicsWorldModel.forward() with signal_level=max_steps-1 (clean) → agent embedding h_t
- Policy head MLP → action distribution
- ActionEmbedder.sample() → discrete actions (buttons + camera bin)
- Convert discrete actions to MineRL env action dict
Key differences from examples/dreamer4_agent.py:
- Uses actual VideoTokenizer/DynamicsWorldModel classes
- Uses DynamicsWorldModel's built-in policy_head and action_embedder
- No separate PolicyHead class needed
- Correct action space mapping matching training
- Proper signal_level/step_size conditioning
1""" 2Dreamer4 Minecraft Agent — drop-in replacement for VPT's MineRLAgent. 3 4This module wraps a trained Dreamer4 world model (VideoTokenizer + 5DynamicsWorldModel with trained policy/value heads) to create an agent 6that can play Minecraft through the VPT evaluation infrastructure. 7 8It replaces the broken examples/dreamer4_agent.py which imported 9non-existent classes (Encoder, Decoder, Tokenizer, Dynamics) from a 10non-existent 'model' module. This version uses the actual Dreamer4 API: 11 - VideoTokenizer (from dreamer4.dreamer4) 12 - DynamicsWorldModel (from dreamer4.dreamer4) 13 14Inference pipeline at each step: 15 1. Receive MineRL observation dict with "pov" key (H, W, 3) uint8 16 2. Zero-pad to 384x640, convert to (1, 3, 1, 384, 640) float tensor 17 3. VideoTokenizer.tokenize() → latent (1, 1, num_latents, dim_latent) 18 4. Concatenate with history latents, add view dim for DynamicsWorldModel 19 5. DynamicsWorldModel.forward() with signal_level=max_steps-1 (clean) 20 → agent embedding h_t 21 6. Policy head MLP → action distribution 22 7. ActionEmbedder.sample() → discrete actions (buttons + camera bin) 23 8. Convert discrete actions to MineRL env action dict 24 25Key differences from examples/dreamer4_agent.py: 26 - Uses actual VideoTokenizer/DynamicsWorldModel classes 27 - Uses DynamicsWorldModel's built-in policy_head and action_embedder 28 - No separate PolicyHead class needed 29 - Correct action space mapping matching training 30 - Proper signal_level/step_size conditioning 31""" 32 33import os 34import sys 35from typing import Optional 36 37import numpy as np 38import torch 39import cv2 40 41# Add paths for imports 42_this_dir = os.path.dirname(os.path.abspath(__file__)) # .../Video-Pre-Training 43_project_root = os.path.dirname(_this_dir) # .../dreamer4 (repo root) 44sys.path.insert(0, _this_dir) # VPT local imports (agent.py, lib/) 45sys.path.insert(0, _project_root) # dreamer4 package + minecraft_vpt_dataset 46 47# pylint: disable=wrong-import-position 48# These imports must follow the sys.path edits above because agent.py and 49# lib.actions live inside Video-Pre-Training/, and dreamer4 / minecraft_vpt_dataset 50# live at the repo root — neither is on sys.path when this module is imported 51# directly (not as part of a package). 52from dreamer4 import VideoTokenizer, DynamicsWorldModel # noqa: E402 53 54from agent import ENV_KWARGS, validate_env, resize_image # noqa: E402 55from lib.actions import ActionTransformer # noqa: E402 56 57from minecraft_vpt_dataset import ( # noqa: E402 58 BUTTONS_ALL, 59 N_BUTTONS, 60 N_CAMERA_BINS, 61 CAMERA_MAXVAL, 62 CAMERA_BINSIZE, 63 CAMERA_MU, 64) 65# pylint: enable=wrong-import-position 66 67 68# VPT action transformer for converting discrete bins back to env actions 69ACTION_TRANSFORMER_KWARGS = dict( 70 camera_binsize=CAMERA_BINSIZE, 71 camera_maxval=CAMERA_MAXVAL, 72 camera_mu=CAMERA_MU, 73 camera_quantization_scheme="mu_law", 74) 75 76 77def dreamer4_actions_to_minerl( 78 discrete_actions: torch.Tensor, 79 action_transformer: ActionTransformer, 80) -> dict: 81 """Convert Dreamer4 discrete action tensor to MineRL env action dict. 82 83 Args: 84 discrete_actions: (21,) long tensor 85 [0:20] = button states (0 or 1) 86 [20] = camera joint index (0 to 120) 87 action_transformer: VPT ActionTransformer for camera undiscretization 88 89 Returns: 90 MineRL-compatible action dict with all required keys 91 """ 92 actions_np = discrete_actions.cpu().numpy() 93 94 env_action = {} 95 96 # Buttons: map discrete values to button names 97 for i, button_name in enumerate(BUTTONS_ALL): 98 env_action[button_name] = int(actions_np[i]) 99 100 # Keys not in our 20-button set default to 0 101 env_action["ESC"] = 0 102 env_action["pickItem"] = 0 103 env_action["swapHands"] = 0 104 105 # Camera: decompose joint index to pitch/yaw bins, then undiscretize 106 camera_joint = int(actions_np[N_BUTTONS]) 107 pitch_bin = camera_joint // N_CAMERA_BINS 108 yaw_bin = camera_joint % N_CAMERA_BINS 109 110 # Use VPT's undiscretize to convert bins back to degrees 111 camera_bins = np.array([[pitch_bin, yaw_bin]]) 112 camera_degrees = action_transformer.undiscretize_camera(camera_bins) 113 env_action["camera"] = camera_degrees[0] # (2,) array [pitch, yaw] 114 115 return env_action 116 117 118class Dreamer4MinecraftAgent: 119 """Dreamer4-based Minecraft agent compatible with VPT evaluation scripts. 120 121 This class provides the same interface as VPT's MineRLAgent: 122 - __init__(env, ...) validates the environment 123 - reset() resets internal state 124 - get_action(minerl_obs) returns MineRL-compatible action dict 125 126 Architecture: 127 The agent uses the DynamicsWorldModel in a special "inference" mode: 128 - Signal level = max_steps - 1 (fully denoised / clean observation) 129 - Step size = max_steps (single step, since observation is already clean) 130 - The model processes the observation history and outputs agent embeddings 131 - Policy head maps embeddings to action distributions 132 133 This is equivalent to what interact_with_env() does internally, 134 but we handle the environment stepping ourselves to match VPT's interface. 135 """ 136 137 def __init__( 138 self, 139 env, 140 *, 141 checkpoint_path: str, 142 device: Optional[str] = None, 143 stochastic: bool = True, 144 max_context_len: int = 32, 145 ): 146 """ 147 Args: 148 env: MineRL environment (validated against VPT settings) 149 checkpoint_path: Path to dreamer4_minecraft.pt from training 150 device: "cuda" or "cpu" 151 stochastic: Sample from policy (True) or take argmax (False) 152 max_context_len: Maximum history length for the dynamics model 153 """ 154 validate_env(env) 155 156 if device is None: 157 device = "cuda" if torch.cuda.is_available() else "cpu" 158 self.device = torch.device(device) 159 self.stochastic = stochastic 160 self.max_context_len = max_context_len 161 162 # VPT action transformer for converting bins back to degrees 163 self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS) 164 165 # Load checkpoint 166 ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) 167 tok_config = ckpt.get('tokenizer_config', {}) 168 dyn_config = ckpt.get('config', {}) 169 170 # Rebuild models 171 self.tokenizer = VideoTokenizer(**tok_config) 172 self.tokenizer.eval() 173 for p in self.tokenizer.parameters(): 174 p.requires_grad_(False) 175 176 self.dynamics = DynamicsWorldModel( 177 video_tokenizer=self.tokenizer, 178 **dyn_config, 179 ) 180 self.dynamics.load_state_dict(ckpt['model']) 181 self.dynamics.eval() 182 for p in self.dynamics.parameters(): 183 p.requires_grad_(False) 184 185 self.dynamics.to(self.device) 186 187 n_params = sum(p.numel() for p in self.dynamics.parameters()) 188 print(f"Loaded Dreamer4 agent ({n_params:,} params)") 189 190 # Internal state 191 self._latent_history = [] # List of (1, 1, num_views, num_latents, dim_latent) tensors 192 self._action_history = [] # List of (1, 1, 21) discrete action tensors 193 self._reward_history = [] # List of (1, 1) reward tensors 194 self._step_count = 0 195 196 def load_weights(self, path: str): # noqa: ARG002 (interface compat) 197 """Compatibility stub — weights are loaded in __init__.""" 198 199 def reset(self): 200 """Reset agent state for a new episode.""" 201 self._latent_history = [] 202 self._action_history = [] 203 self._reward_history = [] 204 self._step_count = 0 205 206 @torch.no_grad() 207 def _encode_observation(self, minerl_obs: dict) -> torch.Tensor: 208 """Encode a MineRL observation to latent tokens. 209 210 Args: 211 minerl_obs: MineRL observation dict with "pov" key 212 213 Returns: 214 latents: (1, 1, num_latent_tokens, dim_latent) tensor 215 """ 216 # Extract frame and zero-pad to 384x640 (paper: 360x640 → 384x640) 217 frame = minerl_obs["pov"] # (H, W, 3) uint8, typically 360x640 218 h, w = frame.shape[:2] 219 target_h, target_w = 384, 640 220 if w == target_w and h < target_h: 221 pad_total = target_h - h 222 pad_top = pad_total // 2 223 pad_bottom = pad_total - pad_top 224 frame = cv2.copyMakeBorder( 225 frame, pad_top, pad_bottom, 0, 0, 226 cv2.BORDER_CONSTANT, value=(0, 0, 0) 227 ) 228 elif h != target_h or w != target_w: 229 frame = resize_image(frame, (target_w, target_h)) 230 231 # Convert to (1, 3, 1, 384, 640) float tensor 232 frame_t = torch.from_numpy(frame).float() / 255.0 233 frame_t = frame_t.permute(2, 0, 1) # (3, 384, 640) 234 video = frame_t.unsqueeze(0).unsqueeze(2).to(self.device) # (1, 3, 1, 384, 640) 235 236 # Tokenize: returns (1, 1, num_latents, dim_latent) 237 latents = self.tokenizer.tokenize(video) 238 return latents 239 240 @torch.no_grad() 241 def get_action(self, minerl_obs: dict) -> dict: 242 """Get action for a MineRL observation. 243 244 Main inference entry point. Matches MineRLAgent.get_action() interface. 245 246 Flow: 247 1. Encode observation to latent 248 2. Build history context (latents, actions, rewards) 249 3. Run dynamics model forward pass with fully-denoised signal 250 4. Extract agent embedding → policy head → sample action 251 5. Convert to MineRL format 252 253 Args: 254 minerl_obs: MineRL observation dict with "pov" key 255 256 Returns: 257 MineRL action dict compatible with env.step() 258 """ 259 # 1. Encode current observation 260 latents = self._encode_observation(minerl_obs) # (1, 1, n, d) 261 self._latent_history.append(latents) 262 263 # 2. Build context window 264 # Trim to max context length 265 start = max(0, len(self._latent_history) - self.max_context_len) 266 ctx_latents = self._latent_history[start:] 267 ctx_actions = self._action_history[start:] 268 ctx_rewards = self._reward_history[start:] 269 270 # Stack latents: (1, T, num_latents, dim_latent) 271 latents_seq = torch.cat(ctx_latents, dim=1) 272 273 # Stack actions: (1, T-1, 21) or None 274 # Actions are shifted: action[t] is the action taken AFTER observing frame[t] 275 # So for T frames we have T-1 past actions (none before first frame) 276 discrete_actions = None 277 if len(ctx_actions) > 0: 278 discrete_actions = torch.cat(ctx_actions, dim=1) 279 280 # Stack rewards: (1, T-1) or None 281 rewards = None 282 if len(ctx_rewards) > 0: 283 rewards = torch.cat(ctx_rewards, dim=1) 284 285 # 3. Forward pass through dynamics model 286 # Signal level = max_steps - 1 means "fully denoised / clean observation" 287 # Step size = max_steps // num_steps where num_steps determines denoising granularity 288 # For inference on real observations, we use a single step 289 max_steps = self.dynamics.max_steps 290 num_steps = 4 # Number of denoising steps (must divide max_steps) 291 assert max_steps % num_steps == 0, \ 292 f"max_steps ({max_steps}) must be divisible by num_steps ({num_steps})" 293 step_size = max_steps // num_steps 294 295 _, (embeds, _, _) = self.dynamics( 296 latents=latents_seq, 297 signal_levels=max_steps - 1, # Clean signal (fully denoised) 298 step_sizes=step_size, 299 rewards=rewards, 300 discrete_actions=discrete_actions, 301 # Skip noise injection — latents are clean from tokenizer 302 latent_is_noised=True, 303 return_pred_only=True, 304 return_intermediates=True, 305 ) 306 307 # 4. Extract agent embedding and sample action 308 agent_embed = embeds.agent # (B, T, num_agents, dim) 309 assert agent_embed.ndim == 4, \ 310 f"Expected 4D agent_embed, got shape {agent_embed.shape}" 311 # Take last timestep, first agent 312 one_agent_embed = agent_embed[:, -1:, 0, :] # (1, 1, dim) 313 314 # Policy head → action distribution 315 policy_embed = self.dynamics.policy_head(one_agent_embed) 316 317 # Sample actions from the action embedder 318 if self.stochastic: 319 sampled_discrete, _ = self.dynamics.action_embedder.sample( 320 policy_embed, pred_head_index=0, squeeze=True 321 ) 322 else: 323 # For deterministic: sample with temperature 0 (argmax) 324 sampled_discrete, _ = self.dynamics.action_embedder.sample( 325 policy_embed, pred_head_index=0, 326 discrete_temperature=0.01, squeeze=True 327 ) 328 329 # sampled_discrete: (1, 1, 21) → squeeze to (21,) for action conversion 330 action_tensor = sampled_discrete.squeeze() # (21,) 331 332 # 5. Store action and dummy reward in history 333 self._action_history.append(sampled_discrete) 334 # Use zero reward (we don't know the real reward until env.step) 335 dummy_reward = torch.zeros(1, 1, device=self.device) 336 self._reward_history.append(dummy_reward) 337 338 # 6. Convert to MineRL action format 339 minerl_action = dreamer4_actions_to_minerl(action_tensor, self.action_transformer) 340 341 self._step_count += 1 342 return minerl_action 343 344 345# ─── CLI for quick testing ────────────────────────────────────────── 346 347if __name__ == "__main__": 348 from argparse import ArgumentParser 349 from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival 350 351 parser = ArgumentParser("Run Dreamer4 agent on MineRL") 352 parser.add_argument("--checkpoint", type=str, required=True, 353 help="Path to dreamer4_minecraft.pt checkpoint") 354 parser.add_argument("--deterministic", action="store_true") 355 parser.add_argument("--render", action="store_true") 356 parser.add_argument("--max_steps", type=int, default=36000) 357 args = parser.parse_args() 358 359 env = HumanSurvival(**ENV_KWARGS).make() 360 agent = Dreamer4MinecraftAgent( 361 env, 362 checkpoint_path=args.checkpoint, 363 stochastic=not args.deterministic, 364 ) 365 366 obs = env.reset() 367 agent.reset() 368 total_reward = 0.0 369 370 for step in range(args.max_steps): 371 action = agent.get_action(obs) 372 obs, reward, done, info = env.step(action) 373 total_reward += reward 374 375 if step % 100 == 0: 376 print(f"Step {step}, reward: {total_reward:.2f}") 377 if args.render: 378 frame = obs["pov"] # (360, 640, 3) native resolution 379 cv2.imshow("Dreamer4 Minecraft", frame[:, :, ::-1]) 380 if cv2.waitKey(1) & 0xFF == ord('q'): 381 break 382 if done: 383 print(f"Episode done at step {step}, reward: {total_reward:.2f}") 384 obs = env.reset() 385 agent.reset() 386 total_reward = 0.0 387 # shut down at end of for loop 388 if args.render: 389 cv2.destroyAllWindows()
78def dreamer4_actions_to_minerl( 79 discrete_actions: torch.Tensor, 80 action_transformer: ActionTransformer, 81) -> dict: 82 """Convert Dreamer4 discrete action tensor to MineRL env action dict. 83 84 Args: 85 discrete_actions: (21,) long tensor 86 [0:20] = button states (0 or 1) 87 [20] = camera joint index (0 to 120) 88 action_transformer: VPT ActionTransformer for camera undiscretization 89 90 Returns: 91 MineRL-compatible action dict with all required keys 92 """ 93 actions_np = discrete_actions.cpu().numpy() 94 95 env_action = {} 96 97 # Buttons: map discrete values to button names 98 for i, button_name in enumerate(BUTTONS_ALL): 99 env_action[button_name] = int(actions_np[i]) 100 101 # Keys not in our 20-button set default to 0 102 env_action["ESC"] = 0 103 env_action["pickItem"] = 0 104 env_action["swapHands"] = 0 105 106 # Camera: decompose joint index to pitch/yaw bins, then undiscretize 107 camera_joint = int(actions_np[N_BUTTONS]) 108 pitch_bin = camera_joint // N_CAMERA_BINS 109 yaw_bin = camera_joint % N_CAMERA_BINS 110 111 # Use VPT's undiscretize to convert bins back to degrees 112 camera_bins = np.array([[pitch_bin, yaw_bin]]) 113 camera_degrees = action_transformer.undiscretize_camera(camera_bins) 114 env_action["camera"] = camera_degrees[0] # (2,) array [pitch, yaw] 115 116 return env_action
Convert Dreamer4 discrete action tensor to MineRL env action dict.
Arguments:
- discrete_actions: (21,) long tensor [0:20] = button states (0 or 1) [20] = camera joint index (0 to 120)
- action_transformer: VPT ActionTransformer for camera undiscretization
Returns:
MineRL-compatible action dict with all required keys
119class Dreamer4MinecraftAgent: 120 """Dreamer4-based Minecraft agent compatible with VPT evaluation scripts. 121 122 This class provides the same interface as VPT's MineRLAgent: 123 - __init__(env, ...) validates the environment 124 - reset() resets internal state 125 - get_action(minerl_obs) returns MineRL-compatible action dict 126 127 Architecture: 128 The agent uses the DynamicsWorldModel in a special "inference" mode: 129 - Signal level = max_steps - 1 (fully denoised / clean observation) 130 - Step size = max_steps (single step, since observation is already clean) 131 - The model processes the observation history and outputs agent embeddings 132 - Policy head maps embeddings to action distributions 133 134 This is equivalent to what interact_with_env() does internally, 135 but we handle the environment stepping ourselves to match VPT's interface. 136 """ 137 138 def __init__( 139 self, 140 env, 141 *, 142 checkpoint_path: str, 143 device: Optional[str] = None, 144 stochastic: bool = True, 145 max_context_len: int = 32, 146 ): 147 """ 148 Args: 149 env: MineRL environment (validated against VPT settings) 150 checkpoint_path: Path to dreamer4_minecraft.pt from training 151 device: "cuda" or "cpu" 152 stochastic: Sample from policy (True) or take argmax (False) 153 max_context_len: Maximum history length for the dynamics model 154 """ 155 validate_env(env) 156 157 if device is None: 158 device = "cuda" if torch.cuda.is_available() else "cpu" 159 self.device = torch.device(device) 160 self.stochastic = stochastic 161 self.max_context_len = max_context_len 162 163 # VPT action transformer for converting bins back to degrees 164 self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS) 165 166 # Load checkpoint 167 ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) 168 tok_config = ckpt.get('tokenizer_config', {}) 169 dyn_config = ckpt.get('config', {}) 170 171 # Rebuild models 172 self.tokenizer = VideoTokenizer(**tok_config) 173 self.tokenizer.eval() 174 for p in self.tokenizer.parameters(): 175 p.requires_grad_(False) 176 177 self.dynamics = DynamicsWorldModel( 178 video_tokenizer=self.tokenizer, 179 **dyn_config, 180 ) 181 self.dynamics.load_state_dict(ckpt['model']) 182 self.dynamics.eval() 183 for p in self.dynamics.parameters(): 184 p.requires_grad_(False) 185 186 self.dynamics.to(self.device) 187 188 n_params = sum(p.numel() for p in self.dynamics.parameters()) 189 print(f"Loaded Dreamer4 agent ({n_params:,} params)") 190 191 # Internal state 192 self._latent_history = [] # List of (1, 1, num_views, num_latents, dim_latent) tensors 193 self._action_history = [] # List of (1, 1, 21) discrete action tensors 194 self._reward_history = [] # List of (1, 1) reward tensors 195 self._step_count = 0 196 197 def load_weights(self, path: str): # noqa: ARG002 (interface compat) 198 """Compatibility stub — weights are loaded in __init__.""" 199 200 def reset(self): 201 """Reset agent state for a new episode.""" 202 self._latent_history = [] 203 self._action_history = [] 204 self._reward_history = [] 205 self._step_count = 0 206 207 @torch.no_grad() 208 def _encode_observation(self, minerl_obs: dict) -> torch.Tensor: 209 """Encode a MineRL observation to latent tokens. 210 211 Args: 212 minerl_obs: MineRL observation dict with "pov" key 213 214 Returns: 215 latents: (1, 1, num_latent_tokens, dim_latent) tensor 216 """ 217 # Extract frame and zero-pad to 384x640 (paper: 360x640 → 384x640) 218 frame = minerl_obs["pov"] # (H, W, 3) uint8, typically 360x640 219 h, w = frame.shape[:2] 220 target_h, target_w = 384, 640 221 if w == target_w and h < target_h: 222 pad_total = target_h - h 223 pad_top = pad_total // 2 224 pad_bottom = pad_total - pad_top 225 frame = cv2.copyMakeBorder( 226 frame, pad_top, pad_bottom, 0, 0, 227 cv2.BORDER_CONSTANT, value=(0, 0, 0) 228 ) 229 elif h != target_h or w != target_w: 230 frame = resize_image(frame, (target_w, target_h)) 231 232 # Convert to (1, 3, 1, 384, 640) float tensor 233 frame_t = torch.from_numpy(frame).float() / 255.0 234 frame_t = frame_t.permute(2, 0, 1) # (3, 384, 640) 235 video = frame_t.unsqueeze(0).unsqueeze(2).to(self.device) # (1, 3, 1, 384, 640) 236 237 # Tokenize: returns (1, 1, num_latents, dim_latent) 238 latents = self.tokenizer.tokenize(video) 239 return latents 240 241 @torch.no_grad() 242 def get_action(self, minerl_obs: dict) -> dict: 243 """Get action for a MineRL observation. 244 245 Main inference entry point. Matches MineRLAgent.get_action() interface. 246 247 Flow: 248 1. Encode observation to latent 249 2. Build history context (latents, actions, rewards) 250 3. Run dynamics model forward pass with fully-denoised signal 251 4. Extract agent embedding → policy head → sample action 252 5. Convert to MineRL format 253 254 Args: 255 minerl_obs: MineRL observation dict with "pov" key 256 257 Returns: 258 MineRL action dict compatible with env.step() 259 """ 260 # 1. Encode current observation 261 latents = self._encode_observation(minerl_obs) # (1, 1, n, d) 262 self._latent_history.append(latents) 263 264 # 2. Build context window 265 # Trim to max context length 266 start = max(0, len(self._latent_history) - self.max_context_len) 267 ctx_latents = self._latent_history[start:] 268 ctx_actions = self._action_history[start:] 269 ctx_rewards = self._reward_history[start:] 270 271 # Stack latents: (1, T, num_latents, dim_latent) 272 latents_seq = torch.cat(ctx_latents, dim=1) 273 274 # Stack actions: (1, T-1, 21) or None 275 # Actions are shifted: action[t] is the action taken AFTER observing frame[t] 276 # So for T frames we have T-1 past actions (none before first frame) 277 discrete_actions = None 278 if len(ctx_actions) > 0: 279 discrete_actions = torch.cat(ctx_actions, dim=1) 280 281 # Stack rewards: (1, T-1) or None 282 rewards = None 283 if len(ctx_rewards) > 0: 284 rewards = torch.cat(ctx_rewards, dim=1) 285 286 # 3. Forward pass through dynamics model 287 # Signal level = max_steps - 1 means "fully denoised / clean observation" 288 # Step size = max_steps // num_steps where num_steps determines denoising granularity 289 # For inference on real observations, we use a single step 290 max_steps = self.dynamics.max_steps 291 num_steps = 4 # Number of denoising steps (must divide max_steps) 292 assert max_steps % num_steps == 0, \ 293 f"max_steps ({max_steps}) must be divisible by num_steps ({num_steps})" 294 step_size = max_steps // num_steps 295 296 _, (embeds, _, _) = self.dynamics( 297 latents=latents_seq, 298 signal_levels=max_steps - 1, # Clean signal (fully denoised) 299 step_sizes=step_size, 300 rewards=rewards, 301 discrete_actions=discrete_actions, 302 # Skip noise injection — latents are clean from tokenizer 303 latent_is_noised=True, 304 return_pred_only=True, 305 return_intermediates=True, 306 ) 307 308 # 4. Extract agent embedding and sample action 309 agent_embed = embeds.agent # (B, T, num_agents, dim) 310 assert agent_embed.ndim == 4, \ 311 f"Expected 4D agent_embed, got shape {agent_embed.shape}" 312 # Take last timestep, first agent 313 one_agent_embed = agent_embed[:, -1:, 0, :] # (1, 1, dim) 314 315 # Policy head → action distribution 316 policy_embed = self.dynamics.policy_head(one_agent_embed) 317 318 # Sample actions from the action embedder 319 if self.stochastic: 320 sampled_discrete, _ = self.dynamics.action_embedder.sample( 321 policy_embed, pred_head_index=0, squeeze=True 322 ) 323 else: 324 # For deterministic: sample with temperature 0 (argmax) 325 sampled_discrete, _ = self.dynamics.action_embedder.sample( 326 policy_embed, pred_head_index=0, 327 discrete_temperature=0.01, squeeze=True 328 ) 329 330 # sampled_discrete: (1, 1, 21) → squeeze to (21,) for action conversion 331 action_tensor = sampled_discrete.squeeze() # (21,) 332 333 # 5. Store action and dummy reward in history 334 self._action_history.append(sampled_discrete) 335 # Use zero reward (we don't know the real reward until env.step) 336 dummy_reward = torch.zeros(1, 1, device=self.device) 337 self._reward_history.append(dummy_reward) 338 339 # 6. Convert to MineRL action format 340 minerl_action = dreamer4_actions_to_minerl(action_tensor, self.action_transformer) 341 342 self._step_count += 1 343 return minerl_action
Dreamer4-based Minecraft agent compatible with VPT evaluation scripts.
This class provides the same interface as VPT's MineRLAgent:
- __init__(env, ...) validates the environment
- reset() resets internal state
- get_action(minerl_obs) returns MineRL-compatible action dict
Architecture:
The agent uses the DynamicsWorldModel in a special "inference" mode:
- Signal level = max_steps - 1 (fully denoised / clean observation)
- Step size = max_steps (single step, since observation is already clean)
- The model processes the observation history and outputs agent embeddings
- Policy head maps embeddings to action distributions
This is equivalent to what interact_with_env() does internally, but we handle the environment stepping ourselves to match VPT's interface.
138 def __init__( 139 self, 140 env, 141 *, 142 checkpoint_path: str, 143 device: Optional[str] = None, 144 stochastic: bool = True, 145 max_context_len: int = 32, 146 ): 147 """ 148 Args: 149 env: MineRL environment (validated against VPT settings) 150 checkpoint_path: Path to dreamer4_minecraft.pt from training 151 device: "cuda" or "cpu" 152 stochastic: Sample from policy (True) or take argmax (False) 153 max_context_len: Maximum history length for the dynamics model 154 """ 155 validate_env(env) 156 157 if device is None: 158 device = "cuda" if torch.cuda.is_available() else "cpu" 159 self.device = torch.device(device) 160 self.stochastic = stochastic 161 self.max_context_len = max_context_len 162 163 # VPT action transformer for converting bins back to degrees 164 self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS) 165 166 # Load checkpoint 167 ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) 168 tok_config = ckpt.get('tokenizer_config', {}) 169 dyn_config = ckpt.get('config', {}) 170 171 # Rebuild models 172 self.tokenizer = VideoTokenizer(**tok_config) 173 self.tokenizer.eval() 174 for p in self.tokenizer.parameters(): 175 p.requires_grad_(False) 176 177 self.dynamics = DynamicsWorldModel( 178 video_tokenizer=self.tokenizer, 179 **dyn_config, 180 ) 181 self.dynamics.load_state_dict(ckpt['model']) 182 self.dynamics.eval() 183 for p in self.dynamics.parameters(): 184 p.requires_grad_(False) 185 186 self.dynamics.to(self.device) 187 188 n_params = sum(p.numel() for p in self.dynamics.parameters()) 189 print(f"Loaded Dreamer4 agent ({n_params:,} params)") 190 191 # Internal state 192 self._latent_history = [] # List of (1, 1, num_views, num_latents, dim_latent) tensors 193 self._action_history = [] # List of (1, 1, 21) discrete action tensors 194 self._reward_history = [] # List of (1, 1) reward tensors 195 self._step_count = 0
Arguments:
- env: MineRL environment (validated against VPT settings)
- checkpoint_path: Path to dreamer4_minecraft.pt from training
- device: "cuda" or "cpu"
- stochastic: Sample from policy (True) or take argmax (False)
- max_context_len: Maximum history length for the dynamics model
197 def load_weights(self, path: str): # noqa: ARG002 (interface compat) 198 """Compatibility stub — weights are loaded in __init__."""
Compatibility stub — weights are loaded in __init__.
200 def reset(self): 201 """Reset agent state for a new episode.""" 202 self._latent_history = [] 203 self._action_history = [] 204 self._reward_history = [] 205 self._step_count = 0
Reset agent state for a new episode.
241 @torch.no_grad() 242 def get_action(self, minerl_obs: dict) -> dict: 243 """Get action for a MineRL observation. 244 245 Main inference entry point. Matches MineRLAgent.get_action() interface. 246 247 Flow: 248 1. Encode observation to latent 249 2. Build history context (latents, actions, rewards) 250 3. Run dynamics model forward pass with fully-denoised signal 251 4. Extract agent embedding → policy head → sample action 252 5. Convert to MineRL format 253 254 Args: 255 minerl_obs: MineRL observation dict with "pov" key 256 257 Returns: 258 MineRL action dict compatible with env.step() 259 """ 260 # 1. Encode current observation 261 latents = self._encode_observation(minerl_obs) # (1, 1, n, d) 262 self._latent_history.append(latents) 263 264 # 2. Build context window 265 # Trim to max context length 266 start = max(0, len(self._latent_history) - self.max_context_len) 267 ctx_latents = self._latent_history[start:] 268 ctx_actions = self._action_history[start:] 269 ctx_rewards = self._reward_history[start:] 270 271 # Stack latents: (1, T, num_latents, dim_latent) 272 latents_seq = torch.cat(ctx_latents, dim=1) 273 274 # Stack actions: (1, T-1, 21) or None 275 # Actions are shifted: action[t] is the action taken AFTER observing frame[t] 276 # So for T frames we have T-1 past actions (none before first frame) 277 discrete_actions = None 278 if len(ctx_actions) > 0: 279 discrete_actions = torch.cat(ctx_actions, dim=1) 280 281 # Stack rewards: (1, T-1) or None 282 rewards = None 283 if len(ctx_rewards) > 0: 284 rewards = torch.cat(ctx_rewards, dim=1) 285 286 # 3. Forward pass through dynamics model 287 # Signal level = max_steps - 1 means "fully denoised / clean observation" 288 # Step size = max_steps // num_steps where num_steps determines denoising granularity 289 # For inference on real observations, we use a single step 290 max_steps = self.dynamics.max_steps 291 num_steps = 4 # Number of denoising steps (must divide max_steps) 292 assert max_steps % num_steps == 0, \ 293 f"max_steps ({max_steps}) must be divisible by num_steps ({num_steps})" 294 step_size = max_steps // num_steps 295 296 _, (embeds, _, _) = self.dynamics( 297 latents=latents_seq, 298 signal_levels=max_steps - 1, # Clean signal (fully denoised) 299 step_sizes=step_size, 300 rewards=rewards, 301 discrete_actions=discrete_actions, 302 # Skip noise injection — latents are clean from tokenizer 303 latent_is_noised=True, 304 return_pred_only=True, 305 return_intermediates=True, 306 ) 307 308 # 4. Extract agent embedding and sample action 309 agent_embed = embeds.agent # (B, T, num_agents, dim) 310 assert agent_embed.ndim == 4, \ 311 f"Expected 4D agent_embed, got shape {agent_embed.shape}" 312 # Take last timestep, first agent 313 one_agent_embed = agent_embed[:, -1:, 0, :] # (1, 1, dim) 314 315 # Policy head → action distribution 316 policy_embed = self.dynamics.policy_head(one_agent_embed) 317 318 # Sample actions from the action embedder 319 if self.stochastic: 320 sampled_discrete, _ = self.dynamics.action_embedder.sample( 321 policy_embed, pred_head_index=0, squeeze=True 322 ) 323 else: 324 # For deterministic: sample with temperature 0 (argmax) 325 sampled_discrete, _ = self.dynamics.action_embedder.sample( 326 policy_embed, pred_head_index=0, 327 discrete_temperature=0.01, squeeze=True 328 ) 329 330 # sampled_discrete: (1, 1, 21) → squeeze to (21,) for action conversion 331 action_tensor = sampled_discrete.squeeze() # (21,) 332 333 # 5. Store action and dummy reward in history 334 self._action_history.append(sampled_discrete) 335 # Use zero reward (we don't know the real reward until env.step) 336 dummy_reward = torch.zeros(1, 1, device=self.device) 337 self._reward_history.append(dummy_reward) 338 339 # 6. Convert to MineRL action format 340 minerl_action = dreamer4_actions_to_minerl(action_tensor, self.action_transformer) 341 342 self._step_count += 1 343 return minerl_action
Get action for a MineRL observation.
Main inference entry point. Matches MineRLAgent.get_action() interface.
Flow:
- Encode observation to latent
- Build history context (latents, actions, rewards)
- Run dynamics model forward pass with fully-denoised signal
- Extract agent embedding → policy head → sample action
- Convert to MineRL format
Arguments:
- minerl_obs: MineRL observation dict with "pov" key
Returns:
MineRL action dict compatible with env.step()