3-tier self-play: command drills, self-critique, adversarial

Tier 1 — Command drills:
  Random seed prompts → generate commands → RCON validates
  Teaches: accurate command syntax

Tier 2 — Single-shot self-critique:
  Model invents a tricky prompt AND responds in one call
  RCON validates the self-generated commands
  Teaches: edge-case awareness, self-evaluation

Tier 3 — Adversarial self-play:
  Session A generates challenging prompts
  Fresh Session B responds cold (can't cheat)
  RCON validates, self-corrects on errors
  Teaches: robustness, generalization

Usage: --tier 1|2|3|all --rounds N --focus category

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-19 19:39:33 -04:00
parent c947fc3fa9
commit 9abf9238c5
+256 -55
View File
@@ -1,20 +1,29 @@
#!/usr/bin/env python3
"""
self_play.py — Self-play training data generator.
self_play.py — Multi-tier self-play training data generator.
The fine-tuned model generates its own training data by:
1. Generating diverse edge-case prompts it's uncertain about
2. Attempting commands via RCON
3. Self-correcting on errors
4. Saving successful sequences as training examples
Three tiers of self-play, each teaching different skills:
This creates a closed-loop learning system with RCON as ground truth.
No API cost — runs entirely on the local model.
Tier 1 — Command drills:
Feed known prompts, execute commands via RCON, validate syntax.
Teaches: accurate command generation.
Usage: --tier 1 --rounds 50
Usage:
python3 training/scripts/self_play.py --rounds 100 --model qwen3.5-9b-mc-v4
python3 training/scripts/self_play.py --rounds 50 --dry-run
python3 training/scripts/self_play.py --rounds 200 --focus enchantments
Tier 2 — Single-shot self-critique:
Model generates BOTH the prompt AND response in one call.
Teaches: edge-case awareness, self-evaluation.
Usage: --tier 2 --rounds 50
Tier 3 — Adversarial self-play:
Session A generates a challenging prompt. Fresh Session B responds.
RCON validates. Model can't cheat by knowing both sides.
Teaches: robustness, generalization, error correction.
Usage: --tier 3 --rounds 50
All tiers:
--tier all --rounds 50 (runs ~17 rounds of each)
No API cost — runs entirely on the local model with RCON as ground truth.
"""
import argparse
@@ -182,6 +191,120 @@ Analyze the error and return a corrected command.
Return JSON: {"commands": ["corrected_cmd"], "reasoning": "what was wrong and how you fixed it"}"""
# --- Tier 1: Command drills ---
def run_tier1_drill(model, ollama_url, rcon_host, rcon_port, rcon_pass, max_retries=2):
"""Pick a random prompt from seed dataset, generate commands, validate via RCON."""
seed_path = ROOT / "data" / "processed" / "seed_dataset.jsonl"
with open(seed_path) as f:
lines = [l for l in f if l.strip()]
line = random.choice(lines)
ex = json.loads(line)
prompt = ex["input"]["user_message"]
# Only drill command_gen examples
if ex.get("category") not in ("command_gen",):
return None
trace = attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, max_retries)
trace["tier"] = 1
trace["original_commands"] = ex.get("output", {}).get("commands", [])
return trace
# --- Tier 2: Single-shot self-critique ---
SELF_CRITIQUE_SYSTEM = """You are a Minecraft 1.21 AI training data generator AND command translator.
Your task: generate a challenging Minecraft player request, then respond to it yourself.
Focus on edge cases you might get wrong: unusual items, complex enchantments, execute chains, ambiguous phrasing.
Return JSON:
{
"generated_prompt": "the player request you invented (must start with 'sudo ' or 'pray ')",
"difficulty": "what makes this tricky",
"commands": ["cmd1", "cmd2"],
"reasoning": "why these commands are correct",
"message": "God message if pray, empty string if sudo"
}
Commands use minecraft: prefix. Enchantments: item[enchantments={name:level}].
Effects: effect give <player> minecraft:<effect> <seconds> <amplifier>.
Player: slingshooter08. Do NOT start commands with /."""
def run_tier2_selfcritique(model, ollama_url, rcon_host, rcon_port, rcon_pass, category=None):
"""Model generates a prompt AND responds in one shot, then RCON validates."""
focus = ""
if category:
focus = f"\nFocus area: {category}. Generate a prompt specifically testing {category}."
try:
raw = llm_call(
model=model,
system=SELF_CRITIQUE_SYSTEM + focus,
user="Generate one challenging Minecraft request and your response. Be creative — pick something you might get wrong.",
ollama_url=ollama_url,
temperature=0.9,
max_tokens=500,
fmt="json",
)
result = json.loads(raw)
except:
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
if match:
try:
result = json.loads(match.group())
except:
return None
else:
return None
prompt = result.get("generated_prompt", "")
commands = result.get("commands") or []
message = result.get("message") or ""
reasoning = result.get("reasoning") or ""
difficulty = result.get("difficulty") or ""
if not prompt:
return None
trace = {
"prompt": prompt,
"mode": "god" if prompt.lower().startswith("pray ") else "sudo",
"tier": 2,
"difficulty_note": difficulty,
"attempts": [],
"final_success": False,
"self_corrected": False,
}
if not commands:
trace["attempts"].append({
"commands": [], "reasoning": reasoning, "message": message,
"rcon_results": [], "all_success": True,
})
trace["final_success"] = True
return trace
# Validate via RCON
rcon_results = []
all_success = True
for cmd in commands:
success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass)
rcon_results.append({"command": cmd, "success": success, "result": rcon_result})
if not success:
all_success = False
trace["attempts"].append({
"commands": commands, "reasoning": reasoning, "message": message,
"rcon_results": rcon_results, "all_success": all_success,
})
trace["final_success"] = all_success
return trace
# --- Tier 3: Adversarial self-play (original generate_prompts + attempt_command) ---
def generate_prompts(model, ollama_url, category=None):
"""Use the model to generate edge-case prompts for itself."""
if category:
@@ -436,16 +559,20 @@ def main():
parser.add_argument("--rcon-port", type=int, default=25578)
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
parser.add_argument("--rounds", type=int, default=50)
parser.add_argument("--tier", default="all", choices=["1", "2", "3", "all"])
parser.add_argument("--focus", default=None, choices=list(EXPLORATION_CATEGORIES.keys()))
parser.add_argument("--output", default=str(OUTPUT))
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--max-retries", type=int, default=2)
args = parser.parse_args()
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
print(f"Self-play training data generator")
print(f" Model: {args.model}")
print(f" RCON: {args.rcon_host}:{args.rcon_port}")
print(f" Rounds: {args.rounds}")
print(f" Tiers: {tiers}")
print(f" Focus: {args.focus or 'all categories'}")
print(f" Max retries: {args.max_retries}")
print(f" Output: {args.output}")
@@ -454,60 +581,129 @@ def main():
stats = {
"rounds": 0, "prompts_generated": 0, "attempts": 0,
"success_first_try": 0, "self_corrected": 0, "failed": 0,
"training_examples": 0, "by_category": {},
"training_examples": 0, "by_tier": {1: 0, 2: 0, 3: 0}, "by_category": {},
}
for round_num in range(args.rounds):
print(f"\n--- Round {round_num + 1}/{args.rounds} ---")
# Generate prompts
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
if not prompts:
print(" No prompts generated, skipping round")
continue
stats["prompts_generated"] += len(prompts)
print(f" Generated {len(prompts)} prompts")
for p in prompts:
prompt = p["prompt"]
cat = p["category"]
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
stats["by_category"][cat]["total"] += 1
stats["attempts"] += 1
print(f" [{cat}] {prompt[:60]:60}", end="")
tier = tiers[round_num % len(tiers)]
print(f"\n--- Round {round_num + 1}/{args.rounds} [Tier {tier}] ---")
if tier == 1:
# Command drill: pick random seed example, try to execute
if args.dry_run:
print(" [DRY RUN]")
print(" [DRY RUN] Would drill a random seed prompt via RCON")
stats["rounds"] += 1
continue
trace = attempt_command(
args.model, args.ollama_url, prompt,
args.rcon_host, args.rcon_port, args.rcon_pass,
max_retries=args.max_retries,
for _ in range(5): # 5 drills per round
trace = run_tier1_drill(
args.model, args.ollama_url,
args.rcon_host, args.rcon_port, args.rcon_pass,
args.max_retries,
)
if trace is None:
continue
stats["attempts"] += 1
stats["by_tier"][1] += 1
prompt = trace["prompt"]
print(f" [drill] {prompt[:55]:55}", end="")
if trace["final_success"] and not trace["self_corrected"]:
stats["success_first_try"] += 1
n_cmds = len(trace["attempts"][0].get("commands", []))
print(f" OK ({n_cmds} cmds)")
elif trace.get("self_corrected"):
stats["self_corrected"] += 1
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
else:
stats["failed"] += 1
print(f" FAILED")
examples = trace_to_training(trace)
all_examples.extend(examples)
stats["training_examples"] += len(examples)
time.sleep(0.5)
elif tier == 2:
# Self-critique: model generates prompt + response, RCON validates
cats = [args.focus] if args.focus else random.sample(
list(EXPLORATION_CATEGORIES.keys()), min(3, len(EXPLORATION_CATEGORIES))
)
for cat in cats:
if args.dry_run:
print(f" [DRY RUN] Would self-critique on {cat}")
continue
if trace["final_success"] and not trace["self_corrected"]:
stats["success_first_try"] += 1
stats["by_category"][cat]["success"] += 1
n_cmds = len(trace["attempts"][0].get("commands", []))
print(f" OK ({n_cmds} cmds)")
elif trace["self_corrected"]:
stats["self_corrected"] += 1
stats["by_category"][cat]["corrected"] += 1
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
else:
stats["failed"] += 1
print(f" FAILED")
trace = run_tier2_selfcritique(
args.model, args.ollama_url,
args.rcon_host, args.rcon_port, args.rcon_pass,
category=cat,
)
if trace is None:
continue
stats["attempts"] += 1
stats["by_tier"][2] += 1
prompt = trace["prompt"]
diff = trace.get("difficulty_note", "")[:30]
print(f" [self-critique:{cat[:12]}] {prompt[:40]:40} ({diff})", end="")
if trace["final_success"]:
stats["success_first_try"] += 1
n_cmds = len(trace["attempts"][0].get("commands", []))
print(f" OK ({n_cmds} cmds)")
else:
stats["failed"] += 1
print(f" FAILED (self-generated bad commands)")
examples = trace_to_training(trace)
all_examples.extend(examples)
stats["training_examples"] += len(examples)
time.sleep(1)
# Convert to training examples
examples = trace_to_training(trace)
all_examples.extend(examples)
stats["training_examples"] += len(examples)
elif tier == 3:
# Adversarial: generate prompts in Session A, respond in fresh Session B
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
if not prompts:
print(" No prompts generated, skipping round")
stats["rounds"] += 1
continue
# Brief pause between attempts
time.sleep(1)
stats["prompts_generated"] += len(prompts)
print(f" Generated {len(prompts)} adversarial prompts")
for p in prompts:
prompt = p["prompt"]
cat = p["category"]
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
stats["by_category"][cat]["total"] += 1
stats["attempts"] += 1
stats["by_tier"][3] += 1
print(f" [adversarial:{cat[:12]}] {prompt[:48]:48}", end="")
if args.dry_run:
print(" [DRY RUN]")
continue
trace = attempt_command(
args.model, args.ollama_url, prompt,
args.rcon_host, args.rcon_port, args.rcon_pass,
max_retries=args.max_retries,
)
if trace["final_success"] and not trace["self_corrected"]:
stats["success_first_try"] += 1
stats["by_category"][cat]["success"] += 1
n_cmds = len(trace["attempts"][0].get("commands", []))
print(f" OK ({n_cmds} cmds)")
elif trace["self_corrected"]:
stats["self_corrected"] += 1
stats["by_category"][cat]["corrected"] += 1
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
else:
stats["failed"] += 1
print(f" FAILED")
examples = trace_to_training(trace)
all_examples.extend(examples)
stats["training_examples"] += len(examples)
time.sleep(1)
stats["rounds"] += 1
@@ -530,6 +726,11 @@ def main():
print(f" Failed: {stats['failed']}")
print(f" Training examples:{stats['training_examples']}")
print(f"\n By tier:")
for t in sorted(stats["by_tier"]):
labels = {1: "Command drills", 2: "Self-critique", 3: "Adversarial"}
print(f" Tier {t} ({labels[t]:16}): {stats['by_tier'][t]} attempts")
if stats["by_category"]:
print(f"\n By category:")
for cat, s in sorted(stats["by_category"].items()):