#!/usr/bin/env python3 """ train_combat.py — Train a small policy network for Minecraft combat via PPO. The agent learns to fight, flee, eat, and survive in a hostile Minecraft world. Uses the MinecraftCombatEnv gymnasium wrapper which controls a mineflayer bot. Usage: # Install deps first: pip install gymnasium stable-baselines3 torch # Train (on steel141 with mc-train conda env): python3 training/rl/train_combat.py # Train with custom settings: python3 training/rl/train_combat.py --timesteps 50000 --host 192.168.0.244 --port 25568 # Evaluate a trained model: python3 training/rl/train_combat.py --eval --model training/rl/checkpoints/combat_ppo.zip """ import argparse import os import sys from pathlib import Path # Add project root to path ROOT = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(ROOT)) def train(args): from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback from training.rl.minecraft_env import MinecraftCombatEnv print(f"=== Minecraft RL Combat Training ===") print(f"Host: {args.host}:{args.port}") print(f"Timesteps: {args.timesteps}") print(f"Policy: MlpPolicy (3-layer MLP)") print() # Create environment env = MinecraftCombatEnv( host=args.host, port=args.port, username=f"RLBot_{os.getpid() % 100}", max_steps=args.max_steps, render_mode="human" if args.verbose else None, ) # Checkpointing ckpt_dir = ROOT / "training" / "rl" / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) checkpoint_cb = CheckpointCallback( save_freq=args.save_freq, save_path=str(ckpt_dir), name_prefix="combat_ppo", ) # Check for existing checkpoint to resume from latest_ckpt = None if ckpt_dir.exists(): ckpts = sorted(ckpt_dir.glob("combat_ppo_*.zip"), key=lambda p: p.stat().st_mtime) if ckpts: latest_ckpt = str(ckpts[-1]) print(f"RESUMING from: {latest_ckpt}") if latest_ckpt: # Load existing model and continue training model = PPO.load( latest_ckpt, env=env, tensorboard_log=str(ckpt_dir / "tb_logs"), ) model.learning_rate = 3e-4 # can adjust between runs else: # Fresh model model = PPO( "MlpPolicy", env, verbose=1, learning_rate=3e-4, n_steps=256, # collect 256 steps before update batch_size=64, n_epochs=4, gamma=0.99, # discount factor gae_lambda=0.95, clip_range=0.2, ent_coef=0.01, # entropy bonus for exploration policy_kwargs={ "net_arch": [64, 64], # 2 hidden layers of 64 units }, tensorboard_log=str(ckpt_dir / "tb_logs"), ) print(f"Policy network params: {sum(p.numel() for p in model.policy.parameters()):,}") print(f"Training for {args.timesteps} timesteps...") print() try: model.learn( total_timesteps=args.timesteps, callback=checkpoint_cb, progress_bar=True, ) except KeyboardInterrupt: print("\nTraining interrupted.") # Save final model final_path = ckpt_dir / "combat_ppo_final.zip" model.save(str(final_path)) print(f"\nModel saved to {final_path}") env.close() def evaluate(args): from stable_baselines3 import PPO from training.rl.minecraft_env import MinecraftCombatEnv print(f"=== Evaluating {args.model} ===") env = MinecraftCombatEnv( host=args.host, port=args.port, username="RLBot_eval", max_steps=args.max_steps, render_mode="human", ) model = PPO.load(args.model) total_reward = 0 total_kills = 0 episodes = args.eval_episodes for ep in range(episodes): obs, info = env.reset() ep_reward = 0 done = False while not done: action, _ = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action) ep_reward += reward done = terminated or truncated total_reward += ep_reward total_kills += info.get("kills", 0) print(f" Episode {ep+1}: reward={ep_reward:.1f} kills={info.get('kills', 0)} steps={info.get('step', 0)}") print(f"\nAverage: reward={total_reward/episodes:.1f} kills={total_kills/episodes:.1f}") env.close() def main(): parser = argparse.ArgumentParser(description="Minecraft RL Combat Training") parser.add_argument("--host", default="192.168.0.244") parser.add_argument("--port", type=int, default=25568) parser.add_argument("--timesteps", type=int, default=10000) parser.add_argument("--max-steps", type=int, default=300) parser.add_argument("--save-freq", type=int, default=2000) parser.add_argument("--verbose", action="store_true") parser.add_argument("--eval", action="store_true") parser.add_argument("--eval-episodes", type=int, default=5) parser.add_argument("--model", default="training/rl/checkpoints/combat_ppo_final.zip") args = parser.parse_args() if args.eval: evaluate(args) else: train(args) if __name__ == "__main__": main()