Files
Mortdecai/training/scripts/distill.py
T
Seth 65ee146043 Swarm bots, RCON validation, Haiku distillation complete
Swarm bots (ingame/swarm_bots.js):
- 10 survival bots with generated names (SwiftWolf, DarkWolf, etc.)
- All bots wander, take damage, auto-respawn, pray when hurt
- Gemini + Dolphin(5%) + Multilingual(3%) prompt generation
- 20-60s interaction interval per bot

Distillation results:
- 222 sudo examples via Haiku ($0.28)
- 122 god examples via Haiku ($0.37) — with God Soul personality
- Total: 344 distilled, $0.65 spent of $5 budget
- RCON validation: 74.7% fully valid, 30 real errors out of ~1000 commands

validate_distilled.py:
- Executes distilled commands on live server via RCON
- Distinguishes real errors from benign (no player online)
- Tags each example with validation status

Dev server switched to Claude Haiku via Anthropic API:
- llm_provider: anthropic with $5 budget cap
- Auto-fallback to Ollama when budget exhausted
- Cost tracking with logging

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 19:18:19 -04:00

254 lines
8.7 KiB
Python

#!/usr/bin/env python3
"""
distill.py — Use Claude to generate gold-standard training responses.
Takes existing dataset examples, sends each one through Claude with the
God Soul / sudo system prompts, and replaces the output with Claude's
higher-quality response. This teaches the small model to approximate
Claude's judgment within the Minecraft domain.
Uses Haiku for sudo (cheap, just needs accurate commands) and
Sonnet for god mode (needs personality, creativity, character).
Usage:
python3 training/scripts/distill.py # distill all
python3 training/scripts/distill.py --dry-run # estimate cost
python3 training/scripts/distill.py --mode god # only god examples
python3 training/scripts/distill.py --mode sudo # only sudo examples
python3 training/scripts/distill.py --budget 5.00 # max spend in USD
python3 training/scripts/distill.py --output data/processed/distilled.jsonl
"""
import argparse
import json
import re
import sys
import time
from pathlib import Path
import requests
ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(ROOT))
from agent.prompts.system_prompts import get_prompt
DATASET = ROOT / "data" / "processed" / "seed_dataset.jsonl"
OUTPUT_DEFAULT = ROOT / "data" / "processed" / "distilled.jsonl"
API_KEY = "REDACTED_ANTHROPIC_KEY_2"
API_URL = "https://api.anthropic.com/v1/messages"
# Model selection and pricing (per million tokens)
MODELS = {
"sudo": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00},
"god": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00},
"god_system": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00},
}
def determine_mode(example: dict) -> str:
query = example["input"]["user_message"]
eid = example.get("id", "")
if query.lower().startswith("pray ") or example.get("source") == "prayer_log":
return "god"
elif eid.startswith("negative-") and "god" in query.lower():
return "god_system"
return "sudo"
def build_user_message(example: dict) -> str:
inp = example["input"]
query = inp["user_message"]
ctx = inp.get("server_context", {})
parts = [f"Request from slingshooter08: {query}"]
parts.append(f"\nContext:\nServer: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}")
if ctx.get("online_players"):
parts.append(f"Online: {', '.join(ctx['online_players'])}")
pos = ctx.get("player_position")
if pos:
parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})")
return "\n".join(parts)
def call_claude(model: str, system: str, user: str) -> dict:
"""Call Claude API and return parsed JSON response."""
headers = {
"x-api-key": API_KEY,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
body = {
"model": model,
"max_tokens": 500,
"system": system,
"messages": [{"role": "user", "content": user}],
}
resp = requests.post(API_URL, headers=headers, json=body, timeout=60)
resp.raise_for_status()
data = resp.json()
text = data["content"][0]["text"]
input_tokens = data["usage"]["input_tokens"]
output_tokens = data["usage"]["output_tokens"]
# Parse JSON from response
try:
parsed = json.loads(text)
except json.JSONDecodeError:
# Try to extract JSON from markdown wrapper
match = re.search(r'\{[\s\S]*\}', text)
if match:
parsed = json.loads(match.group())
else:
parsed = {"commands": [], "message": "", "reasoning": "parse_failed"}
return {
"parsed": parsed,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"raw": text,
}
def estimate_cost(examples: list) -> dict:
"""Estimate API cost without making calls."""
counts = {"sudo": 0, "god": 0, "god_system": 0}
for ex in examples:
mode = determine_mode(ex)
counts[mode] += 1
total = 0
details = {}
for mode, count in counts.items():
if count == 0:
continue
cfg = MODELS[mode]
# Estimate ~600 input tokens (system + user), ~150 output tokens
input_cost = (count * 600 / 1_000_000) * cfg["input_per_m"]
output_cost = (count * 150 / 1_000_000) * cfg["output_per_m"]
mode_cost = input_cost + output_cost
total += mode_cost
details[mode] = {"count": count, "model": cfg["model"], "est_cost": round(mode_cost, 4)}
return {"total_est": round(total, 4), "details": details}
def main():
parser = argparse.ArgumentParser(description="Claude distillation pipeline")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--mode", choices=["sudo", "god", "all"], default="all")
parser.add_argument("--budget", type=float, default=5.00)
parser.add_argument("--output", default=str(OUTPUT_DEFAULT))
args = parser.parse_args()
with open(DATASET) as f:
examples = [json.loads(l) for l in f if l.strip()]
# Filter by mode
if args.mode != "all":
examples = [ex for ex in examples if determine_mode(ex) == args.mode]
# Skip examples that are just abstention/empty (no useful distillation target)
examples = [ex for ex in examples if ex.get("id", "").startswith("abstain-") is False]
print(f"Distillation pipeline")
print(f" Dataset: {len(examples)} examples")
print(f" Budget: ${args.budget:.2f}")
print(f" Output: {args.output}")
cost_est = estimate_cost(examples)
print(f"\n Estimated cost: ${cost_est['total_est']:.4f}")
for mode, d in cost_est["details"].items():
print(f" {mode}: {d['count']} examples via {d['model']} (${d['est_cost']:.4f})")
if args.dry_run:
print(f"\n[DRY RUN] Would process {len(examples)} examples for ~${cost_est['total_est']:.4f}")
return
if cost_est["total_est"] > args.budget:
print(f"\nEstimated cost ${cost_est['total_est']:.4f} exceeds budget ${args.budget:.2f}. Reduce examples or increase budget.")
return
# Process
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_input = 0
total_output = 0
total_cost = 0.0
processed = 0
errors = 0
results = []
for i, ex in enumerate(examples):
mode = determine_mode(ex)
cfg = MODELS[mode]
system_prompt = get_prompt(mode)
user_msg = build_user_message(ex)
# Check budget
if total_cost >= args.budget:
print(f"\n Budget reached at ${total_cost:.4f} after {processed} examples")
break
try:
result = call_claude(cfg["model"], system_prompt, user_msg)
except Exception as e:
print(f" [{i+1}/{len(examples)}] ERROR: {e}")
errors += 1
time.sleep(1)
continue
parsed = result["parsed"]
total_input += result["input_tokens"]
total_output += result["output_tokens"]
# Calculate cost
cost = (result["input_tokens"] / 1_000_000) * cfg["input_per_m"] + \
(result["output_tokens"] / 1_000_000) * cfg["output_per_m"]
total_cost += cost
processed += 1
# Build distilled example
distilled = dict(ex)
distilled["output"] = {
"reasoning": parsed.get("reasoning", ""),
"commands": parsed.get("commands", []),
"message": parsed.get("message") if mode in ("god", "god_system") else None,
"safety_flags": ex["output"].get("safety_flags", []),
}
distilled["metadata"] = dict(ex.get("metadata", {}))
distilled["metadata"]["distilled_by"] = cfg["model"]
distilled["metadata"]["distilled_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
distilled["id"] = f"distill-{ex.get('id', f'ex-{i}')}"
results.append(distilled)
cmds = len(parsed.get("commands", []))
msg_preview = (parsed.get("message", "") or "")[:40]
print(f" [{i+1}/{len(examples)}] ({mode}) {ex['input']['user_message'][:45]:45} [{cmds} cmds] ${cost:.4f} {msg_preview}")
# Rate limit: ~50 req/min for Haiku, ~20 for Sonnet
time.sleep(0.5 if mode == "sudo" else 1.5)
# Write results
with open(output_path, "w") as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"\n{'='*60}")
print(f"Distillation complete")
print(f" Processed: {processed}")
print(f" Errors: {errors}")
print(f" Input tokens: {total_input:,}")
print(f" Output tokens: {total_output:,}")
print(f" Total cost: ${total_cost:.4f}")
print(f" Output: {output_path}")
if __name__ == "__main__":
main()