evaluate_dreamer4_minecraft

Evaluation script for Dreamer4 Minecraft agent using VPT infrastructure.

Runs the trained Dreamer4 agent in MineRL HumanSurvival environment and tracks diamond tech tree progression. Supports parallel evaluation using multiple processes with Xvfb for headless rendering.

This is a corrected version of examples/evaluate_dreamer4_parallel.py that:

  • Uses the corrected Dreamer4MinecraftAgent (not the broken DreamerV4Agent)
  • Accepts a single checkpoint path (not three separate checkpoints)
  • Adds proper error handling and timeouts
  • Validates checkpoint existence before spawning workers
Usage:

Single process (for debugging):

python evaluate_dreamer4_minecraft.py --checkpoint ./checkpoints/dreamer4_minecraft.pt --n_episodes 10 --n_workers 1

Parallel headless evaluation:

python evaluate_dreamer4_minecraft.py --checkpoint ./checkpoints/dreamer4_minecraft.pt --n_episodes 100 --n_workers 8 --max_steps 36000 --output_json eval_results.json

  1"""
  2Evaluation script for Dreamer4 Minecraft agent using VPT infrastructure.
  3
  4Runs the trained Dreamer4 agent in MineRL HumanSurvival environment and
  5tracks diamond tech tree progression. Supports parallel evaluation using
  6multiple processes with Xvfb for headless rendering.
  7
  8This is a corrected version of examples/evaluate_dreamer4_parallel.py that:
  9  - Uses the corrected Dreamer4MinecraftAgent (not the broken DreamerV4Agent)
 10  - Accepts a single checkpoint path (not three separate checkpoints)
 11  - Adds proper error handling and timeouts
 12  - Validates checkpoint existence before spawning workers
 13
 14Usage:
 15    # Single process (for debugging):
 16    python evaluate_dreamer4_minecraft.py \
 17        --checkpoint ./checkpoints/dreamer4_minecraft.pt \
 18        --n_episodes 10 \
 19        --n_workers 1
 20
 21    # Parallel headless evaluation:
 22    python evaluate_dreamer4_minecraft.py \
 23        --checkpoint ./checkpoints/dreamer4_minecraft.pt \
 24        --n_episodes 100 \
 25        --n_workers 8 \
 26        --max_steps 36000 \
 27        --output_json eval_results.json
 28"""
 29
 30import os
 31import sys
 32import json
 33import time
 34import atexit
 35import signal
 36import argparse
 37import subprocess
 38import traceback
 39from dataclasses import dataclass, field, asdict
 40from typing import Dict, List
 41from multiprocessing import Queue, get_context
 42
 43import numpy as np
 44
 45
 46# Sentinel result payload signalling a worker exited before producing results.
 47# Carries the worker id and the raw traceback so the parent can surface it
 48# instead of silently hanging on result_queue.get() forever.
 49@dataclass
 50class WorkerFailure:
 51    worker_id: int
 52    stage: str           # "startup", "reset", "step", etc.
 53    error: str           # repr(exception)
 54    tb: str              # traceback.format_exc()
 55
 56
 57# Parent-side registry of live child processes so SIGTERM/SIGINT in the
 58# driver reaps them. Populated in main(), consumed in _terminate_children.
 59_CHILDREN = []
 60
 61
 62def _terminate_children(*_args):
 63    """Kill any still-running worker processes. Safe to call multiple times."""
 64    for p in _CHILDREN:
 65        try:
 66            if p.is_alive():
 67                p.terminate()
 68        except Exception:
 69            pass
 70    # Give them a moment to die, then hard-kill anything left.
 71    deadline = time.time() + 5
 72    for p in _CHILDREN:
 73        try:
 74            remaining = max(0.0, deadline - time.time())
 75            p.join(timeout=remaining)
 76            if p.is_alive():
 77                p.kill()
 78                p.join(timeout=2)
 79        except Exception:
 80            pass
 81
 82# Add paths for imports — this script lives inside Video-Pre-Training/,
 83# so go up one level to find the dreamer4 package and minecraft_vpt_dataset.
 84_this_dir = os.path.dirname(os.path.abspath(__file__))  # .../Video-Pre-Training
 85_project_root = os.path.dirname(_this_dir)              # .../dreamer4 (repo root)
 86sys.path.insert(0, _this_dir)      # VPT local imports (agent.py, lib/)
 87sys.path.insert(0, _project_root)  # dreamer4 package + minecraft_vpt_dataset
 88
 89
 90# Diamond tech tree tasks in progression order.
 91# Matches the tasks reported in the Dreamer4 paper.
 92TECH_TREE_TASKS = [
 93    ("log", "log"),
 94    ("planks", "planks"),
 95    ("crafting_table", "crafting_table"),
 96    ("wooden_pickaxe", "wooden_pickaxe"),
 97    ("cobblestone", "cobblestone"),
 98    ("stone_pickaxe", "stone_pickaxe"),
 99    ("iron_ore", "iron_ore"),
100    ("furnace", "furnace"),
101    ("iron_ingot", "iron_ingot"),
102    ("iron_pickaxe", "iron_pickaxe"),
103    ("diamond", "diamond"),
104]
105
106
107@dataclass
108class EpisodeResult:
109    """Results from a single evaluation episode."""
110    episode_id: int
111    worker_id: int
112    total_steps: int
113    total_reward: float
114    wall_time_seconds: float
115    tech_tree_steps: Dict[str, int] = field(default_factory=dict)
116    tech_tree_achieved: Dict[str, bool] = field(default_factory=dict)
117
118
119def check_inventory(info: dict, item_name: str) -> bool:
120    """Check if ``item_name`` is present in the player's MineRL inventory.
121
122    Args:
123        info: MineRL ``info`` dict returned by ``env.step``.
124        item_name: Minecraft item id (e.g. ``"log"``, ``"iron_pickaxe"``).
125
126    Returns:
127        True if the inventory entry for ``item_name`` exists with a
128        count greater than zero, False otherwise (including when the
129        inventory field is missing or not a dict).
130    """
131    inventory = info.get("inventory", {})
132    if isinstance(inventory, dict):
133        return inventory.get(item_name, 0) > 0
134    return False
135
136
137def run_worker(
138    worker_id: int,
139    episode_queue: Queue,
140    result_queue: Queue,
141    args_dict: dict,
142):
143    """Worker process for parallel evaluation.
144
145    Each worker:
146    1. Starts its own Xvfb display for headless rendering
147    2. Creates a MineRL HumanSurvival environment
148    3. Loads the Dreamer4 agent
149    4. Runs episodes from the queue until poison pill
150    5. Reports results
151    """
152    display_num = 99 + worker_id
153    os.environ["DISPLAY"] = f":{display_num}"
154
155    xvfb_proc = None
156    env = None
157
158    def _emit_failure(stage, exc):
159        """Send traceback back to parent so it lands in sbatch stdout."""
160        try:
161            result_queue.put(WorkerFailure(
162                worker_id=worker_id,
163                stage=stage,
164                error=repr(exc),
165                tb=traceback.format_exc(),
166            ))
167        except Exception:
168            pass
169
170    # Ensure SIGTERM from the parent kills the Minecraft JVMs before we exit.
171    # Without this, an aborted run leaves a tree of Java processes on HPC.
172    def _worker_signal_handler(signum, _frame):
173        raise SystemExit(f"worker {worker_id} received signal {signum}")
174    signal.signal(signal.SIGTERM, _worker_signal_handler)
175
176    if args_dict.get("headless", True):
177        try:
178            xvfb_proc = subprocess.Popen(
179                ["Xvfb", f":{display_num}", "-screen", "0", "1024x768x24", "-ac"],
180                stdout=subprocess.DEVNULL,
181                stderr=subprocess.DEVNULL,
182            )
183            time.sleep(1)
184        except FileNotFoundError:
185            print(f"[Worker {worker_id}] WARNING: Xvfb not found, using existing display")
186            xvfb_proc = None
187
188    try:
189        try:
190            from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival
191            from agent import ENV_KWARGS
192            from dreamer4_minecraft_agent import Dreamer4MinecraftAgent
193            from minerl.env.malmo import InstanceManager, MinecraftInstance
194        except Exception as e:
195            # Import-time failure (e.g. the Python 3.9 / PEP 604 issue this file
196            # has tripped on before). Surface it through the result queue so the
197            # parent can abort instead of waiting an hour per missing result.
198            print(f"[Worker {worker_id}] Import error: {e}")
199            traceback.print_exc()
200            _emit_failure("import", e)
201            return
202
203        InstanceManager.configure_malmo_base_port(9000 + worker_id * 100)
204
205        # Reduce JVM heap from 4G to 2G per instance to prevent OOM kills
206        _orig_mc_init = MinecraftInstance.__init__
207        _mc_max_mem = args_dict["minecraft_mem"]
208
209        def _mc_init_with_reduced_mem(self, port=None, existing=False,
210                                      status_dir=None, seed=None,
211                                      instance_id=None, max_mem=_mc_max_mem):
212            _orig_mc_init(self, port, existing, status_dir, seed,
213                          instance_id, max_mem)
214        MinecraftInstance.__init__ = _mc_init_with_reduced_mem
215
216        try:
217            env = HumanSurvival(**ENV_KWARGS).make()
218        except Exception as e:
219            print(f"[Worker {worker_id}] env construction failed: {e}")
220            traceback.print_exc()
221            _emit_failure("env_make", e)
222            return
223
224        try:
225            agent = Dreamer4MinecraftAgent(
226                env,
227                checkpoint_path=args_dict["checkpoint"],
228                stochastic=not args_dict.get("deterministic", False),
229            )
230        except Exception as e:
231            print(f"[Worker {worker_id}] agent construction failed: {e}")
232            traceback.print_exc()
233            _emit_failure("agent_init", e)
234            return
235
236        while True:
237            episode_id = episode_queue.get()
238            if episode_id is None:
239                break
240
241            print(f"[Worker {worker_id}] Starting episode {episode_id}")
242            t0 = time.time()
243
244            # MineRL can transiently fail, so give it a few tries with backoff
245            obs = None
246            for attempt in range(5):
247                try:
248                    obs = env.reset()
249                    break
250                except Exception as e:
251                    print(f"[Worker {worker_id}] env.reset() attempt {attempt+1}/5 failed: {e}")
252                    if attempt < 4:
253                        try:
254                            env.close()  # have to make new or will just keep getting error
255                        except Exception:
256                            pass
257                        time.sleep(10 * (attempt + 1))
258                        env = HumanSurvival(**ENV_KWARGS).make()
259                    else:
260                        raise
261
262            agent.reset()
263
264            total_reward = 0.0
265            tech_tree_steps = {name: -1 for name, _ in TECH_TREE_TASKS}
266            tech_tree_achieved = {name: False for name, _ in TECH_TREE_TASKS}
267
268            for step in range(args_dict["max_steps"]):
269                try:
270                    action = agent.get_action(obs)
271                    obs, reward, done, info = env.step(action)
272                except Exception as e:
273                    print(f"[Worker {worker_id}] Error at step {step}: {e}")
274                    break
275
276                # Detect MineRL returning a random obs due to a dead Minecraft process
277                if 'error' in info:
278                    print(f"[Worker {worker_id}] Minecraft connection lost at "
279                          f"step {step}, ending episode")
280                    break
281
282                total_reward += reward
283
284                # Check tech tree progression
285                for task_name, item_key in TECH_TREE_TASKS:
286                    if not tech_tree_achieved[task_name]:
287                        if check_inventory(info, item_key):
288                            tech_tree_achieved[task_name] = True
289                            tech_tree_steps[task_name] = step
290                            print(
291                                f"[Worker {worker_id}] Episode {episode_id}: "
292                                f"achieved '{task_name}' at step {step}"
293                            )
294
295                if done:
296                    break
297
298            wall_time = time.time() - t0
299            result = EpisodeResult(
300                episode_id=episode_id,
301                worker_id=worker_id,
302                total_steps=step + 1,
303                total_reward=total_reward,
304                wall_time_seconds=wall_time,
305                tech_tree_steps=tech_tree_steps,
306                tech_tree_achieved=tech_tree_achieved,
307            )
308            result_queue.put(result)
309            print(
310                f"[Worker {worker_id}] Episode {episode_id} done: "
311                f"{step+1} steps, reward={total_reward:.2f}, "
312                f"time={wall_time:.1f}s"
313            )
314
315    except Exception as e:
316        print(f"[Worker {worker_id}] Fatal error: {e}")
317        traceback.print_exc()
318        _emit_failure("fatal", e)
319    finally:
320        # Close the MineRL env first — this shuts down the Minecraft JVM.
321        # Without it, the Java child of this worker survives and lingers
322        # as a zombie until the Slurm job hits its wall time.
323        if env is not None:
324            try:
325                env.close()
326            except Exception:
327                pass
328        if xvfb_proc is not None:
329            try:
330                xvfb_proc.terminate()
331                xvfb_proc.wait(timeout=5)
332            except Exception:
333                try:
334                    xvfb_proc.kill()
335                except Exception:
336                    pass
337
338
339def aggregate_results(results: List[EpisodeResult]) -> Dict:
340    """Compute aggregate statistics across all episodes."""
341    n = len(results)
342    if n == 0:
343        return {}
344
345    stats = {
346        "n_episodes": n,
347        "mean_reward": float(np.mean([r.total_reward for r in results])),
348        "std_reward": float(np.std([r.total_reward for r in results])),
349        "mean_steps": float(np.mean([r.total_steps for r in results])),
350        "tasks": {},
351    }
352
353    for task_name, _ in TECH_TREE_TASKS:
354        achieved = [r.tech_tree_achieved[task_name] for r in results]
355        success_rate = sum(achieved) / n
356        success_steps = [
357            r.tech_tree_steps[task_name]
358            for r in results
359            if r.tech_tree_achieved[task_name]
360        ]
361        mean_steps = float(np.mean(success_steps)) if success_steps else float("inf")
362
363        stats["tasks"][task_name] = {
364            "success_rate": success_rate,
365            "n_successes": sum(achieved),
366            "mean_steps_to_success": mean_steps,
367        }
368
369    return stats
370
371
372def main():
373    """CLI entry point: spawn worker processes and collect evaluation results.
374
375    Builds a shared ``episode_queue`` of ``n_episodes`` work items (plus
376    one ``None`` poison pill per worker), forks ``n_workers`` worker
377    processes running :func:`run_worker`, drains the ``result_queue``
378    with a one-hour-per-result timeout, aggregates stats via
379    :func:`aggregate_results`, prints a tech-tree progression table,
380    and writes the full JSON report to ``--output_json``.
381    """
382    parser = argparse.ArgumentParser("Dreamer4 Minecraft Evaluation")
383    parser.add_argument("--checkpoint", type=str, required=True,
384                        help="Path to dreamer4_minecraft.pt")
385    parser.add_argument("--n_episodes", type=int, default=100)
386    parser.add_argument("--n_workers", type=int, default=1)
387    parser.add_argument("--minecraft_mem", type=str, default="4G",
388                        help="Max JVM heap per Minecraft instance "
389                             "(e.g., '2G', '1G'). Enables parallelism "
390                             "with low VRAM.")
391    parser.add_argument("--max_steps", type=int, default=36000,
392                        help="Max steps per episode (36000 = 30min at 20fps)")
393    parser.add_argument("--deterministic", action="store_true")
394    parser.add_argument("--no_headless", action="store_true",
395                        help="Don't start Xvfb (use existing display)")
396    parser.add_argument("--output_json", type=str, default="eval_results.json")
397    args = parser.parse_args()
398
399    # Validate checkpoint exists
400    if not os.path.exists(args.checkpoint):
401        print(f"ERROR: Checkpoint not found: {args.checkpoint}")
402        sys.exit(1)
403
404    ctx = get_context("spawn")
405    episode_queue = ctx.Queue()
406    result_queue = ctx.Queue()
407
408    for i in range(args.n_episodes):
409        episode_queue.put(i)
410    for _ in range(args.n_workers):
411        episode_queue.put(None)
412
413    args_dict = {
414        "checkpoint": os.path.abspath(args.checkpoint),
415        "max_steps": args.max_steps,
416        "minecraft_mem": args.minecraft_mem,
417        "deterministic": args.deterministic,
418        "headless": not args.no_headless,
419    }
420
421    # Install the global cleanup hooks BEFORE any child starts. atexit covers
422    # normal exits, the signal handlers cover sbatch SIGTERM / user Ctrl-C.
423    # Whatever happens below, worker processes (and therefore their Xvfb and
424    # Minecraft JVM children) get reaped rather than becoming zombies.
425    atexit.register(_terminate_children)
426    signal.signal(signal.SIGTERM, _terminate_children)
427    signal.signal(signal.SIGINT, _terminate_children)
428
429    workers = []
430    for w in range(args.n_workers):
431        p = ctx.Process(target=run_worker, args=(w, episode_queue, result_queue, args_dict))
432        p.start()
433        workers.append(p)
434        _CHILDREN.append(p)
435        if w < args.n_workers - 1:
436            time.sleep(10)  # Stagger Minecraft JVM launches to avoid resource contention
437
438    print(f"Started {args.n_workers} workers for {args.n_episodes} episodes")
439
440    results = []
441    failures = []
442    for _ in range(args.n_episodes):
443        # Poll result_queue with a short timeout so we can notice that every
444        # worker has died (e.g. import error) instead of hanging for an hour.
445        while True:
446            try:
447                item = result_queue.get(timeout=30)
448            except Exception:
449                # Nothing came through in 30s — are any workers still running?
450                if not any(p.is_alive() for p in workers):
451                    print(
452                        f"[main] All {len(workers)} workers have exited; "
453                        f"aborting with {len(results)}/{args.n_episodes} results "
454                        f"({len(failures)} worker failure(s) reported)."
455                    )
456                    break
457                continue  # workers still alive, keep waiting
458            break
459
460        if isinstance(item, WorkerFailure):
461            failures.append(item)
462            print(
463                f"[main] Worker {item.worker_id} failed during "
464                f"'{item.stage}': {item.error}"
465            )
466            print(item.tb)
467            # If every worker has died, no point draining further — abort now.
468            if not any(p.is_alive() for p in workers):
469                print(
470                    f"[main] No workers remain; aborting after "
471                    f"{len(failures)} failure(s)."
472                )
473                break
474            continue
475        if item is None:
476            # treat as timeout-with-dead-workers: already handled above
477            break
478
479        results.append(item)
480        if len(results) % 10 == 0:
481            print(f"Progress: {len(results)}/{args.n_episodes} episodes")
482
483    # Best-effort graceful shutdown; _terminate_children (via atexit) will
484    # force-kill anything still alive.
485    for p in workers:
486        p.join(timeout=30)
487    _terminate_children()
488
489    if failures and not results:
490        print("\n" + "=" * 60)
491        print("EVALUATION ABORTED — all workers failed before producing results")
492        print("=" * 60)
493        for f in failures[:3]:  # show first few tracebacks
494            print(f"\n[Worker {f.worker_id}] stage={f.stage} error={f.error}")
495            print(f.tb)
496        sys.exit(2)
497
498    if results:
499        stats = aggregate_results(results)
500
501        print("\n" + "=" * 60)
502        print("EVALUATION RESULTS")
503        print("=" * 60)
504        print(f"Episodes: {stats['n_episodes']}")
505        print(f"Mean reward: {stats['mean_reward']:.2f} +/- {stats['std_reward']:.2f}")
506        print(f"Mean length: {stats['mean_steps']:.0f} steps")
507        print()
508        print(f"{'Task':<20} {'Success Rate':>12} {'N Success':>10} {'Mean Steps':>12}")
509        print("-" * 56)
510        for task_name, _ in TECH_TREE_TASKS:
511            t = stats["tasks"][task_name]
512            sr = f"{t['success_rate']*100:.1f}%"
513            ns = str(t["n_successes"])
514            mean_steps = t["mean_steps_to_success"]
515            ms = f"{mean_steps:.0f}" if mean_steps != float("inf") else "N/A"
516            print(f"{task_name:<20} {sr:>12} {ns:>10} {ms:>12}")
517
518        output = {
519            "aggregate": stats,
520            "episodes": [asdict(r) for r in results],
521            "config": vars(args),
522        }
523        with open(args.output_json, "w", encoding="utf-8") as f:
524            json.dump(output, f, indent=2, default=str)
525        print(f"\nResults saved to {args.output_json}")
526    else:
527        print("No results collected!")
528
529
530if __name__ == "__main__":
531    main()
@dataclass
class WorkerFailure:
50@dataclass
51class WorkerFailure:
52    worker_id: int
53    stage: str           # "startup", "reset", "step", etc.
54    error: str           # repr(exception)
55    tb: str              # traceback.format_exc()
WorkerFailure(worker_id: int, stage: str, error: str, tb: str)
worker_id: int
stage: str
error: str
tb: str
TECH_TREE_TASKS = [('log', 'log'), ('planks', 'planks'), ('crafting_table', 'crafting_table'), ('wooden_pickaxe', 'wooden_pickaxe'), ('cobblestone', 'cobblestone'), ('stone_pickaxe', 'stone_pickaxe'), ('iron_ore', 'iron_ore'), ('furnace', 'furnace'), ('iron_ingot', 'iron_ingot'), ('iron_pickaxe', 'iron_pickaxe'), ('diamond', 'diamond')]
@dataclass
class EpisodeResult:
108@dataclass
109class EpisodeResult:
110    """Results from a single evaluation episode."""
111    episode_id: int
112    worker_id: int
113    total_steps: int
114    total_reward: float
115    wall_time_seconds: float
116    tech_tree_steps: Dict[str, int] = field(default_factory=dict)
117    tech_tree_achieved: Dict[str, bool] = field(default_factory=dict)

Results from a single evaluation episode.

EpisodeResult( episode_id: int, worker_id: int, total_steps: int, total_reward: float, wall_time_seconds: float, tech_tree_steps: Dict[str, int] = <factory>, tech_tree_achieved: Dict[str, bool] = <factory>)
episode_id: int
worker_id: int
total_steps: int
total_reward: float
wall_time_seconds: float
tech_tree_steps: Dict[str, int]
tech_tree_achieved: Dict[str, bool]
def check_inventory(info: dict, item_name: str) -> bool:
120def check_inventory(info: dict, item_name: str) -> bool:
121    """Check if ``item_name`` is present in the player's MineRL inventory.
122
123    Args:
124        info: MineRL ``info`` dict returned by ``env.step``.
125        item_name: Minecraft item id (e.g. ``"log"``, ``"iron_pickaxe"``).
126
127    Returns:
128        True if the inventory entry for ``item_name`` exists with a
129        count greater than zero, False otherwise (including when the
130        inventory field is missing or not a dict).
131    """
132    inventory = info.get("inventory", {})
133    if isinstance(inventory, dict):
134        return inventory.get(item_name, 0) > 0
135    return False

Check if item_name is present in the player's MineRL inventory.

Arguments:
  • info: MineRL info dict returned by env.step.
  • item_name: Minecraft item id (e.g. "log", "iron_pickaxe").
Returns:

True if the inventory entry for item_name exists with a count greater than zero, False otherwise (including when the inventory field is missing or not a dict).

def run_worker( worker_id: int, episode_queue: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, result_queue: <bound method BaseContext.Queue of <multiprocessing.context.DefaultContext object>>, args_dict: dict):
138def run_worker(
139    worker_id: int,
140    episode_queue: Queue,
141    result_queue: Queue,
142    args_dict: dict,
143):
144    """Worker process for parallel evaluation.
145
146    Each worker:
147    1. Starts its own Xvfb display for headless rendering
148    2. Creates a MineRL HumanSurvival environment
149    3. Loads the Dreamer4 agent
150    4. Runs episodes from the queue until poison pill
151    5. Reports results
152    """
153    display_num = 99 + worker_id
154    os.environ["DISPLAY"] = f":{display_num}"
155
156    xvfb_proc = None
157    env = None
158
159    def _emit_failure(stage, exc):
160        """Send traceback back to parent so it lands in sbatch stdout."""
161        try:
162            result_queue.put(WorkerFailure(
163                worker_id=worker_id,
164                stage=stage,
165                error=repr(exc),
166                tb=traceback.format_exc(),
167            ))
168        except Exception:
169            pass
170
171    # Ensure SIGTERM from the parent kills the Minecraft JVMs before we exit.
172    # Without this, an aborted run leaves a tree of Java processes on HPC.
173    def _worker_signal_handler(signum, _frame):
174        raise SystemExit(f"worker {worker_id} received signal {signum}")
175    signal.signal(signal.SIGTERM, _worker_signal_handler)
176
177    if args_dict.get("headless", True):
178        try:
179            xvfb_proc = subprocess.Popen(
180                ["Xvfb", f":{display_num}", "-screen", "0", "1024x768x24", "-ac"],
181                stdout=subprocess.DEVNULL,
182                stderr=subprocess.DEVNULL,
183            )
184            time.sleep(1)
185        except FileNotFoundError:
186            print(f"[Worker {worker_id}] WARNING: Xvfb not found, using existing display")
187            xvfb_proc = None
188
189    try:
190        try:
191            from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival
192            from agent import ENV_KWARGS
193            from dreamer4_minecraft_agent import Dreamer4MinecraftAgent
194            from minerl.env.malmo import InstanceManager, MinecraftInstance
195        except Exception as e:
196            # Import-time failure (e.g. the Python 3.9 / PEP 604 issue this file
197            # has tripped on before). Surface it through the result queue so the
198            # parent can abort instead of waiting an hour per missing result.
199            print(f"[Worker {worker_id}] Import error: {e}")
200            traceback.print_exc()
201            _emit_failure("import", e)
202            return
203
204        InstanceManager.configure_malmo_base_port(9000 + worker_id * 100)
205
206        # Reduce JVM heap from 4G to 2G per instance to prevent OOM kills
207        _orig_mc_init = MinecraftInstance.__init__
208        _mc_max_mem = args_dict["minecraft_mem"]
209
210        def _mc_init_with_reduced_mem(self, port=None, existing=False,
211                                      status_dir=None, seed=None,
212                                      instance_id=None, max_mem=_mc_max_mem):
213            _orig_mc_init(self, port, existing, status_dir, seed,
214                          instance_id, max_mem)
215        MinecraftInstance.__init__ = _mc_init_with_reduced_mem
216
217        try:
218            env = HumanSurvival(**ENV_KWARGS).make()
219        except Exception as e:
220            print(f"[Worker {worker_id}] env construction failed: {e}")
221            traceback.print_exc()
222            _emit_failure("env_make", e)
223            return
224
225        try:
226            agent = Dreamer4MinecraftAgent(
227                env,
228                checkpoint_path=args_dict["checkpoint"],
229                stochastic=not args_dict.get("deterministic", False),
230            )
231        except Exception as e:
232            print(f"[Worker {worker_id}] agent construction failed: {e}")
233            traceback.print_exc()
234            _emit_failure("agent_init", e)
235            return
236
237        while True:
238            episode_id = episode_queue.get()
239            if episode_id is None:
240                break
241
242            print(f"[Worker {worker_id}] Starting episode {episode_id}")
243            t0 = time.time()
244
245            # MineRL can transiently fail, so give it a few tries with backoff
246            obs = None
247            for attempt in range(5):
248                try:
249                    obs = env.reset()
250                    break
251                except Exception as e:
252                    print(f"[Worker {worker_id}] env.reset() attempt {attempt+1}/5 failed: {e}")
253                    if attempt < 4:
254                        try:
255                            env.close()  # have to make new or will just keep getting error
256                        except Exception:
257                            pass
258                        time.sleep(10 * (attempt + 1))
259                        env = HumanSurvival(**ENV_KWARGS).make()
260                    else:
261                        raise
262
263            agent.reset()
264
265            total_reward = 0.0
266            tech_tree_steps = {name: -1 for name, _ in TECH_TREE_TASKS}
267            tech_tree_achieved = {name: False for name, _ in TECH_TREE_TASKS}
268
269            for step in range(args_dict["max_steps"]):
270                try:
271                    action = agent.get_action(obs)
272                    obs, reward, done, info = env.step(action)
273                except Exception as e:
274                    print(f"[Worker {worker_id}] Error at step {step}: {e}")
275                    break
276
277                # Detect MineRL returning a random obs due to a dead Minecraft process
278                if 'error' in info:
279                    print(f"[Worker {worker_id}] Minecraft connection lost at "
280                          f"step {step}, ending episode")
281                    break
282
283                total_reward += reward
284
285                # Check tech tree progression
286                for task_name, item_key in TECH_TREE_TASKS:
287                    if not tech_tree_achieved[task_name]:
288                        if check_inventory(info, item_key):
289                            tech_tree_achieved[task_name] = True
290                            tech_tree_steps[task_name] = step
291                            print(
292                                f"[Worker {worker_id}] Episode {episode_id}: "
293                                f"achieved '{task_name}' at step {step}"
294                            )
295
296                if done:
297                    break
298
299            wall_time = time.time() - t0
300            result = EpisodeResult(
301                episode_id=episode_id,
302                worker_id=worker_id,
303                total_steps=step + 1,
304                total_reward=total_reward,
305                wall_time_seconds=wall_time,
306                tech_tree_steps=tech_tree_steps,
307                tech_tree_achieved=tech_tree_achieved,
308            )
309            result_queue.put(result)
310            print(
311                f"[Worker {worker_id}] Episode {episode_id} done: "
312                f"{step+1} steps, reward={total_reward:.2f}, "
313                f"time={wall_time:.1f}s"
314            )
315
316    except Exception as e:
317        print(f"[Worker {worker_id}] Fatal error: {e}")
318        traceback.print_exc()
319        _emit_failure("fatal", e)
320    finally:
321        # Close the MineRL env first — this shuts down the Minecraft JVM.
322        # Without it, the Java child of this worker survives and lingers
323        # as a zombie until the Slurm job hits its wall time.
324        if env is not None:
325            try:
326                env.close()
327            except Exception:
328                pass
329        if xvfb_proc is not None:
330            try:
331                xvfb_proc.terminate()
332                xvfb_proc.wait(timeout=5)
333            except Exception:
334                try:
335                    xvfb_proc.kill()
336                except Exception:
337                    pass

Worker process for parallel evaluation.

Each worker:

  1. Starts its own Xvfb display for headless rendering
  2. Creates a MineRL HumanSurvival environment
  3. Loads the Dreamer4 agent
  4. Runs episodes from the queue until poison pill
  5. Reports results
def aggregate_results(results: List[EpisodeResult]) -> Dict:
340def aggregate_results(results: List[EpisodeResult]) -> Dict:
341    """Compute aggregate statistics across all episodes."""
342    n = len(results)
343    if n == 0:
344        return {}
345
346    stats = {
347        "n_episodes": n,
348        "mean_reward": float(np.mean([r.total_reward for r in results])),
349        "std_reward": float(np.std([r.total_reward for r in results])),
350        "mean_steps": float(np.mean([r.total_steps for r in results])),
351        "tasks": {},
352    }
353
354    for task_name, _ in TECH_TREE_TASKS:
355        achieved = [r.tech_tree_achieved[task_name] for r in results]
356        success_rate = sum(achieved) / n
357        success_steps = [
358            r.tech_tree_steps[task_name]
359            for r in results
360            if r.tech_tree_achieved[task_name]
361        ]
362        mean_steps = float(np.mean(success_steps)) if success_steps else float("inf")
363
364        stats["tasks"][task_name] = {
365            "success_rate": success_rate,
366            "n_successes": sum(achieved),
367            "mean_steps_to_success": mean_steps,
368        }
369
370    return stats

Compute aggregate statistics across all episodes.

def main():
373def main():
374    """CLI entry point: spawn worker processes and collect evaluation results.
375
376    Builds a shared ``episode_queue`` of ``n_episodes`` work items (plus
377    one ``None`` poison pill per worker), forks ``n_workers`` worker
378    processes running :func:`run_worker`, drains the ``result_queue``
379    with a one-hour-per-result timeout, aggregates stats via
380    :func:`aggregate_results`, prints a tech-tree progression table,
381    and writes the full JSON report to ``--output_json``.
382    """
383    parser = argparse.ArgumentParser("Dreamer4 Minecraft Evaluation")
384    parser.add_argument("--checkpoint", type=str, required=True,
385                        help="Path to dreamer4_minecraft.pt")
386    parser.add_argument("--n_episodes", type=int, default=100)
387    parser.add_argument("--n_workers", type=int, default=1)
388    parser.add_argument("--minecraft_mem", type=str, default="4G",
389                        help="Max JVM heap per Minecraft instance "
390                             "(e.g., '2G', '1G'). Enables parallelism "
391                             "with low VRAM.")
392    parser.add_argument("--max_steps", type=int, default=36000,
393                        help="Max steps per episode (36000 = 30min at 20fps)")
394    parser.add_argument("--deterministic", action="store_true")
395    parser.add_argument("--no_headless", action="store_true",
396                        help="Don't start Xvfb (use existing display)")
397    parser.add_argument("--output_json", type=str, default="eval_results.json")
398    args = parser.parse_args()
399
400    # Validate checkpoint exists
401    if not os.path.exists(args.checkpoint):
402        print(f"ERROR: Checkpoint not found: {args.checkpoint}")
403        sys.exit(1)
404
405    ctx = get_context("spawn")
406    episode_queue = ctx.Queue()
407    result_queue = ctx.Queue()
408
409    for i in range(args.n_episodes):
410        episode_queue.put(i)
411    for _ in range(args.n_workers):
412        episode_queue.put(None)
413
414    args_dict = {
415        "checkpoint": os.path.abspath(args.checkpoint),
416        "max_steps": args.max_steps,
417        "minecraft_mem": args.minecraft_mem,
418        "deterministic": args.deterministic,
419        "headless": not args.no_headless,
420    }
421
422    # Install the global cleanup hooks BEFORE any child starts. atexit covers
423    # normal exits, the signal handlers cover sbatch SIGTERM / user Ctrl-C.
424    # Whatever happens below, worker processes (and therefore their Xvfb and
425    # Minecraft JVM children) get reaped rather than becoming zombies.
426    atexit.register(_terminate_children)
427    signal.signal(signal.SIGTERM, _terminate_children)
428    signal.signal(signal.SIGINT, _terminate_children)
429
430    workers = []
431    for w in range(args.n_workers):
432        p = ctx.Process(target=run_worker, args=(w, episode_queue, result_queue, args_dict))
433        p.start()
434        workers.append(p)
435        _CHILDREN.append(p)
436        if w < args.n_workers - 1:
437            time.sleep(10)  # Stagger Minecraft JVM launches to avoid resource contention
438
439    print(f"Started {args.n_workers} workers for {args.n_episodes} episodes")
440
441    results = []
442    failures = []
443    for _ in range(args.n_episodes):
444        # Poll result_queue with a short timeout so we can notice that every
445        # worker has died (e.g. import error) instead of hanging for an hour.
446        while True:
447            try:
448                item = result_queue.get(timeout=30)
449            except Exception:
450                # Nothing came through in 30s — are any workers still running?
451                if not any(p.is_alive() for p in workers):
452                    print(
453                        f"[main] All {len(workers)} workers have exited; "
454                        f"aborting with {len(results)}/{args.n_episodes} results "
455                        f"({len(failures)} worker failure(s) reported)."
456                    )
457                    break
458                continue  # workers still alive, keep waiting
459            break
460
461        if isinstance(item, WorkerFailure):
462            failures.append(item)
463            print(
464                f"[main] Worker {item.worker_id} failed during "
465                f"'{item.stage}': {item.error}"
466            )
467            print(item.tb)
468            # If every worker has died, no point draining further — abort now.
469            if not any(p.is_alive() for p in workers):
470                print(
471                    f"[main] No workers remain; aborting after "
472                    f"{len(failures)} failure(s)."
473                )
474                break
475            continue
476        if item is None:
477            # treat as timeout-with-dead-workers: already handled above
478            break
479
480        results.append(item)
481        if len(results) % 10 == 0:
482            print(f"Progress: {len(results)}/{args.n_episodes} episodes")
483
484    # Best-effort graceful shutdown; _terminate_children (via atexit) will
485    # force-kill anything still alive.
486    for p in workers:
487        p.join(timeout=30)
488    _terminate_children()
489
490    if failures and not results:
491        print("\n" + "=" * 60)
492        print("EVALUATION ABORTED — all workers failed before producing results")
493        print("=" * 60)
494        for f in failures[:3]:  # show first few tracebacks
495            print(f"\n[Worker {f.worker_id}] stage={f.stage} error={f.error}")
496            print(f.tb)
497        sys.exit(2)
498
499    if results:
500        stats = aggregate_results(results)
501
502        print("\n" + "=" * 60)
503        print("EVALUATION RESULTS")
504        print("=" * 60)
505        print(f"Episodes: {stats['n_episodes']}")
506        print(f"Mean reward: {stats['mean_reward']:.2f} +/- {stats['std_reward']:.2f}")
507        print(f"Mean length: {stats['mean_steps']:.0f} steps")
508        print()
509        print(f"{'Task':<20} {'Success Rate':>12} {'N Success':>10} {'Mean Steps':>12}")
510        print("-" * 56)
511        for task_name, _ in TECH_TREE_TASKS:
512            t = stats["tasks"][task_name]
513            sr = f"{t['success_rate']*100:.1f}%"
514            ns = str(t["n_successes"])
515            mean_steps = t["mean_steps_to_success"]
516            ms = f"{mean_steps:.0f}" if mean_steps != float("inf") else "N/A"
517            print(f"{task_name:<20} {sr:>12} {ns:>10} {ms:>12}")
518
519        output = {
520            "aggregate": stats,
521            "episodes": [asdict(r) for r in results],
522            "config": vars(args),
523        }
524        with open(args.output_json, "w", encoding="utf-8") as f:
525            json.dump(output, f, indent=2, default=str)
526        print(f"\nResults saved to {args.output_json}")
527    else:
528        print("No results collected!")

CLI entry point: spawn worker processes and collect evaluation results.

Builds a shared episode_queue of n_episodes work items (plus one None poison pill per worker), forks n_workers worker processes running run_worker(), drains the result_queue with a one-hour-per-result timeout, aggregates stats via aggregate_results(), prints a tech-tree progression table, and writes the full JSON report to --output_json.