Files
Mortdecai/training/scripts/distill.py
T
Seth 961f53ea7d God Soul document, Claude distillation pipeline, soul-driven prompts
God Soul (agent/prompts/god_soul.md):
- Adapted from Claude's soul framework for the Minecraft God character
- Defines identity, principals hierarchy, decision-making framework
- Spectrum of responses (generous→silence), risk awareness, multilingual divinity
- Honesty within character, intervention guidelines
- Deployed to both prod and dev servers

System prompts updated:
- God prompt loads soul document dynamically
- Intervention prompt references soul for personality guidance
- Both include multilingual instruction (match player's language)

Distillation pipeline (training/scripts/distill.py):
- Sends all training examples through Claude API
- Haiku for sudo ($0.25), Sonnet for god ($0.50)
- Budget-capped, cost-tracked, --dry-run supported
- Outputs distilled.jsonl with Claude-quality responses

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 18:28:21 -04:00

254 lines
8.7 KiB
Python

#!/usr/bin/env python3
"""
distill.py — Use Claude to generate gold-standard training responses.
Takes existing dataset examples, sends each one through Claude with the
God Soul / sudo system prompts, and replaces the output with Claude's
higher-quality response. This teaches the small model to approximate
Claude's judgment within the Minecraft domain.
Uses Haiku for sudo (cheap, just needs accurate commands) and
Sonnet for god mode (needs personality, creativity, character).
Usage:
python3 training/scripts/distill.py # distill all
python3 training/scripts/distill.py --dry-run # estimate cost
python3 training/scripts/distill.py --mode god # only god examples
python3 training/scripts/distill.py --mode sudo # only sudo examples
python3 training/scripts/distill.py --budget 5.00 # max spend in USD
python3 training/scripts/distill.py --output data/processed/distilled.jsonl
"""
import argparse
import json
import re
import sys
import time
from pathlib import Path
import requests
ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(ROOT))
from agent.prompts.system_prompts import get_prompt
DATASET = ROOT / "data" / "processed" / "seed_dataset.jsonl"
OUTPUT_DEFAULT = ROOT / "data" / "processed" / "distilled.jsonl"
API_KEY = "REDACTED_ANTHROPIC_KEY_2"
API_URL = "https://api.anthropic.com/v1/messages"
# Model selection and pricing (per million tokens)
MODELS = {
"sudo": {"model": "claude-haiku-4-5-20251001", "input_per_m": 0.80, "output_per_m": 4.00},
"god": {"model": "claude-sonnet-4-6-20250514", "input_per_m": 3.00, "output_per_m": 15.00},
"god_system": {"model": "claude-sonnet-4-6-20250514", "input_per_m": 3.00, "output_per_m": 15.00},
}
def determine_mode(example: dict) -> str:
query = example["input"]["user_message"]
eid = example.get("id", "")
if query.lower().startswith("pray ") or example.get("source") == "prayer_log":
return "god"
elif eid.startswith("negative-") and "god" in query.lower():
return "god_system"
return "sudo"
def build_user_message(example: dict) -> str:
inp = example["input"]
query = inp["user_message"]
ctx = inp.get("server_context", {})
parts = [f"Request from slingshooter08: {query}"]
parts.append(f"\nContext:\nServer: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}")
if ctx.get("online_players"):
parts.append(f"Online: {', '.join(ctx['online_players'])}")
pos = ctx.get("player_position")
if pos:
parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})")
return "\n".join(parts)
def call_claude(model: str, system: str, user: str) -> dict:
"""Call Claude API and return parsed JSON response."""
headers = {
"x-api-key": API_KEY,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
body = {
"model": model,
"max_tokens": 500,
"system": system,
"messages": [{"role": "user", "content": user}],
}
resp = requests.post(API_URL, headers=headers, json=body, timeout=60)
resp.raise_for_status()
data = resp.json()
text = data["content"][0]["text"]
input_tokens = data["usage"]["input_tokens"]
output_tokens = data["usage"]["output_tokens"]
# Parse JSON from response
try:
parsed = json.loads(text)
except json.JSONDecodeError:
# Try to extract JSON from markdown wrapper
match = re.search(r'\{[\s\S]*\}', text)
if match:
parsed = json.loads(match.group())
else:
parsed = {"commands": [], "message": "", "reasoning": "parse_failed"}
return {
"parsed": parsed,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"raw": text,
}
def estimate_cost(examples: list) -> dict:
"""Estimate API cost without making calls."""
counts = {"sudo": 0, "god": 0, "god_system": 0}
for ex in examples:
mode = determine_mode(ex)
counts[mode] += 1
total = 0
details = {}
for mode, count in counts.items():
if count == 0:
continue
cfg = MODELS[mode]
# Estimate ~600 input tokens (system + user), ~150 output tokens
input_cost = (count * 600 / 1_000_000) * cfg["input_per_m"]
output_cost = (count * 150 / 1_000_000) * cfg["output_per_m"]
mode_cost = input_cost + output_cost
total += mode_cost
details[mode] = {"count": count, "model": cfg["model"], "est_cost": round(mode_cost, 4)}
return {"total_est": round(total, 4), "details": details}
def main():
parser = argparse.ArgumentParser(description="Claude distillation pipeline")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--mode", choices=["sudo", "god", "all"], default="all")
parser.add_argument("--budget", type=float, default=5.00)
parser.add_argument("--output", default=str(OUTPUT_DEFAULT))
args = parser.parse_args()
with open(DATASET) as f:
examples = [json.loads(l) for l in f if l.strip()]
# Filter by mode
if args.mode != "all":
examples = [ex for ex in examples if determine_mode(ex) == args.mode]
# Skip examples that are just abstention/empty (no useful distillation target)
examples = [ex for ex in examples if ex.get("id", "").startswith("abstain-") is False]
print(f"Distillation pipeline")
print(f" Dataset: {len(examples)} examples")
print(f" Budget: ${args.budget:.2f}")
print(f" Output: {args.output}")
cost_est = estimate_cost(examples)
print(f"\n Estimated cost: ${cost_est['total_est']:.4f}")
for mode, d in cost_est["details"].items():
print(f" {mode}: {d['count']} examples via {d['model']} (${d['est_cost']:.4f})")
if args.dry_run:
print(f"\n[DRY RUN] Would process {len(examples)} examples for ~${cost_est['total_est']:.4f}")
return
if cost_est["total_est"] > args.budget:
print(f"\nEstimated cost ${cost_est['total_est']:.4f} exceeds budget ${args.budget:.2f}. Reduce examples or increase budget.")
return
# Process
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_input = 0
total_output = 0
total_cost = 0.0
processed = 0
errors = 0
results = []
for i, ex in enumerate(examples):
mode = determine_mode(ex)
cfg = MODELS[mode]
system_prompt = get_prompt(mode)
user_msg = build_user_message(ex)
# Check budget
if total_cost >= args.budget:
print(f"\n Budget reached at ${total_cost:.4f} after {processed} examples")
break
try:
result = call_claude(cfg["model"], system_prompt, user_msg)
except Exception as e:
print(f" [{i+1}/{len(examples)}] ERROR: {e}")
errors += 1
time.sleep(1)
continue
parsed = result["parsed"]
total_input += result["input_tokens"]
total_output += result["output_tokens"]
# Calculate cost
cost = (result["input_tokens"] / 1_000_000) * cfg["input_per_m"] + \
(result["output_tokens"] / 1_000_000) * cfg["output_per_m"]
total_cost += cost
processed += 1
# Build distilled example
distilled = dict(ex)
distilled["output"] = {
"reasoning": parsed.get("reasoning", ""),
"commands": parsed.get("commands", []),
"message": parsed.get("message") if mode in ("god", "god_system") else None,
"safety_flags": ex["output"].get("safety_flags", []),
}
distilled["metadata"] = dict(ex.get("metadata", {}))
distilled["metadata"]["distilled_by"] = cfg["model"]
distilled["metadata"]["distilled_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
distilled["id"] = f"distill-{ex.get('id', f'ex-{i}')}"
results.append(distilled)
cmds = len(parsed.get("commands", []))
msg_preview = (parsed.get("message", "") or "")[:40]
print(f" [{i+1}/{len(examples)}] ({mode}) {ex['input']['user_message'][:45]:45} [{cmds} cmds] ${cost:.4f} {msg_preview}")
# Rate limit: ~50 req/min for Haiku, ~20 for Sonnet
time.sleep(0.5 if mode == "sudo" else 1.5)
# Write results
with open(output_path, "w") as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"\n{'='*60}")
print(f"Distillation complete")
print(f" Processed: {processed}")
print(f" Errors: {errors}")
print(f" Input tokens: {total_input:,}")
print(f" Output tokens: {total_output:,}")
print(f" Total cost: ${total_cost:.4f}")
print(f" Output: {output_path}")
if __name__ == "__main__":
main()