#!/usr/bin/env python3 """ distill.py — Use Claude to generate gold-standard training responses. Takes existing dataset examples, sends each one through Claude with the God Soul / sudo system prompts, and replaces the output with Claude's higher-quality response. This teaches the small model to approximate Claude's judgment within the Minecraft domain. Uses Haiku for sudo (cheap, just needs accurate commands) and Sonnet for god mode (needs personality, creativity, character). Usage: python3 training/scripts/distill.py # distill all python3 training/scripts/distill.py --dry-run # estimate cost python3 training/scripts/distill.py --mode god # only god examples python3 training/scripts/distill.py --mode sudo # only sudo examples python3 training/scripts/distill.py --budget 5.00 # max spend in USD python3 training/scripts/distill.py --output data/processed/distilled.jsonl """ import argparse import json import re import sys import time from pathlib import Path import requests ROOT = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(ROOT)) from agent.prompts.system_prompts import get_prompt DATASET = ROOT / "data" / "processed" / "seed_dataset.jsonl" OUTPUT_DEFAULT = ROOT / "data" / "processed" / "distilled.jsonl" API_KEY = "REDACTED_ANTHROPIC_KEY_2" API_URL = "https://api.anthropic.com/v1/messages" # Model selection and pricing (per million tokens) MODELS = { "sudo": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00}, "god": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00}, "god_system": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00}, } def determine_mode(example: dict) -> str: query = example["input"]["user_message"] eid = example.get("id", "") if query.lower().startswith("pray ") or example.get("source") == "prayer_log": return "god" elif eid.startswith("negative-") and "god" in query.lower(): return "god_system" return "sudo" def build_user_message(example: dict) -> str: inp = example["input"] query = inp["user_message"] ctx = inp.get("server_context", {}) parts = [f"Request from slingshooter08: {query}"] parts.append(f"\nContext:\nServer: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}") if ctx.get("online_players"): parts.append(f"Online: {', '.join(ctx['online_players'])}") pos = ctx.get("player_position") if pos: parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})") return "\n".join(parts) def call_claude(model: str, system: str, user: str) -> dict: """Call Claude API and return parsed JSON response.""" headers = { "x-api-key": API_KEY, "anthropic-version": "2023-06-01", "content-type": "application/json", } body = { "model": model, "max_tokens": 500, "system": system, "messages": [{"role": "user", "content": user}], } resp = requests.post(API_URL, headers=headers, json=body, timeout=60) resp.raise_for_status() data = resp.json() text = data["content"][0]["text"] input_tokens = data["usage"]["input_tokens"] output_tokens = data["usage"]["output_tokens"] # Parse JSON from response try: parsed = json.loads(text) except json.JSONDecodeError: # Try to extract JSON from markdown wrapper match = re.search(r'\{[\s\S]*\}', text) if match: parsed = json.loads(match.group()) else: parsed = {"commands": [], "message": "", "reasoning": "parse_failed"} return { "parsed": parsed, "input_tokens": input_tokens, "output_tokens": output_tokens, "raw": text, } def estimate_cost(examples: list) -> dict: """Estimate API cost without making calls.""" counts = {"sudo": 0, "god": 0, "god_system": 0} for ex in examples: mode = determine_mode(ex) counts[mode] += 1 total = 0 details = {} for mode, count in counts.items(): if count == 0: continue cfg = MODELS[mode] # Estimate ~600 input tokens (system + user), ~150 output tokens input_cost = (count * 600 / 1_000_000) * cfg["input_per_m"] output_cost = (count * 150 / 1_000_000) * cfg["output_per_m"] mode_cost = input_cost + output_cost total += mode_cost details[mode] = {"count": count, "model": cfg["model"], "est_cost": round(mode_cost, 4)} return {"total_est": round(total, 4), "details": details} def main(): parser = argparse.ArgumentParser(description="Claude distillation pipeline") parser.add_argument("--dry-run", action="store_true") parser.add_argument("--mode", choices=["sudo", "god", "all"], default="all") parser.add_argument("--budget", type=float, default=5.00) parser.add_argument("--output", default=str(OUTPUT_DEFAULT)) args = parser.parse_args() with open(DATASET) as f: examples = [json.loads(l) for l in f if l.strip()] # Filter by mode if args.mode != "all": examples = [ex for ex in examples if determine_mode(ex) == args.mode] # Skip examples that are just abstention/empty (no useful distillation target) examples = [ex for ex in examples if ex.get("id", "").startswith("abstain-") is False] print(f"Distillation pipeline") print(f" Dataset: {len(examples)} examples") print(f" Budget: ${args.budget:.2f}") print(f" Output: {args.output}") cost_est = estimate_cost(examples) print(f"\n Estimated cost: ${cost_est['total_est']:.4f}") for mode, d in cost_est["details"].items(): print(f" {mode}: {d['count']} examples via {d['model']} (${d['est_cost']:.4f})") if args.dry_run: print(f"\n[DRY RUN] Would process {len(examples)} examples for ~${cost_est['total_est']:.4f}") return if cost_est["total_est"] > args.budget: print(f"\nEstimated cost ${cost_est['total_est']:.4f} exceeds budget ${args.budget:.2f}. Reduce examples or increase budget.") return # Process output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) total_input = 0 total_output = 0 total_cost = 0.0 processed = 0 errors = 0 results = [] for i, ex in enumerate(examples): mode = determine_mode(ex) cfg = MODELS[mode] system_prompt = get_prompt(mode) user_msg = build_user_message(ex) # Check budget if total_cost >= args.budget: print(f"\n Budget reached at ${total_cost:.4f} after {processed} examples") break try: result = call_claude(cfg["model"], system_prompt, user_msg) except Exception as e: print(f" [{i+1}/{len(examples)}] ERROR: {e}") errors += 1 time.sleep(1) continue parsed = result["parsed"] total_input += result["input_tokens"] total_output += result["output_tokens"] # Calculate cost cost = (result["input_tokens"] / 1_000_000) * cfg["input_per_m"] + \ (result["output_tokens"] / 1_000_000) * cfg["output_per_m"] total_cost += cost processed += 1 # Build distilled example distilled = dict(ex) distilled["output"] = { "reasoning": parsed.get("reasoning", ""), "commands": parsed.get("commands", []), "message": parsed.get("message") if mode in ("god", "god_system") else None, "safety_flags": ex["output"].get("safety_flags", []), } distilled["metadata"] = dict(ex.get("metadata", {})) distilled["metadata"]["distilled_by"] = cfg["model"] distilled["metadata"]["distilled_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) distilled["id"] = f"distill-{ex.get('id', f'ex-{i}')}" results.append(distilled) cmds = len(parsed.get("commands", [])) msg_preview = (parsed.get("message", "") or "")[:40] print(f" [{i+1}/{len(examples)}] ({mode}) {ex['input']['user_message'][:45]:45} [{cmds} cmds] ${cost:.4f} {msg_preview}") # Rate limit: ~50 req/min for Haiku, ~20 for Sonnet time.sleep(0.5 if mode == "sudo" else 1.5) # Write results with open(output_path, "w") as f: for r in results: f.write(json.dumps(r, ensure_ascii=False) + "\n") print(f"\n{'='*60}") print(f"Distillation complete") print(f" Processed: {processed}") print(f" Errors: {errors}") print(f" Input tokens: {total_input:,}") print(f" Output tokens: {total_output:,}") print(f" Total cost: ${total_cost:.4f}") print(f" Output: {output_path}") if __name__ == "__main__": main()