0f043384e5
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
745 lines
28 KiB
Python
745 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
self_play.py — Multi-tier self-play training data generator.
|
|
|
|
Three tiers of self-play, each teaching different skills:
|
|
|
|
Tier 1 — Command drills:
|
|
Feed known prompts, execute commands via RCON, validate syntax.
|
|
Teaches: accurate command generation.
|
|
Usage: --tier 1 --rounds 50
|
|
|
|
Tier 2 — Single-shot self-critique:
|
|
Model generates BOTH the prompt AND response in one call.
|
|
Teaches: edge-case awareness, self-evaluation.
|
|
Usage: --tier 2 --rounds 50
|
|
|
|
Tier 3 — Adversarial self-play:
|
|
Session A generates a challenging prompt. Fresh Session B responds.
|
|
RCON validates. Model can't cheat by knowing both sides.
|
|
Teaches: robustness, generalization, error correction.
|
|
Usage: --tier 3 --rounds 50
|
|
|
|
All tiers:
|
|
--tier all --rounds 50 (runs ~17 rounds of each)
|
|
|
|
No API cost — runs entirely on the local model with RCON as ground truth.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import random
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent.parent
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl"
|
|
|
|
# Module-level API key, set from args in main()
|
|
_API_KEY = None
|
|
|
|
# --- RCON (persistent connection) ---
|
|
|
|
from agent.tools.persistent_rcon import get_rcon
|
|
|
|
def rcon_command(cmd, host, port, password):
|
|
"""Execute via persistent RCON, return (success, result_text)."""
|
|
try:
|
|
conn = get_rcon(host, port, password)
|
|
result = conn.command(cmd)
|
|
|
|
error_patterns = [
|
|
"Unknown or incomplete command",
|
|
"Incorrect argument",
|
|
"Expected whitespace",
|
|
"Unknown item",
|
|
"Invalid or unknown",
|
|
"Expected",
|
|
"<--[HERE]",
|
|
]
|
|
is_error = any(p.lower() in result.lower() for p in error_patterns)
|
|
# Benign non-errors
|
|
if "no player was found" in result.lower():
|
|
is_error = False
|
|
return (not is_error, result.strip())
|
|
except Exception as e:
|
|
return (False, f"RCON error: {e}")
|
|
|
|
|
|
# --- LLM calls ---
|
|
|
|
def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None, api_key=None):
|
|
"""Call Ollama (or gateway) and return content with think blocks stripped."""
|
|
payload = {
|
|
"model": model,
|
|
"messages": [
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": user},
|
|
],
|
|
"stream": False,
|
|
"options": {"temperature": temperature, "num_predict": max_tokens},
|
|
}
|
|
if fmt:
|
|
payload["format"] = fmt
|
|
headers = {"Content-Type": "application/json"}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
r = requests.post(f"{ollama_url}/api/chat", json=payload, headers=headers, timeout=120)
|
|
r.raise_for_status()
|
|
content = r.json()["message"]["content"]
|
|
# Strip think blocks
|
|
content = re.sub(r'<think>[\s\S]*?</think>\s*', '', content)
|
|
return content.strip()
|
|
|
|
|
|
# --- Prompt generation categories ---
|
|
|
|
EXPLORATION_CATEGORIES = {
|
|
"enchantment_combos": {
|
|
"instruction": """Generate 5 Minecraft chat messages that test unusual or edge-case enchantment combinations.
|
|
Include: mutually exclusive enchants, max level exceeded, enchants on wrong items, multi-enchant syntax.
|
|
Every message must start with "sudo " or "pray ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"entity_nbt": {
|
|
"instruction": """Generate 5 Minecraft chat messages that test entity spawning with unusual NBT data.
|
|
Include: custom names, baby variants, colored sheep, armored mobs, riding/passengers, powered creepers.
|
|
Every message must start with "sudo " or "pray ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"execute_chains": {
|
|
"instruction": """Generate 5 Minecraft chat messages that require complex execute command chains.
|
|
Include: nested execute, conditional execution, store results, dimension switching, targeting by gamemode/team.
|
|
Every message must start with "sudo ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"edge_items": {
|
|
"instruction": """Generate 5 Minecraft chat messages requesting obscure or easily-confused items.
|
|
Include: items with color variants, items that changed names between versions, items with underscores,
|
|
items people misspell (like "wooden_sword" vs "wood_sword", "cooked_beef" vs "steak").
|
|
Every message must start with "sudo ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"worldedit": {
|
|
"instruction": """Generate 5 Minecraft chat messages requesting WorldEdit operations.
|
|
Include: shapes, selections, replacements, brushes, stacking, clipboard operations.
|
|
Every message must start with "sudo ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"multiplayer": {
|
|
"instruction": """Generate 5 Minecraft chat messages involving multiple players or complex selectors.
|
|
Include: @a with exclusions, team commands, scoreboard operations, targeting by distance/gamemode.
|
|
Use player names like: slingshooter08, SwiftWolf, DarkWolf, BraveWolf.
|
|
Every message must start with "sudo ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"boundary_testing": {
|
|
"instruction": """Generate 5 Minecraft chat messages that test safety boundaries.
|
|
Include: requests for forbidden items, mass destruction, OP commands, but also requests that SEEM dangerous
|
|
but are actually fine (like giving TNT to yourself, or killing your own mobs).
|
|
Every message must start with "sudo " or "pray ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
"natural_language": {
|
|
"instruction": """Generate 5 Minecraft chat messages phrased in unusual or creative natural language.
|
|
Include: typos, slang, roleplay, indirect requests, questions, sarcasm, mixed languages.
|
|
The AI should still be able to figure out what the player wants.
|
|
Every message must start with "sudo " or "pray ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.2,
|
|
},
|
|
"cosmetic_effects": {
|
|
"instruction": """Generate 5 Minecraft chat messages requesting cosmetic or dramatic effects.
|
|
Include: particles, sounds, titles, tellraw formatting, fireworks, combination effects.
|
|
Every message must start with "sudo " or "pray ".
|
|
Return a JSON array of strings.""",
|
|
"temperature": 1.0,
|
|
},
|
|
}
|
|
|
|
# System prompts
|
|
SUDO_SYSTEM = """You are a Minecraft 1.21 command translator. Return JSON: {"commands": ["cmd1", ...], "reasoning": "why"}
|
|
Commands use minecraft: prefix. Enchantments: item[enchantments={name:level}]. Effects: effect give <player> minecraft:<effect> <seconds> <amplifier>.
|
|
Do NOT start commands with /. Player name: slingshooter08."""
|
|
|
|
GOD_SYSTEM = """You are God in a Minecraft server. Return JSON: {"message": "dramatic response", "commands": ["cmd1", ...], "reasoning": "why"}
|
|
Commands use minecraft: prefix. Be dramatic but use valid 1.21 syntax. Player: slingshooter08."""
|
|
|
|
RETRY_SYSTEM = """You are a Minecraft 1.21 command translator. Your previous command failed with an error.
|
|
Analyze the error and return a corrected command.
|
|
Return JSON: {"commands": ["corrected_cmd"], "reasoning": "what was wrong and how you fixed it"}"""
|
|
|
|
|
|
# --- Tier 1: Command drills ---
|
|
|
|
def run_tier1_drill(model, ollama_url, rcon_host, rcon_port, rcon_pass, max_retries=2):
|
|
"""Pick a random prompt from seed dataset, generate commands, validate via RCON."""
|
|
seed_path = ROOT / "data" / "processed" / "seed_dataset.jsonl"
|
|
with open(seed_path) as f:
|
|
lines = [l for l in f if l.strip()]
|
|
line = random.choice(lines)
|
|
ex = json.loads(line)
|
|
prompt = ex["input"]["user_message"]
|
|
# Only drill command_gen examples
|
|
if ex.get("category") not in ("command_gen",):
|
|
return None
|
|
|
|
trace = attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, max_retries)
|
|
trace["tier"] = 1
|
|
trace["original_commands"] = ex.get("output", {}).get("commands", [])
|
|
return trace
|
|
|
|
|
|
# --- Tier 2: Single-shot self-critique ---
|
|
|
|
SELF_CRITIQUE_SYSTEM = """You are a Minecraft 1.21 AI training data generator AND command translator.
|
|
|
|
Your task: generate a challenging Minecraft player request, then respond to it yourself.
|
|
Focus on edge cases you might get wrong: unusual items, complex enchantments, execute chains, ambiguous phrasing.
|
|
|
|
Return JSON:
|
|
{
|
|
"generated_prompt": "the player request you invented (must start with 'sudo ' or 'pray ')",
|
|
"difficulty": "what makes this tricky",
|
|
"commands": ["cmd1", "cmd2"],
|
|
"reasoning": "why these commands are correct",
|
|
"message": "God message if pray, empty string if sudo"
|
|
}
|
|
|
|
Commands use minecraft: prefix. Enchantments: item[enchantments={name:level}].
|
|
Effects: effect give <player> minecraft:<effect> <seconds> <amplifier>.
|
|
Player: slingshooter08. Do NOT start commands with /."""
|
|
|
|
|
|
def run_tier2_selfcritique(model, ollama_url, rcon_host, rcon_port, rcon_pass, category=None):
|
|
"""Model generates a prompt AND responds in one shot, then RCON validates."""
|
|
focus = ""
|
|
if category:
|
|
focus = f"\nFocus area: {category}. Generate a prompt specifically testing {category}."
|
|
|
|
try:
|
|
raw = llm_call(
|
|
model=model,
|
|
system=SELF_CRITIQUE_SYSTEM + focus,
|
|
user="Generate one challenging Minecraft request and your response. Be creative — pick something you might get wrong.",
|
|
ollama_url=ollama_url,
|
|
temperature=0.9,
|
|
max_tokens=500,
|
|
fmt="json",
|
|
)
|
|
result = json.loads(raw)
|
|
except:
|
|
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
|
if match:
|
|
try:
|
|
result = json.loads(match.group())
|
|
except:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
prompt = result.get("generated_prompt", "")
|
|
commands = result.get("commands") or []
|
|
message = result.get("message") or ""
|
|
reasoning = result.get("reasoning") or ""
|
|
difficulty = result.get("difficulty") or ""
|
|
|
|
if not prompt:
|
|
return None
|
|
|
|
trace = {
|
|
"prompt": prompt,
|
|
"mode": "god" if prompt.lower().startswith("pray ") else "sudo",
|
|
"tier": 2,
|
|
"difficulty_note": difficulty,
|
|
"attempts": [],
|
|
"final_success": False,
|
|
"self_corrected": False,
|
|
}
|
|
|
|
if not commands:
|
|
trace["attempts"].append({
|
|
"commands": [], "reasoning": reasoning, "message": message,
|
|
"rcon_results": [], "all_success": True,
|
|
})
|
|
trace["final_success"] = True
|
|
return trace
|
|
|
|
# Validate via RCON
|
|
rcon_results = []
|
|
all_success = True
|
|
for cmd in commands:
|
|
success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass)
|
|
rcon_results.append({"command": cmd, "success": success, "result": rcon_result})
|
|
if not success:
|
|
all_success = False
|
|
|
|
trace["attempts"].append({
|
|
"commands": commands, "reasoning": reasoning, "message": message,
|
|
"rcon_results": rcon_results, "all_success": all_success,
|
|
})
|
|
trace["final_success"] = all_success
|
|
return trace
|
|
|
|
|
|
# --- Tier 3: Adversarial self-play (original generate_prompts + attempt_command) ---
|
|
|
|
def generate_prompts(model, ollama_url, category=None):
|
|
"""Use the model to generate edge-case prompts for itself."""
|
|
if category:
|
|
cats = {category: EXPLORATION_CATEGORIES[category]}
|
|
else:
|
|
# Pick 2-3 random categories per round
|
|
keys = random.sample(list(EXPLORATION_CATEGORIES.keys()), min(3, len(EXPLORATION_CATEGORIES)))
|
|
cats = {k: EXPLORATION_CATEGORIES[k] for k in keys}
|
|
|
|
prompts = []
|
|
for cat_name, cat_config in cats.items():
|
|
try:
|
|
raw = llm_call(
|
|
model=model,
|
|
system="You are a Minecraft test case generator. Generate diverse edge cases for an AI training pipeline.",
|
|
user=cat_config["instruction"],
|
|
ollama_url=ollama_url,
|
|
temperature=cat_config["temperature"],
|
|
max_tokens=400,
|
|
)
|
|
# Parse JSON array
|
|
cleaned = raw.replace("```json", "").replace("```", "").strip()
|
|
match = re.search(r'\[[\s\S]*\]', cleaned)
|
|
if match:
|
|
items = json.loads(match.group())
|
|
for item in items:
|
|
if isinstance(item, str) and item.strip():
|
|
prompts.append({"prompt": item.strip(), "category": cat_name})
|
|
except Exception as e:
|
|
print(f" [!] Prompt generation failed for {cat_name}: {e}")
|
|
|
|
return prompts
|
|
|
|
|
|
def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, max_retries=2):
|
|
"""
|
|
Model generates a command for the prompt, executes via RCON.
|
|
On error, model self-corrects up to max_retries times.
|
|
Returns the full interaction trace.
|
|
"""
|
|
mode = "god" if prompt.lower().startswith("pray ") else "sudo"
|
|
system = GOD_SYSTEM if mode == "god" else SUDO_SYSTEM
|
|
|
|
trace = {
|
|
"prompt": prompt,
|
|
"mode": mode,
|
|
"attempts": [],
|
|
"final_success": False,
|
|
"self_corrected": False,
|
|
}
|
|
|
|
# First attempt
|
|
try:
|
|
raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json", api_key=_API_KEY)
|
|
result = json.loads(raw)
|
|
except (json.JSONDecodeError, Exception) as e:
|
|
# Try extracting JSON
|
|
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
|
if match:
|
|
try:
|
|
result = json.loads(match.group())
|
|
except:
|
|
trace["attempts"].append({"commands": [], "error": f"JSON parse failed: {e}"})
|
|
return trace
|
|
else:
|
|
trace["attempts"].append({"commands": [], "error": f"LLM failed: {e}"})
|
|
return trace
|
|
|
|
commands = result.get("commands") or []
|
|
message = result.get("message") or ""
|
|
reasoning = result.get("reasoning") or ""
|
|
|
|
if not commands:
|
|
trace["attempts"].append({
|
|
"commands": [], "reasoning": reasoning, "message": message,
|
|
"rcon_results": [], "all_success": True,
|
|
})
|
|
trace["final_success"] = True # Refusal/info is valid
|
|
return trace
|
|
|
|
# Execute commands via RCON
|
|
rcon_results = []
|
|
all_success = True
|
|
for cmd in commands:
|
|
success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass)
|
|
rcon_results.append({"command": cmd, "success": success, "result": rcon_result})
|
|
if not success:
|
|
all_success = False
|
|
|
|
trace["attempts"].append({
|
|
"commands": commands, "reasoning": reasoning, "message": message,
|
|
"rcon_results": rcon_results, "all_success": all_success,
|
|
})
|
|
|
|
if all_success:
|
|
trace["final_success"] = True
|
|
return trace
|
|
|
|
# Self-correction loop
|
|
for retry in range(max_retries):
|
|
# Build error context for the model
|
|
failed_cmds = [r for r in rcon_results if not r["success"]]
|
|
error_context = "\n".join(
|
|
f"Command: {r['command']}\nError: {r['result']}" for r in failed_cmds
|
|
)
|
|
|
|
retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands."
|
|
|
|
try:
|
|
raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json", api_key=_API_KEY)
|
|
result = json.loads(raw)
|
|
except:
|
|
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
|
if match:
|
|
try:
|
|
result = json.loads(match.group())
|
|
except:
|
|
break
|
|
else:
|
|
break
|
|
|
|
commands = result.get("commands") or []
|
|
reasoning = result.get("reasoning") or ""
|
|
if not commands:
|
|
break
|
|
|
|
# Execute corrected commands
|
|
rcon_results = []
|
|
all_success = True
|
|
for cmd in commands:
|
|
success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass)
|
|
rcon_results.append({"command": cmd, "success": success, "result": rcon_result})
|
|
if not success:
|
|
all_success = False
|
|
|
|
trace["attempts"].append({
|
|
"commands": commands, "reasoning": reasoning,
|
|
"rcon_results": rcon_results, "all_success": all_success,
|
|
"retry": retry + 1,
|
|
})
|
|
|
|
if all_success:
|
|
trace["final_success"] = True
|
|
trace["self_corrected"] = True
|
|
break
|
|
|
|
return trace
|
|
|
|
|
|
def trace_to_training(trace):
|
|
"""Convert a self-play trace to training examples."""
|
|
examples = []
|
|
prompt = trace["prompt"]
|
|
mode = trace["mode"]
|
|
|
|
if not trace["attempts"]:
|
|
return examples
|
|
|
|
# Single successful attempt → standard training pair
|
|
if trace["final_success"] and len(trace["attempts"]) == 1:
|
|
att = trace["attempts"][0]
|
|
ex = {
|
|
"id": f"selfplay-{int(time.time())}-{random.randint(0,999):03d}",
|
|
"source": "self_play",
|
|
"category": "command_gen",
|
|
"input": {
|
|
"user_message": prompt,
|
|
"server_context": {"server_type": "paper", "version": "1.21.x"},
|
|
},
|
|
"output": {
|
|
"reasoning": att.get("reasoning", ""),
|
|
"commands": att.get("commands", []),
|
|
"message": att.get("message", "") if mode == "god" else "",
|
|
"safety_flags": [],
|
|
},
|
|
"metadata": {
|
|
"difficulty": "medium",
|
|
"validated": True,
|
|
"risk_level": 3,
|
|
"rcon_verified": True,
|
|
"self_play": True,
|
|
},
|
|
}
|
|
examples.append(ex)
|
|
|
|
# Self-corrected → multi-turn tool-calling training example
|
|
elif trace["self_corrected"] and len(trace["attempts"]) >= 2:
|
|
messages = []
|
|
|
|
# System
|
|
system = GOD_SYSTEM if mode == "god" else SUDO_SYSTEM
|
|
messages.append({"role": "system", "content": system})
|
|
|
|
# User
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
# First attempt (failed)
|
|
first = trace["attempts"][0]
|
|
for r in first.get("rcon_results", []):
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": f'<tool_call>\n{{"name": "rcon.execute", "arguments": {{"command": "{r["command"]}"}}}}\n</tool_call>'
|
|
})
|
|
messages.append({
|
|
"role": "tool",
|
|
"content": json.dumps({"success": r["success"], "result": r["result"]})
|
|
})
|
|
|
|
# Successful retry
|
|
last = trace["attempts"][-1]
|
|
for r in last.get("rcon_results", []):
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": f'<tool_call>\n{{"name": "rcon.execute", "arguments": {{"command": "{r["command"]}"}}}}\n</tool_call>'
|
|
})
|
|
messages.append({
|
|
"role": "tool",
|
|
"content": json.dumps({"success": r["success"], "result": r["result"]})
|
|
})
|
|
|
|
# Final response
|
|
final_cmds = last.get("commands", [])
|
|
final_response = {
|
|
"commands": final_cmds,
|
|
"reasoning": f"Self-corrected: {first.get('reasoning', '')} → {last.get('reasoning', '')}",
|
|
}
|
|
if mode == "god":
|
|
final_response["message"] = first.get("message", "")
|
|
messages.append({"role": "assistant", "content": json.dumps(final_response)})
|
|
|
|
ex = {
|
|
"id": f"selfplay-correction-{int(time.time())}-{random.randint(0,999):03d}",
|
|
"source": "self_play",
|
|
"type": "error_correction",
|
|
"messages": messages,
|
|
"metadata": {
|
|
"self_play": True,
|
|
"rcon_verified": True,
|
|
"attempts": len(trace["attempts"]),
|
|
},
|
|
}
|
|
examples.append(ex)
|
|
|
|
return examples
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Self-play training data generator")
|
|
parser.add_argument("--model", default="qwen3-8b-mc-lora-v3")
|
|
parser.add_argument("--ollama-url", default="http://192.168.0.141:11434")
|
|
parser.add_argument("--api-key", default=None, help="API key for authenticated gateways")
|
|
parser.add_argument("--rcon-host", default="192.168.0.244")
|
|
parser.add_argument("--rcon-port", type=int, default=25578)
|
|
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
|
parser.add_argument("--rounds", type=int, default=50)
|
|
parser.add_argument("--tier", default="all", choices=["1", "2", "3", "all"])
|
|
parser.add_argument("--focus", default=None, choices=list(EXPLORATION_CATEGORIES.keys()))
|
|
parser.add_argument("--output", default=str(OUTPUT))
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
parser.add_argument("--max-retries", type=int, default=2)
|
|
args = parser.parse_args()
|
|
|
|
global _API_KEY
|
|
_API_KEY = args.api_key
|
|
|
|
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
|
|
|
|
print(f"Self-play training data generator")
|
|
print(f" Model: {args.model}")
|
|
print(f" RCON: {args.rcon_host}:{args.rcon_port}")
|
|
print(f" Rounds: {args.rounds}")
|
|
print(f" Tiers: {tiers}")
|
|
print(f" Focus: {args.focus or 'all categories'}")
|
|
print(f" Max retries: {args.max_retries}")
|
|
print(f" Output: {args.output}")
|
|
|
|
all_examples = []
|
|
stats = {
|
|
"rounds": 0, "prompts_generated": 0, "attempts": 0,
|
|
"success_first_try": 0, "self_corrected": 0, "failed": 0,
|
|
"training_examples": 0, "by_tier": {1: 0, 2: 0, 3: 0}, "by_category": {},
|
|
}
|
|
|
|
for round_num in range(args.rounds):
|
|
tier = tiers[round_num % len(tiers)]
|
|
print(f"\n--- Round {round_num + 1}/{args.rounds} [Tier {tier}] ---")
|
|
|
|
if tier == 1:
|
|
# Command drill: pick random seed example, try to execute
|
|
if args.dry_run:
|
|
print(" [DRY RUN] Would drill a random seed prompt via RCON")
|
|
stats["rounds"] += 1
|
|
continue
|
|
|
|
for _ in range(5): # 5 drills per round
|
|
trace = run_tier1_drill(
|
|
args.model, args.ollama_url,
|
|
args.rcon_host, args.rcon_port, args.rcon_pass,
|
|
args.max_retries,
|
|
)
|
|
if trace is None:
|
|
continue
|
|
stats["attempts"] += 1
|
|
stats["by_tier"][1] += 1
|
|
prompt = trace["prompt"]
|
|
print(f" [drill] {prompt[:55]:55}", end="")
|
|
if trace["final_success"] and not trace["self_corrected"]:
|
|
stats["success_first_try"] += 1
|
|
n_cmds = len(trace["attempts"][0].get("commands", []))
|
|
print(f" OK ({n_cmds} cmds)")
|
|
elif trace.get("self_corrected"):
|
|
stats["self_corrected"] += 1
|
|
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
|
else:
|
|
stats["failed"] += 1
|
|
print(f" FAILED")
|
|
examples = trace_to_training(trace)
|
|
all_examples.extend(examples)
|
|
stats["training_examples"] += len(examples)
|
|
time.sleep(0.1)
|
|
|
|
elif tier == 2:
|
|
# Self-critique: model generates prompt + response, RCON validates
|
|
cats = [args.focus] if args.focus else random.sample(
|
|
list(EXPLORATION_CATEGORIES.keys()), min(3, len(EXPLORATION_CATEGORIES))
|
|
)
|
|
for cat in cats:
|
|
if args.dry_run:
|
|
print(f" [DRY RUN] Would self-critique on {cat}")
|
|
continue
|
|
|
|
trace = run_tier2_selfcritique(
|
|
args.model, args.ollama_url,
|
|
args.rcon_host, args.rcon_port, args.rcon_pass,
|
|
category=cat,
|
|
)
|
|
if trace is None:
|
|
continue
|
|
stats["attempts"] += 1
|
|
stats["by_tier"][2] += 1
|
|
prompt = trace["prompt"]
|
|
diff = trace.get("difficulty_note", "")[:30]
|
|
print(f" [self-critique:{cat[:12]}] {prompt[:40]:40} ({diff})", end="")
|
|
if trace["final_success"]:
|
|
stats["success_first_try"] += 1
|
|
n_cmds = len(trace["attempts"][0].get("commands", []))
|
|
print(f" OK ({n_cmds} cmds)")
|
|
else:
|
|
stats["failed"] += 1
|
|
print(f" FAILED (self-generated bad commands)")
|
|
examples = trace_to_training(trace)
|
|
all_examples.extend(examples)
|
|
stats["training_examples"] += len(examples)
|
|
time.sleep(0.1)
|
|
|
|
elif tier == 3:
|
|
# Adversarial: generate prompts in Session A, respond in fresh Session B
|
|
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
|
|
if not prompts:
|
|
print(" No prompts generated, skipping round")
|
|
stats["rounds"] += 1
|
|
continue
|
|
|
|
stats["prompts_generated"] += len(prompts)
|
|
print(f" Generated {len(prompts)} adversarial prompts")
|
|
|
|
for p in prompts:
|
|
prompt = p["prompt"]
|
|
cat = p["category"]
|
|
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
|
|
stats["by_category"][cat]["total"] += 1
|
|
stats["attempts"] += 1
|
|
stats["by_tier"][3] += 1
|
|
|
|
print(f" [adversarial:{cat[:12]}] {prompt[:48]:48}", end="")
|
|
|
|
if args.dry_run:
|
|
print(" [DRY RUN]")
|
|
continue
|
|
|
|
trace = attempt_command(
|
|
args.model, args.ollama_url, prompt,
|
|
args.rcon_host, args.rcon_port, args.rcon_pass,
|
|
max_retries=args.max_retries,
|
|
)
|
|
|
|
if trace["final_success"] and not trace["self_corrected"]:
|
|
stats["success_first_try"] += 1
|
|
stats["by_category"][cat]["success"] += 1
|
|
n_cmds = len(trace["attempts"][0].get("commands", []))
|
|
print(f" OK ({n_cmds} cmds)")
|
|
elif trace["self_corrected"]:
|
|
stats["self_corrected"] += 1
|
|
stats["by_category"][cat]["corrected"] += 1
|
|
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
|
else:
|
|
stats["failed"] += 1
|
|
print(f" FAILED")
|
|
|
|
examples = trace_to_training(trace)
|
|
all_examples.extend(examples)
|
|
stats["training_examples"] += len(examples)
|
|
time.sleep(0.1)
|
|
|
|
stats["rounds"] += 1
|
|
|
|
# Save
|
|
if not args.dry_run and all_examples:
|
|
output_path = Path(args.output)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(output_path, "a") as f:
|
|
for ex in all_examples:
|
|
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
|
|
|
# Summary
|
|
print(f"\n{'='*60}")
|
|
print(f"Self-play complete")
|
|
print(f" Rounds: {stats['rounds']}")
|
|
print(f" Prompts generated:{stats['prompts_generated']}")
|
|
print(f" Attempts: {stats['attempts']}")
|
|
print(f" Success (1st try):{stats['success_first_try']}")
|
|
print(f" Self-corrected: {stats['self_corrected']}")
|
|
print(f" Failed: {stats['failed']}")
|
|
print(f" Training examples:{stats['training_examples']}")
|
|
|
|
print(f"\n By tier:")
|
|
for t in sorted(stats["by_tier"]):
|
|
labels = {1: "Command drills", 2: "Self-critique", 3: "Adversarial"}
|
|
print(f" Tier {t} ({labels[t]:16}): {stats['by_tier'][t]} attempts")
|
|
|
|
if stats["by_category"]:
|
|
print(f"\n By category:")
|
|
for cat, s in sorted(stats["by_category"].items()):
|
|
total = s["total"]
|
|
ok = s["success"]
|
|
corr = s["corrected"]
|
|
fail = total - ok - corr
|
|
print(f" {cat:25} total={total} ok={ok} corrected={corr} failed={fail}")
|
|
|
|
print(f"\n Output: {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|