diff --git a/training/scripts/self_play.py b/training/scripts/self_play.py index d05c26d..688f3d5 100644 --- a/training/scripts/self_play.py +++ b/training/scripts/self_play.py @@ -1,20 +1,29 @@ #!/usr/bin/env python3 """ -self_play.py — Self-play training data generator. +self_play.py — Multi-tier self-play training data generator. -The fine-tuned model generates its own training data by: -1. Generating diverse edge-case prompts it's uncertain about -2. Attempting commands via RCON -3. Self-correcting on errors -4. Saving successful sequences as training examples +Three tiers of self-play, each teaching different skills: -This creates a closed-loop learning system with RCON as ground truth. -No API cost — runs entirely on the local model. +Tier 1 — Command drills: + Feed known prompts, execute commands via RCON, validate syntax. + Teaches: accurate command generation. + Usage: --tier 1 --rounds 50 -Usage: - python3 training/scripts/self_play.py --rounds 100 --model qwen3.5-9b-mc-v4 - python3 training/scripts/self_play.py --rounds 50 --dry-run - python3 training/scripts/self_play.py --rounds 200 --focus enchantments +Tier 2 — Single-shot self-critique: + Model generates BOTH the prompt AND response in one call. + Teaches: edge-case awareness, self-evaluation. + Usage: --tier 2 --rounds 50 + +Tier 3 — Adversarial self-play: + Session A generates a challenging prompt. Fresh Session B responds. + RCON validates. Model can't cheat by knowing both sides. + Teaches: robustness, generalization, error correction. + Usage: --tier 3 --rounds 50 + +All tiers: + --tier all --rounds 50 (runs ~17 rounds of each) + +No API cost — runs entirely on the local model with RCON as ground truth. """ import argparse @@ -182,6 +191,120 @@ Analyze the error and return a corrected command. Return JSON: {"commands": ["corrected_cmd"], "reasoning": "what was wrong and how you fixed it"}""" +# --- Tier 1: Command drills --- + +def run_tier1_drill(model, ollama_url, rcon_host, rcon_port, rcon_pass, max_retries=2): + """Pick a random prompt from seed dataset, generate commands, validate via RCON.""" + seed_path = ROOT / "data" / "processed" / "seed_dataset.jsonl" + with open(seed_path) as f: + lines = [l for l in f if l.strip()] + line = random.choice(lines) + ex = json.loads(line) + prompt = ex["input"]["user_message"] + # Only drill command_gen examples + if ex.get("category") not in ("command_gen",): + return None + + trace = attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, max_retries) + trace["tier"] = 1 + trace["original_commands"] = ex.get("output", {}).get("commands", []) + return trace + + +# --- Tier 2: Single-shot self-critique --- + +SELF_CRITIQUE_SYSTEM = """You are a Minecraft 1.21 AI training data generator AND command translator. + +Your task: generate a challenging Minecraft player request, then respond to it yourself. +Focus on edge cases you might get wrong: unusual items, complex enchantments, execute chains, ambiguous phrasing. + +Return JSON: +{ + "generated_prompt": "the player request you invented (must start with 'sudo ' or 'pray ')", + "difficulty": "what makes this tricky", + "commands": ["cmd1", "cmd2"], + "reasoning": "why these commands are correct", + "message": "God message if pray, empty string if sudo" +} + +Commands use minecraft: prefix. Enchantments: item[enchantments={name:level}]. +Effects: effect give minecraft: . +Player: slingshooter08. Do NOT start commands with /.""" + + +def run_tier2_selfcritique(model, ollama_url, rcon_host, rcon_port, rcon_pass, category=None): + """Model generates a prompt AND responds in one shot, then RCON validates.""" + focus = "" + if category: + focus = f"\nFocus area: {category}. Generate a prompt specifically testing {category}." + + try: + raw = llm_call( + model=model, + system=SELF_CRITIQUE_SYSTEM + focus, + user="Generate one challenging Minecraft request and your response. Be creative — pick something you might get wrong.", + ollama_url=ollama_url, + temperature=0.9, + max_tokens=500, + fmt="json", + ) + result = json.loads(raw) + except: + match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '') + if match: + try: + result = json.loads(match.group()) + except: + return None + else: + return None + + prompt = result.get("generated_prompt", "") + commands = result.get("commands") or [] + message = result.get("message") or "" + reasoning = result.get("reasoning") or "" + difficulty = result.get("difficulty") or "" + + if not prompt: + return None + + trace = { + "prompt": prompt, + "mode": "god" if prompt.lower().startswith("pray ") else "sudo", + "tier": 2, + "difficulty_note": difficulty, + "attempts": [], + "final_success": False, + "self_corrected": False, + } + + if not commands: + trace["attempts"].append({ + "commands": [], "reasoning": reasoning, "message": message, + "rcon_results": [], "all_success": True, + }) + trace["final_success"] = True + return trace + + # Validate via RCON + rcon_results = [] + all_success = True + for cmd in commands: + success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass) + rcon_results.append({"command": cmd, "success": success, "result": rcon_result}) + if not success: + all_success = False + + trace["attempts"].append({ + "commands": commands, "reasoning": reasoning, "message": message, + "rcon_results": rcon_results, "all_success": all_success, + }) + trace["final_success"] = all_success + return trace + + +# --- Tier 3: Adversarial self-play (original generate_prompts + attempt_command) --- + def generate_prompts(model, ollama_url, category=None): """Use the model to generate edge-case prompts for itself.""" if category: @@ -436,16 +559,20 @@ def main(): parser.add_argument("--rcon-port", type=int, default=25578) parser.add_argument("--rcon-pass", default="REDACTED_RCON") parser.add_argument("--rounds", type=int, default=50) + parser.add_argument("--tier", default="all", choices=["1", "2", "3", "all"]) parser.add_argument("--focus", default=None, choices=list(EXPLORATION_CATEGORIES.keys())) parser.add_argument("--output", default=str(OUTPUT)) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--max-retries", type=int, default=2) args = parser.parse_args() + tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)] + print(f"Self-play training data generator") print(f" Model: {args.model}") print(f" RCON: {args.rcon_host}:{args.rcon_port}") print(f" Rounds: {args.rounds}") + print(f" Tiers: {tiers}") print(f" Focus: {args.focus or 'all categories'}") print(f" Max retries: {args.max_retries}") print(f" Output: {args.output}") @@ -454,60 +581,129 @@ def main(): stats = { "rounds": 0, "prompts_generated": 0, "attempts": 0, "success_first_try": 0, "self_corrected": 0, "failed": 0, - "training_examples": 0, "by_category": {}, + "training_examples": 0, "by_tier": {1: 0, 2: 0, 3: 0}, "by_category": {}, } for round_num in range(args.rounds): - print(f"\n--- Round {round_num + 1}/{args.rounds} ---") - - # Generate prompts - prompts = generate_prompts(args.model, args.ollama_url, args.focus) - if not prompts: - print(" No prompts generated, skipping round") - continue - - stats["prompts_generated"] += len(prompts) - print(f" Generated {len(prompts)} prompts") - - for p in prompts: - prompt = p["prompt"] - cat = p["category"] - stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0}) - stats["by_category"][cat]["total"] += 1 - stats["attempts"] += 1 - - print(f" [{cat}] {prompt[:60]:60}", end="") + tier = tiers[round_num % len(tiers)] + print(f"\n--- Round {round_num + 1}/{args.rounds} [Tier {tier}] ---") + if tier == 1: + # Command drill: pick random seed example, try to execute if args.dry_run: - print(" [DRY RUN]") + print(" [DRY RUN] Would drill a random seed prompt via RCON") + stats["rounds"] += 1 continue - trace = attempt_command( - args.model, args.ollama_url, prompt, - args.rcon_host, args.rcon_port, args.rcon_pass, - max_retries=args.max_retries, + for _ in range(5): # 5 drills per round + trace = run_tier1_drill( + args.model, args.ollama_url, + args.rcon_host, args.rcon_port, args.rcon_pass, + args.max_retries, + ) + if trace is None: + continue + stats["attempts"] += 1 + stats["by_tier"][1] += 1 + prompt = trace["prompt"] + print(f" [drill] {prompt[:55]:55}", end="") + if trace["final_success"] and not trace["self_corrected"]: + stats["success_first_try"] += 1 + n_cmds = len(trace["attempts"][0].get("commands", [])) + print(f" OK ({n_cmds} cmds)") + elif trace.get("self_corrected"): + stats["self_corrected"] += 1 + print(f" CORRECTED ({len(trace['attempts'])} attempts)") + else: + stats["failed"] += 1 + print(f" FAILED") + examples = trace_to_training(trace) + all_examples.extend(examples) + stats["training_examples"] += len(examples) + time.sleep(0.5) + + elif tier == 2: + # Self-critique: model generates prompt + response, RCON validates + cats = [args.focus] if args.focus else random.sample( + list(EXPLORATION_CATEGORIES.keys()), min(3, len(EXPLORATION_CATEGORIES)) ) + for cat in cats: + if args.dry_run: + print(f" [DRY RUN] Would self-critique on {cat}") + continue - if trace["final_success"] and not trace["self_corrected"]: - stats["success_first_try"] += 1 - stats["by_category"][cat]["success"] += 1 - n_cmds = len(trace["attempts"][0].get("commands", [])) - print(f" OK ({n_cmds} cmds)") - elif trace["self_corrected"]: - stats["self_corrected"] += 1 - stats["by_category"][cat]["corrected"] += 1 - print(f" CORRECTED ({len(trace['attempts'])} attempts)") - else: - stats["failed"] += 1 - print(f" FAILED") + trace = run_tier2_selfcritique( + args.model, args.ollama_url, + args.rcon_host, args.rcon_port, args.rcon_pass, + category=cat, + ) + if trace is None: + continue + stats["attempts"] += 1 + stats["by_tier"][2] += 1 + prompt = trace["prompt"] + diff = trace.get("difficulty_note", "")[:30] + print(f" [self-critique:{cat[:12]}] {prompt[:40]:40} ({diff})", end="") + if trace["final_success"]: + stats["success_first_try"] += 1 + n_cmds = len(trace["attempts"][0].get("commands", [])) + print(f" OK ({n_cmds} cmds)") + else: + stats["failed"] += 1 + print(f" FAILED (self-generated bad commands)") + examples = trace_to_training(trace) + all_examples.extend(examples) + stats["training_examples"] += len(examples) + time.sleep(1) - # Convert to training examples - examples = trace_to_training(trace) - all_examples.extend(examples) - stats["training_examples"] += len(examples) + elif tier == 3: + # Adversarial: generate prompts in Session A, respond in fresh Session B + prompts = generate_prompts(args.model, args.ollama_url, args.focus) + if not prompts: + print(" No prompts generated, skipping round") + stats["rounds"] += 1 + continue - # Brief pause between attempts - time.sleep(1) + stats["prompts_generated"] += len(prompts) + print(f" Generated {len(prompts)} adversarial prompts") + + for p in prompts: + prompt = p["prompt"] + cat = p["category"] + stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0}) + stats["by_category"][cat]["total"] += 1 + stats["attempts"] += 1 + stats["by_tier"][3] += 1 + + print(f" [adversarial:{cat[:12]}] {prompt[:48]:48}", end="") + + if args.dry_run: + print(" [DRY RUN]") + continue + + trace = attempt_command( + args.model, args.ollama_url, prompt, + args.rcon_host, args.rcon_port, args.rcon_pass, + max_retries=args.max_retries, + ) + + if trace["final_success"] and not trace["self_corrected"]: + stats["success_first_try"] += 1 + stats["by_category"][cat]["success"] += 1 + n_cmds = len(trace["attempts"][0].get("commands", [])) + print(f" OK ({n_cmds} cmds)") + elif trace["self_corrected"]: + stats["self_corrected"] += 1 + stats["by_category"][cat]["corrected"] += 1 + print(f" CORRECTED ({len(trace['attempts'])} attempts)") + else: + stats["failed"] += 1 + print(f" FAILED") + + examples = trace_to_training(trace) + all_examples.extend(examples) + stats["training_examples"] += len(examples) + time.sleep(1) stats["rounds"] += 1 @@ -530,6 +726,11 @@ def main(): print(f" Failed: {stats['failed']}") print(f" Training examples:{stats['training_examples']}") + print(f"\n By tier:") + for t in sorted(stats["by_tier"]): + labels = {1: "Command drills", 2: "Self-critique", 3: "Adversarial"} + print(f" Tier {t} ({labels[t]:16}): {stats['by_tier'][t]} attempts") + if stats["by_category"]: print(f"\n By category:") for cat, s in sorted(stats["by_category"].items()):