evaluate_dreamer4_minecraft
Evaluation script for Dreamer4 Minecraft agent using VPT infrastructure.
Runs the trained Dreamer4 agent in MineRL HumanSurvival environment and tracks diamond tech tree progression. Supports parallel evaluation using multiple processes with Xvfb for headless rendering.
This is a corrected version of examples/evaluate_dreamer4_parallel.py that:
- Uses the corrected Dreamer4MinecraftAgent (not the broken DreamerV4Agent)
- Accepts a single checkpoint path (not three separate checkpoints)
- Adds proper error handling and timeouts
- Validates checkpoint existence before spawning workers
Usage:
Single process (for debugging):
python evaluate_dreamer4_minecraft.py --checkpoint ./checkpoints/dreamer4_minecraft.pt --n_episodes 10 --n_workers 1
Parallel headless evaluation:
python evaluate_dreamer4_minecraft.py --checkpoint ./checkpoints/dreamer4_minecraft.pt --n_episodes 100 --n_workers 8 --max_steps 36000 --output_json eval_results.json
1""" 2Evaluation script for Dreamer4 Minecraft agent using VPT infrastructure. 3 4Runs the trained Dreamer4 agent in MineRL HumanSurvival environment and 5tracks diamond tech tree progression. Supports parallel evaluation using 6multiple processes with Xvfb for headless rendering. 7 8This is a corrected version of examples/evaluate_dreamer4_parallel.py that: 9 - Uses the corrected Dreamer4MinecraftAgent (not the broken DreamerV4Agent) 10 - Accepts a single checkpoint path (not three separate checkpoints) 11 - Adds proper error handling and timeouts 12 - Validates checkpoint existence before spawning workers 13 14Usage: 15 # Single process (for debugging): 16 python evaluate_dreamer4_minecraft.py \ 17 --checkpoint ./checkpoints/dreamer4_minecraft.pt \ 18 --n_episodes 10 \ 19 --n_workers 1 20 21 # Parallel headless evaluation: 22 python evaluate_dreamer4_minecraft.py \ 23 --checkpoint ./checkpoints/dreamer4_minecraft.pt \ 24 --n_episodes 100 \ 25 --n_workers 8 \ 26 --max_steps 36000 \ 27 --output_json eval_results.json 28""" 29 30import os 31import sys 32import json 33import time 34import atexit 35import signal 36import argparse 37import subprocess 38import traceback 39from dataclasses import dataclass, field, asdict 40from typing import Dict, List 41from multiprocessing import Queue, get_context 42 43import numpy as np 44 45 46# Sentinel result payload signalling a worker exited before producing results. 47# Carries the worker id and the raw traceback so the parent can surface it 48# instead of silently hanging on result_queue.get() forever. 49@dataclass 50class WorkerFailure: 51 worker_id: int 52 stage: str # "startup", "reset", "step", etc. 53 error: str # repr(exception) 54 tb: str # traceback.format_exc() 55 56 57# Parent-side registry of live child processes so SIGTERM/SIGINT in the 58# driver reaps them. Populated in main(), consumed in _terminate_children. 59_CHILDREN = [] 60 61 62def _terminate_children(*_args): 63 """Kill any still-running worker processes. Safe to call multiple times.""" 64 for p in _CHILDREN: 65 try: 66 if p.is_alive(): 67 p.terminate() 68 except Exception: 69 pass 70 # Give them a moment to die, then hard-kill anything left. 71 deadline = time.time() + 5 72 for p in _CHILDREN: 73 try: 74 remaining = max(0.0, deadline - time.time()) 75 p.join(timeout=remaining) 76 if p.is_alive(): 77 p.kill() 78 p.join(timeout=2) 79 except Exception: 80 pass 81 82# Add paths for imports — this script lives inside Video-Pre-Training/, 83# so go up one level to find the dreamer4 package and minecraft_vpt_dataset. 84_this_dir = os.path.dirname(os.path.abspath(__file__)) # .../Video-Pre-Training 85_project_root = os.path.dirname(_this_dir) # .../dreamer4 (repo root) 86sys.path.insert(0, _this_dir) # VPT local imports (agent.py, lib/) 87sys.path.insert(0, _project_root) # dreamer4 package + minecraft_vpt_dataset 88 89 90# Diamond tech tree tasks in progression order. 91# Matches the tasks reported in the Dreamer4 paper. 92TECH_TREE_TASKS = [ 93 ("log", "log"), 94 ("planks", "planks"), 95 ("crafting_table", "crafting_table"), 96 ("wooden_pickaxe", "wooden_pickaxe"), 97 ("cobblestone", "cobblestone"), 98 ("stone_pickaxe", "stone_pickaxe"), 99 ("iron_ore", "iron_ore"), 100 ("furnace", "furnace"), 101 ("iron_ingot", "iron_ingot"), 102 ("iron_pickaxe", "iron_pickaxe"), 103 ("diamond", "diamond"), 104] 105 106 107@dataclass 108class EpisodeResult: 109 """Results from a single evaluation episode.""" 110 episode_id: int 111 worker_id: int 112 total_steps: int 113 total_reward: float 114 wall_time_seconds: float 115 tech_tree_steps: Dict[str, int] = field(default_factory=dict) 116 tech_tree_achieved: Dict[str, bool] = field(default_factory=dict) 117 118 119def check_inventory(info: dict, item_name: str) -> bool: 120 """Check if ``item_name`` is present in the player's MineRL inventory. 121 122 Args: 123 info: MineRL ``info`` dict returned by ``env.step``. 124 item_name: Minecraft item id (e.g. ``"log"``, ``"iron_pickaxe"``). 125 126 Returns: 127 True if the inventory entry for ``item_name`` exists with a 128 count greater than zero, False otherwise (including when the 129 inventory field is missing or not a dict). 130 """ 131 inventory = info.get("inventory", {}) 132 if isinstance(inventory, dict): 133 return inventory.get(item_name, 0) > 0 134 return False 135 136 137def run_worker( 138 worker_id: int, 139 episode_queue: Queue, 140 result_queue: Queue, 141 args_dict: dict, 142): 143 """Worker process for parallel evaluation. 144 145 Each worker: 146 1. Starts its own Xvfb display for headless rendering 147 2. Creates a MineRL HumanSurvival environment 148 3. Loads the Dreamer4 agent 149 4. Runs episodes from the queue until poison pill 150 5. Reports results 151 """ 152 display_num = 99 + worker_id 153 os.environ["DISPLAY"] = f":{display_num}" 154 155 xvfb_proc = None 156 env = None 157 158 def _emit_failure(stage, exc): 159 """Send traceback back to parent so it lands in sbatch stdout.""" 160 try: 161 result_queue.put(WorkerFailure( 162 worker_id=worker_id, 163 stage=stage, 164 error=repr(exc), 165 tb=traceback.format_exc(), 166 )) 167 except Exception: 168 pass 169 170 # Ensure SIGTERM from the parent kills the Minecraft JVMs before we exit. 171 # Without this, an aborted run leaves a tree of Java processes on HPC. 172 def _worker_signal_handler(signum, _frame): 173 raise SystemExit(f"worker {worker_id} received signal {signum}") 174 signal.signal(signal.SIGTERM, _worker_signal_handler) 175 176 if args_dict.get("headless", True): 177 try: 178 xvfb_proc = subprocess.Popen( 179 ["Xvfb", f":{display_num}", "-screen", "0", "1024x768x24", "-ac"], 180 stdout=subprocess.DEVNULL, 181 stderr=subprocess.DEVNULL, 182 ) 183 time.sleep(1) 184 except FileNotFoundError: 185 print(f"[Worker {worker_id}] WARNING: Xvfb not found, using existing display") 186 xvfb_proc = None 187 188 try: 189 try: 190 from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival 191 from agent import ENV_KWARGS 192 from dreamer4_minecraft_agent import Dreamer4MinecraftAgent 193 from minerl.env.malmo import InstanceManager, MinecraftInstance 194 except Exception as e: 195 # Import-time failure (e.g. the Python 3.9 / PEP 604 issue this file 196 # has tripped on before). Surface it through the result queue so the 197 # parent can abort instead of waiting an hour per missing result. 198 print(f"[Worker {worker_id}] Import error: {e}") 199 traceback.print_exc() 200 _emit_failure("import", e) 201 return 202 203 InstanceManager.configure_malmo_base_port(9000 + worker_id * 100) 204 205 # Reduce JVM heap from 4G to 2G per instance to prevent OOM kills 206 _orig_mc_init = MinecraftInstance.__init__ 207 _mc_max_mem = args_dict["minecraft_mem"] 208 209 def _mc_init_with_reduced_mem(self, port=None, existing=False, 210 status_dir=None, seed=None, 211 instance_id=None, max_mem=_mc_max_mem): 212 _orig_mc_init(self, port, existing, status_dir, seed, 213 instance_id, max_mem) 214 MinecraftInstance.__init__ = _mc_init_with_reduced_mem 215 216 try: 217 env = HumanSurvival(**ENV_KWARGS).make() 218 except Exception as e: 219 print(f"[Worker {worker_id}] env construction failed: {e}") 220 traceback.print_exc() 221 _emit_failure("env_make", e) 222 return 223 224 try: 225 agent = Dreamer4MinecraftAgent( 226 env, 227 checkpoint_path=args_dict["checkpoint"], 228 stochastic=not args_dict.get("deterministic", False), 229 ) 230 except Exception as e: 231 print(f"[Worker {worker_id}] agent construction failed: {e}") 232 traceback.print_exc() 233 _emit_failure("agent_init", e) 234 return 235 236 while True: 237 episode_id = episode_queue.get() 238 if episode_id is None: 239 break 240 241 print(f"[Worker {worker_id}] Starting episode {episode_id}") 242 t0 = time.time() 243 244 # MineRL can transiently fail, so give it a few tries with backoff 245 obs = None 246 for attempt in range(5): 247 try: 248 obs = env.reset() 249 break 250 except Exception as e: 251 print(f"[Worker {worker_id}] env.reset() attempt {attempt+1}/5 failed: {e}") 252 if attempt < 4: 253 try: 254 env.close() # have to make new or will just keep getting error 255 except Exception: 256 pass 257 time.sleep(10 * (attempt + 1)) 258 env = HumanSurvival(**ENV_KWARGS).make() 259 else: 260 raise 261 262 agent.reset() 263 264 total_reward = 0.0 265 tech_tree_steps = {name: -1 for name, _ in TECH_TREE_TASKS} 266 tech_tree_achieved = {name: False for name, _ in TECH_TREE_TASKS} 267 268 for step in range(args_dict["max_steps"]): 269 try: 270 action = agent.get_action(obs) 271 obs, reward, done, info = env.step(action) 272 except Exception as e: 273 print(f"[Worker {worker_id}] Error at step {step}: {e}") 274 break 275 276 # Detect MineRL returning a random obs due to a dead Minecraft process 277 if 'error' in info: 278 print(f"[Worker {worker_id}] Minecraft connection lost at " 279 f"step {step}, ending episode") 280 break 281 282 total_reward += reward 283 284 # Check tech tree progression 285 for task_name, item_key in TECH_TREE_TASKS: 286 if not tech_tree_achieved[task_name]: 287 if check_inventory(info, item_key): 288 tech_tree_achieved[task_name] = True 289 tech_tree_steps[task_name] = step 290 print( 291 f"[Worker {worker_id}] Episode {episode_id}: " 292 f"achieved '{task_name}' at step {step}" 293 ) 294 295 if done: 296 break 297 298 wall_time = time.time() - t0 299 result = EpisodeResult( 300 episode_id=episode_id, 301 worker_id=worker_id, 302 total_steps=step + 1, 303 total_reward=total_reward, 304 wall_time_seconds=wall_time, 305 tech_tree_steps=tech_tree_steps, 306 tech_tree_achieved=tech_tree_achieved, 307 ) 308 result_queue.put(result) 309 print( 310 f"[Worker {worker_id}] Episode {episode_id} done: " 311 f"{step+1} steps, reward={total_reward:.2f}, " 312 f"time={wall_time:.1f}s" 313 ) 314 315 except Exception as e: 316 print(f"[Worker {worker_id}] Fatal error: {e}") 317 traceback.print_exc() 318 _emit_failure("fatal", e) 319 finally: 320 # Close the MineRL env first — this shuts down the Minecraft JVM. 321 # Without it, the Java child of this worker survives and lingers 322 # as a zombie until the Slurm job hits its wall time. 323 if env is not None: 324 try: 325 env.close() 326 except Exception: 327 pass 328 if xvfb_proc is not None: 329 try: 330 xvfb_proc.terminate() 331 xvfb_proc.wait(timeout=5) 332 except Exception: 333 try: 334 xvfb_proc.kill() 335 except Exception: 336 pass 337 338 339def aggregate_results(results: List[EpisodeResult]) -> Dict: 340 """Compute aggregate statistics across all episodes.""" 341 n = len(results) 342 if n == 0: 343 return {} 344 345 stats = { 346 "n_episodes": n, 347 "mean_reward": float(np.mean([r.total_reward for r in results])), 348 "std_reward": float(np.std([r.total_reward for r in results])), 349 "mean_steps": float(np.mean([r.total_steps for r in results])), 350 "tasks": {}, 351 } 352 353 for task_name, _ in TECH_TREE_TASKS: 354 achieved = [r.tech_tree_achieved[task_name] for r in results] 355 success_rate = sum(achieved) / n 356 success_steps = [ 357 r.tech_tree_steps[task_name] 358 for r in results 359 if r.tech_tree_achieved[task_name] 360 ] 361 mean_steps = float(np.mean(success_steps)) if success_steps else float("inf") 362 363 stats["tasks"][task_name] = { 364 "success_rate": success_rate, 365 "n_successes": sum(achieved), 366 "mean_steps_to_success": mean_steps, 367 } 368 369 return stats 370 371 372def main(): 373 """CLI entry point: spawn worker processes and collect evaluation results. 374 375 Builds a shared ``episode_queue`` of ``n_episodes`` work items (plus 376 one ``None`` poison pill per worker), forks ``n_workers`` worker 377 processes running :func:`run_worker`, drains the ``result_queue`` 378 with a one-hour-per-result timeout, aggregates stats via 379 :func:`aggregate_results`, prints a tech-tree progression table, 380 and writes the full JSON report to ``--output_json``. 381 """ 382 parser = argparse.ArgumentParser("Dreamer4 Minecraft Evaluation") 383 parser.add_argument("--checkpoint", type=str, required=True, 384 help="Path to dreamer4_minecraft.pt") 385 parser.add_argument("--n_episodes", type=int, default=100) 386 parser.add_argument("--n_workers", type=int, default=1) 387 parser.add_argument("--minecraft_mem", type=str, default="4G", 388 help="Max JVM heap per Minecraft instance " 389 "(e.g., '2G', '1G'). Enables parallelism " 390 "with low VRAM.") 391 parser.add_argument("--max_steps", type=int, default=36000, 392 help="Max steps per episode (36000 = 30min at 20fps)") 393 parser.add_argument("--deterministic", action="store_true") 394 parser.add_argument("--no_headless", action="store_true", 395 help="Don't start Xvfb (use existing display)") 396 parser.add_argument("--output_json", type=str, default="eval_results.json") 397 args = parser.parse_args() 398 399 # Validate checkpoint exists 400 if not os.path.exists(args.checkpoint): 401 print(f"ERROR: Checkpoint not found: {args.checkpoint}") 402 sys.exit(1) 403 404 ctx = get_context("spawn") 405 episode_queue = ctx.Queue() 406 result_queue = ctx.Queue() 407 408 for i in range(args.n_episodes): 409 episode_queue.put(i) 410 for _ in range(args.n_workers): 411 episode_queue.put(None) 412 413 args_dict = { 414 "checkpoint": os.path.abspath(args.checkpoint), 415 "max_steps": args.max_steps, 416 "minecraft_mem": args.minecraft_mem, 417 "deterministic": args.deterministic, 418 "headless": not args.no_headless, 419 } 420 421 # Install the global cleanup hooks BEFORE any child starts. atexit covers 422 # normal exits, the signal handlers cover sbatch SIGTERM / user Ctrl-C. 423 # Whatever happens below, worker processes (and therefore their Xvfb and 424 # Minecraft JVM children) get reaped rather than becoming zombies. 425 atexit.register(_terminate_children) 426 signal.signal(signal.SIGTERM, _terminate_children) 427 signal.signal(signal.SIGINT, _terminate_children) 428 429 workers = [] 430 for w in range(args.n_workers): 431 p = ctx.Process(target=run_worker, args=(w, episode_queue, result_queue, args_dict)) 432 p.start() 433 workers.append(p) 434 _CHILDREN.append(p) 435 if w < args.n_workers - 1: 436 time.sleep(10) # Stagger Minecraft JVM launches to avoid resource contention 437 438 print(f"Started {args.n_workers} workers for {args.n_episodes} episodes") 439 440 results = [] 441 failures = [] 442 for _ in range(args.n_episodes): 443 # Poll result_queue with a short timeout so we can notice that every 444 # worker has died (e.g. import error) instead of hanging for an hour. 445 while True: 446 try: 447 item = result_queue.get(timeout=30) 448 except Exception: 449 # Nothing came through in 30s — are any workers still running? 450 if not any(p.is_alive() for p in workers): 451 print( 452 f"[main] All {len(workers)} workers have exited; " 453 f"aborting with {len(results)}/{args.n_episodes} results " 454 f"({len(failures)} worker failure(s) reported)." 455 ) 456 break 457 continue # workers still alive, keep waiting 458 break 459 460 if isinstance(item, WorkerFailure): 461 failures.append(item) 462 print( 463 f"[main] Worker {item.worker_id} failed during " 464 f"'{item.stage}': {item.error}" 465 ) 466 print(item.tb) 467 # If every worker has died, no point draining further — abort now. 468 if not any(p.is_alive() for p in workers): 469 print( 470 f"[main] No workers remain; aborting after " 471 f"{len(failures)} failure(s)." 472 ) 473 break 474 continue 475 if item is None: 476 # treat as timeout-with-dead-workers: already handled above 477 break 478 479 results.append(item) 480 if len(results) % 10 == 0: 481 print(f"Progress: {len(results)}/{args.n_episodes} episodes") 482 483 # Best-effort graceful shutdown; _terminate_children (via atexit) will 484 # force-kill anything still alive. 485 for p in workers: 486 p.join(timeout=30) 487 _terminate_children() 488 489 if failures and not results: 490 print("\n" + "=" * 60) 491 print("EVALUATION ABORTED — all workers failed before producing results") 492 print("=" * 60) 493 for f in failures[:3]: # show first few tracebacks 494 print(f"\n[Worker {f.worker_id}] stage={f.stage} error={f.error}") 495 print(f.tb) 496 sys.exit(2) 497 498 if results: 499 stats = aggregate_results(results) 500 501 print("\n" + "=" * 60) 502 print("EVALUATION RESULTS") 503 print("=" * 60) 504 print(f"Episodes: {stats['n_episodes']}") 505 print(f"Mean reward: {stats['mean_reward']:.2f} +/- {stats['std_reward']:.2f}") 506 print(f"Mean length: {stats['mean_steps']:.0f} steps") 507 print() 508 print(f"{'Task':<20} {'Success Rate':>12} {'N Success':>10} {'Mean Steps':>12}") 509 print("-" * 56) 510 for task_name, _ in TECH_TREE_TASKS: 511 t = stats["tasks"][task_name] 512 sr = f"{t['success_rate']*100:.1f}%" 513 ns = str(t["n_successes"]) 514 mean_steps = t["mean_steps_to_success"] 515 ms = f"{mean_steps:.0f}" if mean_steps != float("inf") else "N/A" 516 print(f"{task_name:<20} {sr:>12} {ns:>10} {ms:>12}") 517 518 output = { 519 "aggregate": stats, 520 "episodes": [asdict(r) for r in results], 521 "config": vars(args), 522 } 523 with open(args.output_json, "w", encoding="utf-8") as f: 524 json.dump(output, f, indent=2, default=str) 525 print(f"\nResults saved to {args.output_json}") 526 else: 527 print("No results collected!") 528 529 530if __name__ == "__main__": 531 main()
108@dataclass 109class EpisodeResult: 110 """Results from a single evaluation episode.""" 111 episode_id: int 112 worker_id: int 113 total_steps: int 114 total_reward: float 115 wall_time_seconds: float 116 tech_tree_steps: Dict[str, int] = field(default_factory=dict) 117 tech_tree_achieved: Dict[str, bool] = field(default_factory=dict)
Results from a single evaluation episode.
120def check_inventory(info: dict, item_name: str) -> bool: 121 """Check if ``item_name`` is present in the player's MineRL inventory. 122 123 Args: 124 info: MineRL ``info`` dict returned by ``env.step``. 125 item_name: Minecraft item id (e.g. ``"log"``, ``"iron_pickaxe"``). 126 127 Returns: 128 True if the inventory entry for ``item_name`` exists with a 129 count greater than zero, False otherwise (including when the 130 inventory field is missing or not a dict). 131 """ 132 inventory = info.get("inventory", {}) 133 if isinstance(inventory, dict): 134 return inventory.get(item_name, 0) > 0 135 return False
Check if item_name is present in the player's MineRL inventory.
Arguments:
- info: MineRL
infodict returned byenv.step. - item_name: Minecraft item id (e.g.
"log","iron_pickaxe").
Returns:
True if the inventory entry for
item_nameexists with a count greater than zero, False otherwise (including when the inventory field is missing or not a dict).
138def run_worker( 139 worker_id: int, 140 episode_queue: Queue, 141 result_queue: Queue, 142 args_dict: dict, 143): 144 """Worker process for parallel evaluation. 145 146 Each worker: 147 1. Starts its own Xvfb display for headless rendering 148 2. Creates a MineRL HumanSurvival environment 149 3. Loads the Dreamer4 agent 150 4. Runs episodes from the queue until poison pill 151 5. Reports results 152 """ 153 display_num = 99 + worker_id 154 os.environ["DISPLAY"] = f":{display_num}" 155 156 xvfb_proc = None 157 env = None 158 159 def _emit_failure(stage, exc): 160 """Send traceback back to parent so it lands in sbatch stdout.""" 161 try: 162 result_queue.put(WorkerFailure( 163 worker_id=worker_id, 164 stage=stage, 165 error=repr(exc), 166 tb=traceback.format_exc(), 167 )) 168 except Exception: 169 pass 170 171 # Ensure SIGTERM from the parent kills the Minecraft JVMs before we exit. 172 # Without this, an aborted run leaves a tree of Java processes on HPC. 173 def _worker_signal_handler(signum, _frame): 174 raise SystemExit(f"worker {worker_id} received signal {signum}") 175 signal.signal(signal.SIGTERM, _worker_signal_handler) 176 177 if args_dict.get("headless", True): 178 try: 179 xvfb_proc = subprocess.Popen( 180 ["Xvfb", f":{display_num}", "-screen", "0", "1024x768x24", "-ac"], 181 stdout=subprocess.DEVNULL, 182 stderr=subprocess.DEVNULL, 183 ) 184 time.sleep(1) 185 except FileNotFoundError: 186 print(f"[Worker {worker_id}] WARNING: Xvfb not found, using existing display") 187 xvfb_proc = None 188 189 try: 190 try: 191 from minerl.herobraine.env_specs.human_survival_specs import HumanSurvival 192 from agent import ENV_KWARGS 193 from dreamer4_minecraft_agent import Dreamer4MinecraftAgent 194 from minerl.env.malmo import InstanceManager, MinecraftInstance 195 except Exception as e: 196 # Import-time failure (e.g. the Python 3.9 / PEP 604 issue this file 197 # has tripped on before). Surface it through the result queue so the 198 # parent can abort instead of waiting an hour per missing result. 199 print(f"[Worker {worker_id}] Import error: {e}") 200 traceback.print_exc() 201 _emit_failure("import", e) 202 return 203 204 InstanceManager.configure_malmo_base_port(9000 + worker_id * 100) 205 206 # Reduce JVM heap from 4G to 2G per instance to prevent OOM kills 207 _orig_mc_init = MinecraftInstance.__init__ 208 _mc_max_mem = args_dict["minecraft_mem"] 209 210 def _mc_init_with_reduced_mem(self, port=None, existing=False, 211 status_dir=None, seed=None, 212 instance_id=None, max_mem=_mc_max_mem): 213 _orig_mc_init(self, port, existing, status_dir, seed, 214 instance_id, max_mem) 215 MinecraftInstance.__init__ = _mc_init_with_reduced_mem 216 217 try: 218 env = HumanSurvival(**ENV_KWARGS).make() 219 except Exception as e: 220 print(f"[Worker {worker_id}] env construction failed: {e}") 221 traceback.print_exc() 222 _emit_failure("env_make", e) 223 return 224 225 try: 226 agent = Dreamer4MinecraftAgent( 227 env, 228 checkpoint_path=args_dict["checkpoint"], 229 stochastic=not args_dict.get("deterministic", False), 230 ) 231 except Exception as e: 232 print(f"[Worker {worker_id}] agent construction failed: {e}") 233 traceback.print_exc() 234 _emit_failure("agent_init", e) 235 return 236 237 while True: 238 episode_id = episode_queue.get() 239 if episode_id is None: 240 break 241 242 print(f"[Worker {worker_id}] Starting episode {episode_id}") 243 t0 = time.time() 244 245 # MineRL can transiently fail, so give it a few tries with backoff 246 obs = None 247 for attempt in range(5): 248 try: 249 obs = env.reset() 250 break 251 except Exception as e: 252 print(f"[Worker {worker_id}] env.reset() attempt {attempt+1}/5 failed: {e}") 253 if attempt < 4: 254 try: 255 env.close() # have to make new or will just keep getting error 256 except Exception: 257 pass 258 time.sleep(10 * (attempt + 1)) 259 env = HumanSurvival(**ENV_KWARGS).make() 260 else: 261 raise 262 263 agent.reset() 264 265 total_reward = 0.0 266 tech_tree_steps = {name: -1 for name, _ in TECH_TREE_TASKS} 267 tech_tree_achieved = {name: False for name, _ in TECH_TREE_TASKS} 268 269 for step in range(args_dict["max_steps"]): 270 try: 271 action = agent.get_action(obs) 272 obs, reward, done, info = env.step(action) 273 except Exception as e: 274 print(f"[Worker {worker_id}] Error at step {step}: {e}") 275 break 276 277 # Detect MineRL returning a random obs due to a dead Minecraft process 278 if 'error' in info: 279 print(f"[Worker {worker_id}] Minecraft connection lost at " 280 f"step {step}, ending episode") 281 break 282 283 total_reward += reward 284 285 # Check tech tree progression 286 for task_name, item_key in TECH_TREE_TASKS: 287 if not tech_tree_achieved[task_name]: 288 if check_inventory(info, item_key): 289 tech_tree_achieved[task_name] = True 290 tech_tree_steps[task_name] = step 291 print( 292 f"[Worker {worker_id}] Episode {episode_id}: " 293 f"achieved '{task_name}' at step {step}" 294 ) 295 296 if done: 297 break 298 299 wall_time = time.time() - t0 300 result = EpisodeResult( 301 episode_id=episode_id, 302 worker_id=worker_id, 303 total_steps=step + 1, 304 total_reward=total_reward, 305 wall_time_seconds=wall_time, 306 tech_tree_steps=tech_tree_steps, 307 tech_tree_achieved=tech_tree_achieved, 308 ) 309 result_queue.put(result) 310 print( 311 f"[Worker {worker_id}] Episode {episode_id} done: " 312 f"{step+1} steps, reward={total_reward:.2f}, " 313 f"time={wall_time:.1f}s" 314 ) 315 316 except Exception as e: 317 print(f"[Worker {worker_id}] Fatal error: {e}") 318 traceback.print_exc() 319 _emit_failure("fatal", e) 320 finally: 321 # Close the MineRL env first — this shuts down the Minecraft JVM. 322 # Without it, the Java child of this worker survives and lingers 323 # as a zombie until the Slurm job hits its wall time. 324 if env is not None: 325 try: 326 env.close() 327 except Exception: 328 pass 329 if xvfb_proc is not None: 330 try: 331 xvfb_proc.terminate() 332 xvfb_proc.wait(timeout=5) 333 except Exception: 334 try: 335 xvfb_proc.kill() 336 except Exception: 337 pass
Worker process for parallel evaluation.
Each worker:
- Starts its own Xvfb display for headless rendering
- Creates a MineRL HumanSurvival environment
- Loads the Dreamer4 agent
- Runs episodes from the queue until poison pill
- Reports results
340def aggregate_results(results: List[EpisodeResult]) -> Dict: 341 """Compute aggregate statistics across all episodes.""" 342 n = len(results) 343 if n == 0: 344 return {} 345 346 stats = { 347 "n_episodes": n, 348 "mean_reward": float(np.mean([r.total_reward for r in results])), 349 "std_reward": float(np.std([r.total_reward for r in results])), 350 "mean_steps": float(np.mean([r.total_steps for r in results])), 351 "tasks": {}, 352 } 353 354 for task_name, _ in TECH_TREE_TASKS: 355 achieved = [r.tech_tree_achieved[task_name] for r in results] 356 success_rate = sum(achieved) / n 357 success_steps = [ 358 r.tech_tree_steps[task_name] 359 for r in results 360 if r.tech_tree_achieved[task_name] 361 ] 362 mean_steps = float(np.mean(success_steps)) if success_steps else float("inf") 363 364 stats["tasks"][task_name] = { 365 "success_rate": success_rate, 366 "n_successes": sum(achieved), 367 "mean_steps_to_success": mean_steps, 368 } 369 370 return stats
Compute aggregate statistics across all episodes.
373def main(): 374 """CLI entry point: spawn worker processes and collect evaluation results. 375 376 Builds a shared ``episode_queue`` of ``n_episodes`` work items (plus 377 one ``None`` poison pill per worker), forks ``n_workers`` worker 378 processes running :func:`run_worker`, drains the ``result_queue`` 379 with a one-hour-per-result timeout, aggregates stats via 380 :func:`aggregate_results`, prints a tech-tree progression table, 381 and writes the full JSON report to ``--output_json``. 382 """ 383 parser = argparse.ArgumentParser("Dreamer4 Minecraft Evaluation") 384 parser.add_argument("--checkpoint", type=str, required=True, 385 help="Path to dreamer4_minecraft.pt") 386 parser.add_argument("--n_episodes", type=int, default=100) 387 parser.add_argument("--n_workers", type=int, default=1) 388 parser.add_argument("--minecraft_mem", type=str, default="4G", 389 help="Max JVM heap per Minecraft instance " 390 "(e.g., '2G', '1G'). Enables parallelism " 391 "with low VRAM.") 392 parser.add_argument("--max_steps", type=int, default=36000, 393 help="Max steps per episode (36000 = 30min at 20fps)") 394 parser.add_argument("--deterministic", action="store_true") 395 parser.add_argument("--no_headless", action="store_true", 396 help="Don't start Xvfb (use existing display)") 397 parser.add_argument("--output_json", type=str, default="eval_results.json") 398 args = parser.parse_args() 399 400 # Validate checkpoint exists 401 if not os.path.exists(args.checkpoint): 402 print(f"ERROR: Checkpoint not found: {args.checkpoint}") 403 sys.exit(1) 404 405 ctx = get_context("spawn") 406 episode_queue = ctx.Queue() 407 result_queue = ctx.Queue() 408 409 for i in range(args.n_episodes): 410 episode_queue.put(i) 411 for _ in range(args.n_workers): 412 episode_queue.put(None) 413 414 args_dict = { 415 "checkpoint": os.path.abspath(args.checkpoint), 416 "max_steps": args.max_steps, 417 "minecraft_mem": args.minecraft_mem, 418 "deterministic": args.deterministic, 419 "headless": not args.no_headless, 420 } 421 422 # Install the global cleanup hooks BEFORE any child starts. atexit covers 423 # normal exits, the signal handlers cover sbatch SIGTERM / user Ctrl-C. 424 # Whatever happens below, worker processes (and therefore their Xvfb and 425 # Minecraft JVM children) get reaped rather than becoming zombies. 426 atexit.register(_terminate_children) 427 signal.signal(signal.SIGTERM, _terminate_children) 428 signal.signal(signal.SIGINT, _terminate_children) 429 430 workers = [] 431 for w in range(args.n_workers): 432 p = ctx.Process(target=run_worker, args=(w, episode_queue, result_queue, args_dict)) 433 p.start() 434 workers.append(p) 435 _CHILDREN.append(p) 436 if w < args.n_workers - 1: 437 time.sleep(10) # Stagger Minecraft JVM launches to avoid resource contention 438 439 print(f"Started {args.n_workers} workers for {args.n_episodes} episodes") 440 441 results = [] 442 failures = [] 443 for _ in range(args.n_episodes): 444 # Poll result_queue with a short timeout so we can notice that every 445 # worker has died (e.g. import error) instead of hanging for an hour. 446 while True: 447 try: 448 item = result_queue.get(timeout=30) 449 except Exception: 450 # Nothing came through in 30s — are any workers still running? 451 if not any(p.is_alive() for p in workers): 452 print( 453 f"[main] All {len(workers)} workers have exited; " 454 f"aborting with {len(results)}/{args.n_episodes} results " 455 f"({len(failures)} worker failure(s) reported)." 456 ) 457 break 458 continue # workers still alive, keep waiting 459 break 460 461 if isinstance(item, WorkerFailure): 462 failures.append(item) 463 print( 464 f"[main] Worker {item.worker_id} failed during " 465 f"'{item.stage}': {item.error}" 466 ) 467 print(item.tb) 468 # If every worker has died, no point draining further — abort now. 469 if not any(p.is_alive() for p in workers): 470 print( 471 f"[main] No workers remain; aborting after " 472 f"{len(failures)} failure(s)." 473 ) 474 break 475 continue 476 if item is None: 477 # treat as timeout-with-dead-workers: already handled above 478 break 479 480 results.append(item) 481 if len(results) % 10 == 0: 482 print(f"Progress: {len(results)}/{args.n_episodes} episodes") 483 484 # Best-effort graceful shutdown; _terminate_children (via atexit) will 485 # force-kill anything still alive. 486 for p in workers: 487 p.join(timeout=30) 488 _terminate_children() 489 490 if failures and not results: 491 print("\n" + "=" * 60) 492 print("EVALUATION ABORTED — all workers failed before producing results") 493 print("=" * 60) 494 for f in failures[:3]: # show first few tracebacks 495 print(f"\n[Worker {f.worker_id}] stage={f.stage} error={f.error}") 496 print(f.tb) 497 sys.exit(2) 498 499 if results: 500 stats = aggregate_results(results) 501 502 print("\n" + "=" * 60) 503 print("EVALUATION RESULTS") 504 print("=" * 60) 505 print(f"Episodes: {stats['n_episodes']}") 506 print(f"Mean reward: {stats['mean_reward']:.2f} +/- {stats['std_reward']:.2f}") 507 print(f"Mean length: {stats['mean_steps']:.0f} steps") 508 print() 509 print(f"{'Task':<20} {'Success Rate':>12} {'N Success':>10} {'Mean Steps':>12}") 510 print("-" * 56) 511 for task_name, _ in TECH_TREE_TASKS: 512 t = stats["tasks"][task_name] 513 sr = f"{t['success_rate']*100:.1f}%" 514 ns = str(t["n_successes"]) 515 mean_steps = t["mean_steps_to_success"] 516 ms = f"{mean_steps:.0f}" if mean_steps != float("inf") else "N/A" 517 print(f"{task_name:<20} {sr:>12} {ns:>10} {ms:>12}") 518 519 output = { 520 "aggregate": stats, 521 "episodes": [asdict(r) for r in results], 522 "config": vars(args), 523 } 524 with open(args.output_json, "w", encoding="utf-8") as f: 525 json.dump(output, f, indent=2, default=str) 526 print(f"\nResults saved to {args.output_json}") 527 else: 528 print("No results collected!")
CLI entry point: spawn worker processes and collect evaluation results.
Builds a shared episode_queue of n_episodes work items (plus
one None poison pill per worker), forks n_workers worker
processes running run_worker(), drains the result_queue
with a one-hour-per-result timeout, aggregates stats via
aggregate_results(), prints a tech-tree progression table,
and writes the full JSON report to --output_json.