dreamer4_minecraft_agent

Dreamer4 Minecraft Agent — drop-in replacement for VPT's MineRLAgent.

This module wraps a trained Dreamer4 world model (VideoTokenizer + DynamicsWorldModel with trained policy/value heads) to create an agent that can play Minecraft through the VPT evaluation infrastructure.

It replaces the broken examples/dreamer4_agent.py which imported non-existent classes (Encoder, Decoder, Tokenizer, Dynamics) from a non-existent 'model' module. This version uses the actual Dreamer4 API:

  • VideoTokenizer (from dreamer4.dreamer4)
  • DynamicsWorldModel (from dreamer4.dreamer4)
Inference pipeline at each step:
  1. Receive MineRL observation dict with "pov" key (H, W, 3) uint8
  2. Zero-pad to 384x640, convert to (1, 3, 1, 384, 640) float tensor
  3. VideoTokenizer.tokenize() → latent (1, 1, num_latents, dim_latent)
  4. Concatenate with history latents, add view dim for DynamicsWorldModel
  5. DynamicsWorldModel.forward() with signal_level=max_steps-1 (clean) → agent embedding h_t
  6. Policy head MLP → action distribution
  7. ActionEmbedder.sample() → discrete actions (buttons + camera bin)
  8. Convert discrete actions to MineRL env action dict

Key differences from examples/dreamer4_agent.py:

  • Uses actual VideoTokenizer/DynamicsWorldModel classes
  • Uses DynamicsWorldModel's built-in policy_head and action_embedder
  • No separate PolicyHead class needed
  • Correct action space mapping matching training
  • Proper signal_level/step_size conditioning
  1"""
  2Dreamer4 Minecraft Agent — drop-in replacement for VPT's MineRLAgent.
  3
  4This module wraps a trained Dreamer4 world model (VideoTokenizer +
  5DynamicsWorldModel with trained policy/value heads) to create an agent
  6that can play Minecraft through the VPT evaluation infrastructure.
  7
  8It replaces the broken examples/dreamer4_agent.py which imported
  9non-existent classes (Encoder, Decoder, Tokenizer, Dynamics) from a
 10non-existent 'model' module. This version uses the actual Dreamer4 API:
 11  - VideoTokenizer (from dreamer4.dreamer4)
 12  - DynamicsWorldModel (from dreamer4.dreamer4)
 13
 14Inference pipeline at each step:
 15  1. Receive MineRL observation dict with "pov" key (H, W, 3) uint8
 16  2. Zero-pad to 384x640, convert to (1, 3, 1, 384, 640) float tensor
 17  3. VideoTokenizer.tokenize() → latent (1, 1, num_latents, dim_latent)
 18  4. Concatenate with history latents, add view dim for DynamicsWorldModel
 19  5. DynamicsWorldModel.forward() with signal_level=max_steps-1 (clean)
 20     → agent embedding h_t
 21  6. Policy head MLP → action distribution
 22  7. ActionEmbedder.sample() → discrete actions (buttons + camera bin)
 23  8. Convert discrete actions to MineRL env action dict
 24
 25Key differences from examples/dreamer4_agent.py:
 26  - Uses actual VideoTokenizer/DynamicsWorldModel classes
 27  - Uses DynamicsWorldModel's built-in policy_head and action_embedder
 28  - No separate PolicyHead class needed
 29  - Correct action space mapping matching training
 30  - Proper signal_level/step_size conditioning
 31"""
 32
 33import os
 34import sys
 35from typing import Optional
 36
 37import numpy as np
 38import torch
 39import cv2
 40
 41# Add paths for imports
 42_this_dir = os.path.dirname(os.path.abspath(__file__))  # .../Video-Pre-Training
 43_project_root = os.path.dirname(_this_dir)              # .../dreamer4 (repo root)
 44sys.path.insert(0, _this_dir)      # VPT local imports (agent.py, lib/)
 45sys.path.insert(0, _project_root)  # dreamer4 package + minecraft_vpt_dataset
 46
 47# pylint: disable=wrong-import-position
 48# These imports must follow the sys.path edits above because agent.py and
 49# lib.actions live inside Video-Pre-Training/, and dreamer4 / minecraft_vpt_dataset
 50# live at the repo root — neither is on sys.path when this module is imported
 51# directly (not as part of a package).
 52from dreamer4 import VideoTokenizer, DynamicsWorldModel  # noqa: E402
 53
 54from agent import ENV_KWARGS, validate_env, resize_image  # noqa: E402
 55from lib.actions import ActionTransformer  # noqa: E402
 56
 57from minecraft_vpt_dataset import (  # noqa: E402
 58    BUTTONS_ALL,
 59    N_BUTTONS,
 60    N_CAMERA_BINS,
 61    CAMERA_MAXVAL,
 62    CAMERA_BINSIZE,
 63    CAMERA_MU,
 64)
 65# pylint: enable=wrong-import-position
 66
 67
 68# VPT action transformer for converting discrete bins back to env actions
 69ACTION_TRANSFORMER_KWARGS = dict(
 70    camera_binsize=CAMERA_BINSIZE,
 71    camera_maxval=CAMERA_MAXVAL,
 72    camera_mu=CAMERA_MU,
 73    camera_quantization_scheme="mu_law",
 74)
 75
 76
 77def dreamer4_actions_to_minerl(
 78    discrete_actions: torch.Tensor,
 79    action_transformer: ActionTransformer,
 80) -> dict:
 81    """Convert Dreamer4 discrete action tensor to MineRL env action dict.
 82
 83    Args:
 84        discrete_actions: (21,) long tensor
 85            [0:20] = button states (0 or 1)
 86            [20]   = camera joint index (0 to 120)
 87        action_transformer: VPT ActionTransformer for camera undiscretization
 88
 89    Returns:
 90        MineRL-compatible action dict with all required keys
 91    """
 92    actions_np = discrete_actions.cpu().numpy()
 93
 94    env_action = {}
 95
 96    # Buttons: map discrete values to button names
 97    for i, button_name in enumerate(BUTTONS_ALL):
 98        env_action[button_name] = int(actions_np[i])
 99
100    # Keys not in our 20-button set default to 0
101    env_action["ESC"] = 0
102    env_action["pickItem"] = 0
103    env_action["swapHands"] = 0
104
105    # Camera: decompose joint index to pitch/yaw bins, then undiscretize
106    camera_joint = int(actions_np[N_BUTTONS])
107    pitch_bin = camera_joint // N_CAMERA_BINS
108    yaw_bin = camera_joint % N_CAMERA_BINS
109
110    # Use VPT's undiscretize to convert bins back to degrees
111    camera_bins = np.array([[pitch_bin, yaw_bin]])
112    camera_degrees = action_transformer.undiscretize_camera(camera_bins)
113    env_action["camera"] = camera_degrees[0]  # (2,) array [pitch, yaw]
114
115    return env_action
116
117
118class Dreamer4MinecraftAgent:
119    """Dreamer4-based Minecraft agent compatible with VPT evaluation scripts.
120
121    This class provides the same interface as VPT's MineRLAgent:
122      - __init__(env, ...) validates the environment
123      - reset() resets internal state
124      - get_action(minerl_obs) returns MineRL-compatible action dict
125
126    Architecture:
127      The agent uses the DynamicsWorldModel in a special "inference" mode:
128      - Signal level = max_steps - 1 (fully denoised / clean observation)
129      - Step size = max_steps (single step, since observation is already clean)
130      - The model processes the observation history and outputs agent embeddings
131      - Policy head maps embeddings to action distributions
132
133      This is equivalent to what interact_with_env() does internally,
134      but we handle the environment stepping ourselves to match VPT's interface.
135    """
136
137    def __init__(
138        self,
139        env,
140        *,
141        checkpoint_path: str,
142        device: Optional[str] = None,
143        stochastic: bool = True,
144        max_context_len: int = 32,
145    ):
146        """
147        Args:
148            env: MineRL environment (validated against VPT settings)
149            checkpoint_path: Path to dreamer4_minecraft.pt from training
150            device: "cuda" or "cpu"
151            stochastic: Sample from policy (True) or take argmax (False)
152            max_context_len: Maximum history length for the dynamics model
153        """
154        validate_env(env)
155
156        if device is None:
157            device = "cuda" if torch.cuda.is_available() else "cpu"
158        self.device = torch.device(device)
159        self.stochastic = stochastic
160        self.max_context_len = max_context_len
161
162        # VPT action transformer for converting bins back to degrees
163        self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)
164
165        # Load checkpoint
166        ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
167        tok_config = ckpt.get('tokenizer_config', {})
168        dyn_config = ckpt.get('config', {})
169
170        # Rebuild models
171        self.tokenizer = VideoTokenizer(**tok_config)
172        self.tokenizer.eval()
173        for p in self.tokenizer.parameters():
174            p.requires_grad_(False)
175
176        self.dynamics = DynamicsWorldModel(
177            video_tokenizer=self.tokenizer,
178            **dyn_config,
179        )
180        self.dynamics.load_state_dict(ckpt['model'])
181        self.dynamics.eval()
182        for p in self.dynamics.parameters():
183            p.requires_grad_(False)
184
185        self.dynamics.to(self.device)
186
187        n_params = sum(p.numel() for p in self.dynamics.parameters())
188        print(f"Loaded Dreamer4 agent ({n_params:,} params)")
189
190        # Internal state
191        self._latent_history = []     # List of (1, 1, num_views, num_latents, dim_latent) tensors
192        self._action_history = []     # List of (1, 1, 21) discrete action tensors
193        self._reward_history = []     # List of (1, 1) reward tensors
194        self._step_count = 0
195
196    def load_weights(self, path: str):  # noqa: ARG002  (interface compat)
197        """Compatibility stub — weights are loaded in __init__."""
198
199    def reset(self):
200        """Reset agent state for a new episode."""
201        self._latent_history = []
202        self._action_history = []
203        self._reward_history = []
204        self._step_count = 0
205
206    @torch.no_grad()
207    def _encode_observation(self, minerl_obs: dict) -> torch.Tensor:
208        """Encode a MineRL observation to latent tokens.
209
210        Args:
211            minerl_obs: MineRL observation dict with "pov" key
212
213        Returns:
214            latents: (1, 1, num_latent_tokens, dim_latent) tensor
215        """
216        # Extract frame and zero-pad to 384x640 (paper: 360x640 → 384x640)
217        frame = minerl_obs["pov"]  # (H, W, 3) uint8, typically 360x640
218        h, w = frame.shape[:2]
219        target_h, target_w = 384, 640
220        if w == target_w and h < target_h:
221            pad_total = target_h - h
222            pad_top = pad_total // 2
223            pad_bottom = pad_total - pad_top
224            frame = cv2.copyMakeBorder(
225                frame, pad_top, pad_bottom, 0, 0,
226                cv2.BORDER_CONSTANT, value=(0, 0, 0)
227            )
228        elif h != target_h or w != target_w:
229            frame = resize_image(frame, (target_w, target_h))
230
231        # Convert to (1, 3, 1, 384, 640) float tensor
232        frame_t = torch.from_numpy(frame).float() / 255.0
233        frame_t = frame_t.permute(2, 0, 1)  # (3, 384, 640)
234        video = frame_t.unsqueeze(0).unsqueeze(2).to(self.device)  # (1, 3, 1, 384, 640)
235
236        # Tokenize: returns (1, 1, num_latents, dim_latent)
237        latents = self.tokenizer.tokenize(video)
238        return latents
239
240    @torch.no_grad()
241    def get_action(self, minerl_obs: dict) -> dict:
242        """Get action for a MineRL observation.
243
244        Main inference entry point. Matches MineRLAgent.get_action() interface.
245
246        Flow:
247        1. Encode observation to latent
248        2. Build history context (latents, actions, rewards)
249        3. Run dynamics model forward pass with fully-denoised signal
250        4. Extract agent embedding → policy head → sample action
251        5. Convert to MineRL format
252
253        Args:
254            minerl_obs: MineRL observation dict with "pov" key
255
256        Returns:
257            MineRL action dict compatible with env.step()
258        """
259        # 1. Encode current observation
260        latents = self._encode_observation(minerl_obs)  # (1, 1, n, d)
261        self._latent_history.append(latents)
262
263        # 2. Build context window
264        # Trim to max context length
265        start = max(0, len(self._latent_history) - self.max_context_len)
266        ctx_latents = self._latent_history[start:]
267        ctx_actions = self._action_history[start:]
268        ctx_rewards = self._reward_history[start:]
269
270        # Stack latents: (1, T, num_latents, dim_latent)
271        latents_seq = torch.cat(ctx_latents, dim=1)
272
273        # Stack actions: (1, T-1, 21) or None
274        # Actions are shifted: action[t] is the action taken AFTER observing frame[t]
275        # So for T frames we have T-1 past actions (none before first frame)
276        discrete_actions = None
277        if len(ctx_actions) > 0:
278            discrete_actions = torch.cat(ctx_actions, dim=1)
279
280        # Stack rewards: (1, T-1) or None
281        rewards = None
282        if len(ctx_rewards) > 0:
283            rewards = torch.cat(ctx_rewards, dim=1)
284
285        # 3. Forward pass through dynamics model
286        # Signal level = max_steps - 1 means "fully denoised / clean observation"
287        # Step size = max_steps // num_steps where num_steps determines denoising granularity
288        # For inference on real observations, we use a single step
289        max_steps = self.dynamics.max_steps
290        num_steps = 4  # Number of denoising steps (must divide max_steps)
291        assert max_steps % num_steps == 0, \
292            f"max_steps ({max_steps}) must be divisible by num_steps ({num_steps})"
293        step_size = max_steps // num_steps
294
295        _, (embeds, _, _) = self.dynamics(
296            latents=latents_seq,
297            signal_levels=max_steps - 1,    # Clean signal (fully denoised)
298            step_sizes=step_size,
299            rewards=rewards,
300            discrete_actions=discrete_actions,
301            # Skip noise injection — latents are clean from tokenizer
302            latent_is_noised=True,
303            return_pred_only=True,
304            return_intermediates=True,
305        )
306
307        # 4. Extract agent embedding and sample action
308        agent_embed = embeds.agent  # (B, T, num_agents, dim)
309        assert agent_embed.ndim == 4, \
310            f"Expected 4D agent_embed, got shape {agent_embed.shape}"
311        # Take last timestep, first agent
312        one_agent_embed = agent_embed[:, -1:, 0, :]  # (1, 1, dim)
313
314        # Policy head → action distribution
315        policy_embed = self.dynamics.policy_head(one_agent_embed)
316
317        # Sample actions from the action embedder
318        if self.stochastic:
319            sampled_discrete, _ = self.dynamics.action_embedder.sample(
320                policy_embed, pred_head_index=0, squeeze=True
321            )
322        else:
323            # For deterministic: sample with temperature 0 (argmax)
324            sampled_discrete, _ = self.dynamics.action_embedder.sample(
325                policy_embed, pred_head_index=0,
326                discrete_temperature=0.01, squeeze=True
327            )
328
329        # sampled_discrete: (1, 1, 21) → squeeze to (21,) for action conversion
330        action_tensor = sampled_discrete.squeeze()  # (21,)
331
332        # 5. Store action and dummy reward in history
333        self._action_history.append(sampled_discrete)
334        # Use zero reward (we don't know the real reward until env.step)
335        dummy_reward = torch.zeros(1, 1, device=self.device)
336        self._reward_history.append(dummy_reward)
337
338        # 6. Convert to MineRL action format
339        minerl_action = dreamer4_actions_to_minerl(action_tensor, self.action_transformer)
340
341        self._step_count += 1
342        return minerl_action
343
344
345# ─── CLI for quick testing ──────────────────────────────────────────
346
347if __name__ == "__main__":
348    from argparse import ArgumentParser
349    from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival
350
351    parser = ArgumentParser("Run Dreamer4 agent on MineRL")
352    parser.add_argument("--checkpoint", type=str, required=True,
353                        help="Path to dreamer4_minecraft.pt checkpoint")
354    parser.add_argument("--deterministic", action="store_true")
355    parser.add_argument("--render", action="store_true")
356    parser.add_argument("--max_steps", type=int, default=36000)
357    args = parser.parse_args()
358
359    env = HumanSurvival(**ENV_KWARGS).make()
360    agent = Dreamer4MinecraftAgent(
361        env,
362        checkpoint_path=args.checkpoint,
363        stochastic=not args.deterministic,
364    )
365
366    obs = env.reset()
367    agent.reset()
368    total_reward = 0.0
369
370    for step in range(args.max_steps):
371        action = agent.get_action(obs)
372        obs, reward, done, info = env.step(action)
373        total_reward += reward
374
375        if step % 100 == 0:
376            print(f"Step {step}, reward: {total_reward:.2f}")
377        if args.render:
378            frame = obs["pov"]  # (360, 640, 3) native resolution
379            cv2.imshow("Dreamer4 Minecraft", frame[:, :, ::-1])
380            if cv2.waitKey(1) & 0xFF == ord('q'):
381                break
382        if done:
383            print(f"Episode done at step {step}, reward: {total_reward:.2f}")
384            obs = env.reset()
385            agent.reset()
386            total_reward = 0.0
387    # shut down at end of for loop
388    if args.render:
389        cv2.destroyAllWindows()
ACTION_TRANSFORMER_KWARGS = {'camera_binsize': 2, 'camera_maxval': 10, 'camera_mu': 10, 'camera_quantization_scheme': 'mu_law'}
def dreamer4_actions_to_minerl( discrete_actions: torch.Tensor, action_transformer: lib.actions.ActionTransformer) -> dict:
 78def dreamer4_actions_to_minerl(
 79    discrete_actions: torch.Tensor,
 80    action_transformer: ActionTransformer,
 81) -> dict:
 82    """Convert Dreamer4 discrete action tensor to MineRL env action dict.
 83
 84    Args:
 85        discrete_actions: (21,) long tensor
 86            [0:20] = button states (0 or 1)
 87            [20]   = camera joint index (0 to 120)
 88        action_transformer: VPT ActionTransformer for camera undiscretization
 89
 90    Returns:
 91        MineRL-compatible action dict with all required keys
 92    """
 93    actions_np = discrete_actions.cpu().numpy()
 94
 95    env_action = {}
 96
 97    # Buttons: map discrete values to button names
 98    for i, button_name in enumerate(BUTTONS_ALL):
 99        env_action[button_name] = int(actions_np[i])
100
101    # Keys not in our 20-button set default to 0
102    env_action["ESC"] = 0
103    env_action["pickItem"] = 0
104    env_action["swapHands"] = 0
105
106    # Camera: decompose joint index to pitch/yaw bins, then undiscretize
107    camera_joint = int(actions_np[N_BUTTONS])
108    pitch_bin = camera_joint // N_CAMERA_BINS
109    yaw_bin = camera_joint % N_CAMERA_BINS
110
111    # Use VPT's undiscretize to convert bins back to degrees
112    camera_bins = np.array([[pitch_bin, yaw_bin]])
113    camera_degrees = action_transformer.undiscretize_camera(camera_bins)
114    env_action["camera"] = camera_degrees[0]  # (2,) array [pitch, yaw]
115
116    return env_action

Convert Dreamer4 discrete action tensor to MineRL env action dict.

Arguments:
  • discrete_actions: (21,) long tensor [0:20] = button states (0 or 1) [20] = camera joint index (0 to 120)
  • action_transformer: VPT ActionTransformer for camera undiscretization
Returns:

MineRL-compatible action dict with all required keys

class Dreamer4MinecraftAgent:
119class Dreamer4MinecraftAgent:
120    """Dreamer4-based Minecraft agent compatible with VPT evaluation scripts.
121
122    This class provides the same interface as VPT's MineRLAgent:
123      - __init__(env, ...) validates the environment
124      - reset() resets internal state
125      - get_action(minerl_obs) returns MineRL-compatible action dict
126
127    Architecture:
128      The agent uses the DynamicsWorldModel in a special "inference" mode:
129      - Signal level = max_steps - 1 (fully denoised / clean observation)
130      - Step size = max_steps (single step, since observation is already clean)
131      - The model processes the observation history and outputs agent embeddings
132      - Policy head maps embeddings to action distributions
133
134      This is equivalent to what interact_with_env() does internally,
135      but we handle the environment stepping ourselves to match VPT's interface.
136    """
137
138    def __init__(
139        self,
140        env,
141        *,
142        checkpoint_path: str,
143        device: Optional[str] = None,
144        stochastic: bool = True,
145        max_context_len: int = 32,
146    ):
147        """
148        Args:
149            env: MineRL environment (validated against VPT settings)
150            checkpoint_path: Path to dreamer4_minecraft.pt from training
151            device: "cuda" or "cpu"
152            stochastic: Sample from policy (True) or take argmax (False)
153            max_context_len: Maximum history length for the dynamics model
154        """
155        validate_env(env)
156
157        if device is None:
158            device = "cuda" if torch.cuda.is_available() else "cpu"
159        self.device = torch.device(device)
160        self.stochastic = stochastic
161        self.max_context_len = max_context_len
162
163        # VPT action transformer for converting bins back to degrees
164        self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)
165
166        # Load checkpoint
167        ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
168        tok_config = ckpt.get('tokenizer_config', {})
169        dyn_config = ckpt.get('config', {})
170
171        # Rebuild models
172        self.tokenizer = VideoTokenizer(**tok_config)
173        self.tokenizer.eval()
174        for p in self.tokenizer.parameters():
175            p.requires_grad_(False)
176
177        self.dynamics = DynamicsWorldModel(
178            video_tokenizer=self.tokenizer,
179            **dyn_config,
180        )
181        self.dynamics.load_state_dict(ckpt['model'])
182        self.dynamics.eval()
183        for p in self.dynamics.parameters():
184            p.requires_grad_(False)
185
186        self.dynamics.to(self.device)
187
188        n_params = sum(p.numel() for p in self.dynamics.parameters())
189        print(f"Loaded Dreamer4 agent ({n_params:,} params)")
190
191        # Internal state
192        self._latent_history = []     # List of (1, 1, num_views, num_latents, dim_latent) tensors
193        self._action_history = []     # List of (1, 1, 21) discrete action tensors
194        self._reward_history = []     # List of (1, 1) reward tensors
195        self._step_count = 0
196
197    def load_weights(self, path: str):  # noqa: ARG002  (interface compat)
198        """Compatibility stub — weights are loaded in __init__."""
199
200    def reset(self):
201        """Reset agent state for a new episode."""
202        self._latent_history = []
203        self._action_history = []
204        self._reward_history = []
205        self._step_count = 0
206
207    @torch.no_grad()
208    def _encode_observation(self, minerl_obs: dict) -> torch.Tensor:
209        """Encode a MineRL observation to latent tokens.
210
211        Args:
212            minerl_obs: MineRL observation dict with "pov" key
213
214        Returns:
215            latents: (1, 1, num_latent_tokens, dim_latent) tensor
216        """
217        # Extract frame and zero-pad to 384x640 (paper: 360x640 → 384x640)
218        frame = minerl_obs["pov"]  # (H, W, 3) uint8, typically 360x640
219        h, w = frame.shape[:2]
220        target_h, target_w = 384, 640
221        if w == target_w and h < target_h:
222            pad_total = target_h - h
223            pad_top = pad_total // 2
224            pad_bottom = pad_total - pad_top
225            frame = cv2.copyMakeBorder(
226                frame, pad_top, pad_bottom, 0, 0,
227                cv2.BORDER_CONSTANT, value=(0, 0, 0)
228            )
229        elif h != target_h or w != target_w:
230            frame = resize_image(frame, (target_w, target_h))
231
232        # Convert to (1, 3, 1, 384, 640) float tensor
233        frame_t = torch.from_numpy(frame).float() / 255.0
234        frame_t = frame_t.permute(2, 0, 1)  # (3, 384, 640)
235        video = frame_t.unsqueeze(0).unsqueeze(2).to(self.device)  # (1, 3, 1, 384, 640)
236
237        # Tokenize: returns (1, 1, num_latents, dim_latent)
238        latents = self.tokenizer.tokenize(video)
239        return latents
240
241    @torch.no_grad()
242    def get_action(self, minerl_obs: dict) -> dict:
243        """Get action for a MineRL observation.
244
245        Main inference entry point. Matches MineRLAgent.get_action() interface.
246
247        Flow:
248        1. Encode observation to latent
249        2. Build history context (latents, actions, rewards)
250        3. Run dynamics model forward pass with fully-denoised signal
251        4. Extract agent embedding → policy head → sample action
252        5. Convert to MineRL format
253
254        Args:
255            minerl_obs: MineRL observation dict with "pov" key
256
257        Returns:
258            MineRL action dict compatible with env.step()
259        """
260        # 1. Encode current observation
261        latents = self._encode_observation(minerl_obs)  # (1, 1, n, d)
262        self._latent_history.append(latents)
263
264        # 2. Build context window
265        # Trim to max context length
266        start = max(0, len(self._latent_history) - self.max_context_len)
267        ctx_latents = self._latent_history[start:]
268        ctx_actions = self._action_history[start:]
269        ctx_rewards = self._reward_history[start:]
270
271        # Stack latents: (1, T, num_latents, dim_latent)
272        latents_seq = torch.cat(ctx_latents, dim=1)
273
274        # Stack actions: (1, T-1, 21) or None
275        # Actions are shifted: action[t] is the action taken AFTER observing frame[t]
276        # So for T frames we have T-1 past actions (none before first frame)
277        discrete_actions = None
278        if len(ctx_actions) > 0:
279            discrete_actions = torch.cat(ctx_actions, dim=1)
280
281        # Stack rewards: (1, T-1) or None
282        rewards = None
283        if len(ctx_rewards) > 0:
284            rewards = torch.cat(ctx_rewards, dim=1)
285
286        # 3. Forward pass through dynamics model
287        # Signal level = max_steps - 1 means "fully denoised / clean observation"
288        # Step size = max_steps // num_steps where num_steps determines denoising granularity
289        # For inference on real observations, we use a single step
290        max_steps = self.dynamics.max_steps
291        num_steps = 4  # Number of denoising steps (must divide max_steps)
292        assert max_steps % num_steps == 0, \
293            f"max_steps ({max_steps}) must be divisible by num_steps ({num_steps})"
294        step_size = max_steps // num_steps
295
296        _, (embeds, _, _) = self.dynamics(
297            latents=latents_seq,
298            signal_levels=max_steps - 1,    # Clean signal (fully denoised)
299            step_sizes=step_size,
300            rewards=rewards,
301            discrete_actions=discrete_actions,
302            # Skip noise injection — latents are clean from tokenizer
303            latent_is_noised=True,
304            return_pred_only=True,
305            return_intermediates=True,
306        )
307
308        # 4. Extract agent embedding and sample action
309        agent_embed = embeds.agent  # (B, T, num_agents, dim)
310        assert agent_embed.ndim == 4, \
311            f"Expected 4D agent_embed, got shape {agent_embed.shape}"
312        # Take last timestep, first agent
313        one_agent_embed = agent_embed[:, -1:, 0, :]  # (1, 1, dim)
314
315        # Policy head → action distribution
316        policy_embed = self.dynamics.policy_head(one_agent_embed)
317
318        # Sample actions from the action embedder
319        if self.stochastic:
320            sampled_discrete, _ = self.dynamics.action_embedder.sample(
321                policy_embed, pred_head_index=0, squeeze=True
322            )
323        else:
324            # For deterministic: sample with temperature 0 (argmax)
325            sampled_discrete, _ = self.dynamics.action_embedder.sample(
326                policy_embed, pred_head_index=0,
327                discrete_temperature=0.01, squeeze=True
328            )
329
330        # sampled_discrete: (1, 1, 21) → squeeze to (21,) for action conversion
331        action_tensor = sampled_discrete.squeeze()  # (21,)
332
333        # 5. Store action and dummy reward in history
334        self._action_history.append(sampled_discrete)
335        # Use zero reward (we don't know the real reward until env.step)
336        dummy_reward = torch.zeros(1, 1, device=self.device)
337        self._reward_history.append(dummy_reward)
338
339        # 6. Convert to MineRL action format
340        minerl_action = dreamer4_actions_to_minerl(action_tensor, self.action_transformer)
341
342        self._step_count += 1
343        return minerl_action

Dreamer4-based Minecraft agent compatible with VPT evaluation scripts.

This class provides the same interface as VPT's MineRLAgent:

  • __init__(env, ...) validates the environment
  • reset() resets internal state
  • get_action(minerl_obs) returns MineRL-compatible action dict
Architecture:

The agent uses the DynamicsWorldModel in a special "inference" mode:

  • Signal level = max_steps - 1 (fully denoised / clean observation)
  • Step size = max_steps (single step, since observation is already clean)
  • The model processes the observation history and outputs agent embeddings
  • Policy head maps embeddings to action distributions

This is equivalent to what interact_with_env() does internally, but we handle the environment stepping ourselves to match VPT's interface.

Dreamer4MinecraftAgent( env, *, checkpoint_path: str, device: Optional[str] = None, stochastic: bool = True, max_context_len: int = 32)
138    def __init__(
139        self,
140        env,
141        *,
142        checkpoint_path: str,
143        device: Optional[str] = None,
144        stochastic: bool = True,
145        max_context_len: int = 32,
146    ):
147        """
148        Args:
149            env: MineRL environment (validated against VPT settings)
150            checkpoint_path: Path to dreamer4_minecraft.pt from training
151            device: "cuda" or "cpu"
152            stochastic: Sample from policy (True) or take argmax (False)
153            max_context_len: Maximum history length for the dynamics model
154        """
155        validate_env(env)
156
157        if device is None:
158            device = "cuda" if torch.cuda.is_available() else "cpu"
159        self.device = torch.device(device)
160        self.stochastic = stochastic
161        self.max_context_len = max_context_len
162
163        # VPT action transformer for converting bins back to degrees
164        self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)
165
166        # Load checkpoint
167        ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
168        tok_config = ckpt.get('tokenizer_config', {})
169        dyn_config = ckpt.get('config', {})
170
171        # Rebuild models
172        self.tokenizer = VideoTokenizer(**tok_config)
173        self.tokenizer.eval()
174        for p in self.tokenizer.parameters():
175            p.requires_grad_(False)
176
177        self.dynamics = DynamicsWorldModel(
178            video_tokenizer=self.tokenizer,
179            **dyn_config,
180        )
181        self.dynamics.load_state_dict(ckpt['model'])
182        self.dynamics.eval()
183        for p in self.dynamics.parameters():
184            p.requires_grad_(False)
185
186        self.dynamics.to(self.device)
187
188        n_params = sum(p.numel() for p in self.dynamics.parameters())
189        print(f"Loaded Dreamer4 agent ({n_params:,} params)")
190
191        # Internal state
192        self._latent_history = []     # List of (1, 1, num_views, num_latents, dim_latent) tensors
193        self._action_history = []     # List of (1, 1, 21) discrete action tensors
194        self._reward_history = []     # List of (1, 1) reward tensors
195        self._step_count = 0
Arguments:
  • env: MineRL environment (validated against VPT settings)
  • checkpoint_path: Path to dreamer4_minecraft.pt from training
  • device: "cuda" or "cpu"
  • stochastic: Sample from policy (True) or take argmax (False)
  • max_context_len: Maximum history length for the dynamics model
device
stochastic
max_context_len
action_transformer
tokenizer
dynamics
def load_weights(self, path: str):
197    def load_weights(self, path: str):  # noqa: ARG002  (interface compat)
198        """Compatibility stub — weights are loaded in __init__."""

Compatibility stub — weights are loaded in __init__.

def reset(self):
200    def reset(self):
201        """Reset agent state for a new episode."""
202        self._latent_history = []
203        self._action_history = []
204        self._reward_history = []
205        self._step_count = 0

Reset agent state for a new episode.

@torch.no_grad()
def get_action(self, minerl_obs: dict) -> dict:
241    @torch.no_grad()
242    def get_action(self, minerl_obs: dict) -> dict:
243        """Get action for a MineRL observation.
244
245        Main inference entry point. Matches MineRLAgent.get_action() interface.
246
247        Flow:
248        1. Encode observation to latent
249        2. Build history context (latents, actions, rewards)
250        3. Run dynamics model forward pass with fully-denoised signal
251        4. Extract agent embedding → policy head → sample action
252        5. Convert to MineRL format
253
254        Args:
255            minerl_obs: MineRL observation dict with "pov" key
256
257        Returns:
258            MineRL action dict compatible with env.step()
259        """
260        # 1. Encode current observation
261        latents = self._encode_observation(minerl_obs)  # (1, 1, n, d)
262        self._latent_history.append(latents)
263
264        # 2. Build context window
265        # Trim to max context length
266        start = max(0, len(self._latent_history) - self.max_context_len)
267        ctx_latents = self._latent_history[start:]
268        ctx_actions = self._action_history[start:]
269        ctx_rewards = self._reward_history[start:]
270
271        # Stack latents: (1, T, num_latents, dim_latent)
272        latents_seq = torch.cat(ctx_latents, dim=1)
273
274        # Stack actions: (1, T-1, 21) or None
275        # Actions are shifted: action[t] is the action taken AFTER observing frame[t]
276        # So for T frames we have T-1 past actions (none before first frame)
277        discrete_actions = None
278        if len(ctx_actions) > 0:
279            discrete_actions = torch.cat(ctx_actions, dim=1)
280
281        # Stack rewards: (1, T-1) or None
282        rewards = None
283        if len(ctx_rewards) > 0:
284            rewards = torch.cat(ctx_rewards, dim=1)
285
286        # 3. Forward pass through dynamics model
287        # Signal level = max_steps - 1 means "fully denoised / clean observation"
288        # Step size = max_steps // num_steps where num_steps determines denoising granularity
289        # For inference on real observations, we use a single step
290        max_steps = self.dynamics.max_steps
291        num_steps = 4  # Number of denoising steps (must divide max_steps)
292        assert max_steps % num_steps == 0, \
293            f"max_steps ({max_steps}) must be divisible by num_steps ({num_steps})"
294        step_size = max_steps // num_steps
295
296        _, (embeds, _, _) = self.dynamics(
297            latents=latents_seq,
298            signal_levels=max_steps - 1,    # Clean signal (fully denoised)
299            step_sizes=step_size,
300            rewards=rewards,
301            discrete_actions=discrete_actions,
302            # Skip noise injection — latents are clean from tokenizer
303            latent_is_noised=True,
304            return_pred_only=True,
305            return_intermediates=True,
306        )
307
308        # 4. Extract agent embedding and sample action
309        agent_embed = embeds.agent  # (B, T, num_agents, dim)
310        assert agent_embed.ndim == 4, \
311            f"Expected 4D agent_embed, got shape {agent_embed.shape}"
312        # Take last timestep, first agent
313        one_agent_embed = agent_embed[:, -1:, 0, :]  # (1, 1, dim)
314
315        # Policy head → action distribution
316        policy_embed = self.dynamics.policy_head(one_agent_embed)
317
318        # Sample actions from the action embedder
319        if self.stochastic:
320            sampled_discrete, _ = self.dynamics.action_embedder.sample(
321                policy_embed, pred_head_index=0, squeeze=True
322            )
323        else:
324            # For deterministic: sample with temperature 0 (argmax)
325            sampled_discrete, _ = self.dynamics.action_embedder.sample(
326                policy_embed, pred_head_index=0,
327                discrete_temperature=0.01, squeeze=True
328            )
329
330        # sampled_discrete: (1, 1, 21) → squeeze to (21,) for action conversion
331        action_tensor = sampled_discrete.squeeze()  # (21,)
332
333        # 5. Store action and dummy reward in history
334        self._action_history.append(sampled_discrete)
335        # Use zero reward (we don't know the real reward until env.step)
336        dummy_reward = torch.zeros(1, 1, device=self.device)
337        self._reward_history.append(dummy_reward)
338
339        # 6. Convert to MineRL action format
340        minerl_action = dreamer4_actions_to_minerl(action_tensor, self.action_transformer)
341
342        self._step_count += 1
343        return minerl_action

Get action for a MineRL observation.

Main inference entry point. Matches MineRLAgent.get_action() interface.

Flow:

  1. Encode observation to latent
  2. Build history context (latents, actions, rewards)
  3. Run dynamics model forward pass with fully-denoised signal
  4. Extract agent embedding → policy head → sample action
  5. Convert to MineRL format
Arguments:
  • minerl_obs: MineRL observation dict with "pov" key
Returns:

MineRL action dict compatible with env.step()