Files
Mortdecai/training/scripts/train_lora.py
T
Seth a3d139e04f Mortdecai v4 pre-training: /no_think, dedup, 3,369 examples
- /no_think prepended to all system prompts (seed + tool training)
- Deduplicated seed dataset (435 dupes removed)
- Training script updated for Qwen3.5-9B + /no_think
- 2,210 seed + 1,159 tool-calling = 3,369 total examples

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 20:15:00 -04:00

292 lines
11 KiB
Python

#!/usr/bin/env python3
"""
LoRA fine-tuning script for Minecraft AI ops assistant.
Base model: Qwen/Qwen3-8B (dense, Apache 2.0)
Method: QLoRA (4-bit base + LoRA adapters in FP16)
Framework: Unsloth + HuggingFace TRL
Target GPU: RTX 3090 Ti (24GB VRAM)
Usage:
python train_lora.py
python train_lora.py --epochs 5 --lr 2e-4 --rank 32
"""
import argparse
import json
import os
from pathlib import Path
def determine_mode(example: dict) -> str:
"""Determine prompt mode from the example."""
query = example["input"]["user_message"]
eid = example.get("id", "")
if query.lower().startswith("pray "):
return "god"
elif eid.startswith("negative-") and "god" in query.lower():
return "god_system"
elif example.get("source") == "prayer_log":
return "god"
return "sudo"
def get_system_prompt(mode: str) -> str:
"""Get the system prompt for training. Import from project if available, fallback to inline.
Prepends /no_think to disable Qwen3/3.5 thinking tokens."""
try:
import sys
script_dir = Path(__file__).resolve().parent
project_root = script_dir.parent.parent
sys.path.insert(0, str(project_root))
from agent.prompts.system_prompts import get_prompt
prompt = get_prompt(mode)
except ImportError:
# Minimal fallback prompts
if mode == "god":
prompt = "You are God in a Minecraft server. Return JSON: {\"message\": \"...\", \"commands\": [...], \"reasoning\": \"...\"}"
elif mode == "god_system":
prompt = "You are God performing an unprompted intervention. Return JSON: {\"message\": \"...\", \"commands\": [...]}"
else:
prompt = "You are a Minecraft 1.21 command translator. Return JSON: {\"commands\": [...], \"reasoning\": \"...\"}"
# Disable thinking for Qwen3/3.5 models
return "/no_think\n" + prompt
def load_seed_dataset(path: str) -> list:
"""Load seed dataset and format for SFT training with system prompts and mode awareness."""
examples = []
with open(path) as f:
for line in f:
if not line.strip():
continue
ex = json.loads(line)
# Determine mode and get system prompt
mode = determine_mode(ex)
system_prompt = get_system_prompt(mode)
# Build the training conversation
inp = ex["input"]
out = ex["output"]
query = inp["user_message"]
ctx = inp.get("server_context", {})
# User message with context
user_parts = [f"Request from slingshooter08: {query}"]
user_parts.append(f"\nContext:\nServer: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}")
if ctx.get("online_players"):
user_parts.append(f"Online: {', '.join(ctx['online_players'])}")
pos = ctx.get("player_position")
if pos:
user_parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})")
user_msg = "\n".join(user_parts)
# Assistant response as JSON — includes risk_level for decision transparency
risk_level = ex.get("metadata", {}).get("risk_level", 3)
response = {
"risk_level": risk_level,
"reasoning": out.get("reasoning", ""),
"commands": out.get("commands", []),
}
# Include message field for god modes
if mode in ("god", "god_system"):
response["message"] = out.get("message") or ""
examples.append({
"conversations": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": json.dumps(response)},
]
})
return examples
def load_tool_dataset(path: str) -> list:
"""Load multi-turn tool-calling training data.
These examples are already in Qwen3 chat format with tool_call tags.
They contain multi-turn conversations: user → assistant tool_call → tool result → ... → final response.
We pass them through as pre-formatted text (not as conversations for the chat template).
"""
examples = []
with open(path) as f:
for line in f:
if not line.strip():
continue
ex = json.loads(line)
# Tool training data has a 'messages' field with multi-turn conversations
if "messages" in ex:
examples.append({"conversations": ex["messages"]})
# Or pre-formatted qwen3_text
elif "qwen3_text" in ex:
examples.append({"text": ex["qwen3_text"]})
return examples
def load_dataset(seed_path: str, tool_path: str = None) -> list:
"""Load and merge all training datasets."""
examples = load_seed_dataset(seed_path)
print(f" Seed examples: {len(examples)}")
if tool_path and os.path.exists(tool_path):
tool_examples = load_tool_dataset(tool_path)
print(f" Tool examples: {len(tool_examples)}")
examples.extend(tool_examples)
else:
print(f" Tool examples: 0 (no file)")
return examples
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for Minecraft AI")
parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Base model from HuggingFace")
parser.add_argument("--dataset", default="", help="Dataset path (default: auto-detect)")
parser.add_argument("--output", default="", help="Output directory for adapter")
parser.add_argument("--rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--alpha", type=int, default=32, help="LoRA alpha")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--epochs", type=int, default=1, help="Training epochs")
parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps")
parser.add_argument("--max-seq-len", type=int, default=2048, help="Max sequence length")
parser.add_argument("--dry-run", action="store_true", help="Load model and dataset but don't train")
args = parser.parse_args()
# Auto-detect paths
script_dir = Path(__file__).resolve().parent
project_root = script_dir.parent.parent
if not args.dataset:
args.dataset = str(project_root / "data" / "processed" / "seed_dataset.jsonl")
tool_dataset = str(project_root / "data" / "processed" / "tool_training.jsonl")
if not args.output:
args.output = str(project_root / "training" / "checkpoints" / "qwen3-8b-mc-lora")
print(f"Base model: {args.model}")
print(f"Dataset: {args.dataset}")
print(f"Output: {args.output}")
print(f"LoRA rank: {args.rank}, alpha: {args.alpha}")
print(f"LR: {args.lr}")
print(f"Epochs: {args.epochs}")
print(f"Batch: {args.batch_size} x {args.grad_accum} grad accum")
print(f"Max seq len: {args.max_seq_len}")
print()
# Load dataset (seed + tool-calling)
print("Loading datasets...")
train_data = load_dataset(args.dataset, tool_dataset)
print(f" Total: {len(train_data)} training examples")
if args.dry_run:
print("\n[DRY RUN] Would load model and train. Exiting.")
for ex in train_data[:2]:
print(f" Example: {ex['conversations'][0]['content'][:80]}...")
return
# Import Unsloth (heavy imports, only when actually training)
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
# Load model with 4-bit quantization
print(f"\nLoading {args.model} in 4-bit...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_len,
load_in_4bit=True,
dtype=None, # auto-detect
)
# Add LoRA adapters
print(f"Adding LoRA adapters (rank={args.rank}, alpha={args.alpha})...")
model = FastLanguageModel.get_peft_model(
model,
r=args.rank,
lora_alpha=args.alpha,
lora_dropout=0,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
bias="none",
use_gradient_checkpointing="unsloth",
)
# Prepare dataset
dataset = Dataset.from_list(train_data)
def formatting_func(examples):
"""Format conversations for the chat template. Handles both:
- 'conversations': list of role/content dicts → apply chat template
- 'text': pre-formatted Qwen3 text (tool-calling examples) → pass through
"""
texts = []
convos_list = examples.get("conversations", [])
text_list = examples.get("text", [])
for i in range(len(convos_list)):
convos = convos_list[i]
pre_text = text_list[i] if i < len(text_list) else None
if pre_text:
# Pre-formatted tool-calling example
texts.append(pre_text)
elif convos:
# Standard conversation → apply chat template
text = tokenizer.apply_chat_template(
convos, tokenize=False, add_generation_prompt=False
)
texts.append(text)
else:
texts.append("")
return {"text": texts}
dataset = dataset.map(formatting_func, batched=True)
# Training config
training_args = SFTConfig(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.01,
bf16=True,
logging_steps=1,
save_strategy="epoch",
seed=42,
max_seq_length=args.max_seq_len,
dataset_text_field="text",
packing=True,
)
# Train
print(f"\nStarting training ({args.epochs} epochs, {len(train_data)} examples)...")
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=training_args,
)
trainer.train()
# Save adapter
print(f"\nSaving LoRA adapter to {args.output}...")
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
print("\nTraining complete!")
print(f"Adapter saved to: {args.output}")
print(f"To convert to GGUF for Ollama, use:")
print(f" python -m unsloth.save --model {args.output} --output_type gguf")
if __name__ == "__main__":
main()