#!/usr/bin/env python3 """ LoRA fine-tuning script for Minecraft AI ops assistant. Base model: Qwen/Qwen3-8B (dense, Apache 2.0) Method: QLoRA (4-bit base + LoRA adapters in FP16) Framework: Unsloth + HuggingFace TRL Target GPU: RTX 3090 Ti (24GB VRAM) Usage: python train_lora.py python train_lora.py --epochs 5 --lr 2e-4 --rank 32 """ import argparse import json import os from pathlib import Path def determine_mode(example: dict) -> str: """Determine prompt mode from the example.""" query = example["input"]["user_message"] eid = example.get("id", "") if query.lower().startswith("pray "): return "god" elif eid.startswith("negative-") and "god" in query.lower(): return "god_system" elif example.get("source") == "prayer_log": return "god" return "sudo" def get_system_prompt(mode: str) -> str: """Get the system prompt for training. Import from project if available, fallback to inline. Prepends /no_think to disable Qwen3/3.5 thinking tokens.""" try: import sys script_dir = Path(__file__).resolve().parent project_root = script_dir.parent.parent sys.path.insert(0, str(project_root)) from agent.prompts.system_prompts import get_prompt prompt = get_prompt(mode) except ImportError: # Minimal fallback prompts if mode == "god": prompt = "You are God in a Minecraft server. Return JSON: {\"message\": \"...\", \"commands\": [...], \"reasoning\": \"...\"}" elif mode == "god_system": prompt = "You are God performing an unprompted intervention. Return JSON: {\"message\": \"...\", \"commands\": [...]}" else: prompt = "You are a Minecraft 1.21 command translator. Return JSON: {\"commands\": [...], \"reasoning\": \"...\"}" # Disable thinking for Qwen3/3.5 models return "/no_think\n" + prompt def load_seed_dataset(path: str) -> list: """Load seed dataset and format for SFT training with system prompts and mode awareness.""" examples = [] with open(path) as f: for line in f: if not line.strip(): continue ex = json.loads(line) # Determine mode and get system prompt mode = determine_mode(ex) system_prompt = get_system_prompt(mode) # Build the training conversation inp = ex["input"] out = ex["output"] query = inp["user_message"] ctx = inp.get("server_context", {}) # User message with context user_parts = [f"Request from slingshooter08: {query}"] user_parts.append(f"\nContext:\nServer: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}") if ctx.get("online_players"): user_parts.append(f"Online: {', '.join(ctx['online_players'])}") pos = ctx.get("player_position") if pos: user_parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})") user_msg = "\n".join(user_parts) # Assistant response as JSON — includes risk_level for decision transparency risk_level = ex.get("metadata", {}).get("risk_level", 3) response = { "risk_level": risk_level, "reasoning": out.get("reasoning", ""), "commands": out.get("commands", []), } # Include message field for god modes if mode in ("god", "god_system"): response["message"] = out.get("message") or "" examples.append({ "conversations": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_msg}, {"role": "assistant", "content": json.dumps(response)}, ] }) return examples def load_tool_dataset(path: str) -> list: """Load multi-turn tool-calling training data. These examples are already in Qwen3 chat format with tool_call tags. They contain multi-turn conversations: user → assistant tool_call → tool result → ... → final response. We pass them through as pre-formatted text (not as conversations for the chat template). """ examples = [] with open(path) as f: for line in f: if not line.strip(): continue ex = json.loads(line) # Tool training data has a 'messages' field with multi-turn conversations if "messages" in ex: examples.append({"conversations": ex["messages"]}) # Or pre-formatted qwen3_text elif "qwen3_text" in ex: examples.append({"text": ex["qwen3_text"]}) return examples def load_dataset(seed_path: str, tool_path: str = None) -> list: """Load and merge all training datasets.""" examples = load_seed_dataset(seed_path) print(f" Seed examples: {len(examples)}") if tool_path and os.path.exists(tool_path): tool_examples = load_tool_dataset(tool_path) print(f" Tool examples: {len(tool_examples)}") examples.extend(tool_examples) else: print(f" Tool examples: 0 (no file)") return examples def main(): parser = argparse.ArgumentParser(description="LoRA fine-tuning for Minecraft AI") parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Base model from HuggingFace") parser.add_argument("--dataset", default="", help="Dataset path (default: auto-detect)") parser.add_argument("--output", default="", help="Output directory for adapter") parser.add_argument("--rank", type=int, default=16, help="LoRA rank") parser.add_argument("--alpha", type=int, default=32, help="LoRA alpha") parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") parser.add_argument("--epochs", type=int, default=1, help="Training epochs") parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size") parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps") parser.add_argument("--max-seq-len", type=int, default=2048, help="Max sequence length") parser.add_argument("--dry-run", action="store_true", help="Load model and dataset but don't train") args = parser.parse_args() # Auto-detect paths script_dir = Path(__file__).resolve().parent project_root = script_dir.parent.parent if not args.dataset: args.dataset = str(project_root / "data" / "processed" / "seed_dataset.jsonl") tool_dataset = str(project_root / "data" / "processed" / "tool_training.jsonl") if not args.output: args.output = str(project_root / "training" / "checkpoints" / "qwen3-8b-mc-lora") print(f"Base model: {args.model}") print(f"Dataset: {args.dataset}") print(f"Output: {args.output}") print(f"LoRA rank: {args.rank}, alpha: {args.alpha}") print(f"LR: {args.lr}") print(f"Epochs: {args.epochs}") print(f"Batch: {args.batch_size} x {args.grad_accum} grad accum") print(f"Max seq len: {args.max_seq_len}") print() # Load dataset (seed + tool-calling) print("Loading datasets...") train_data = load_dataset(args.dataset, tool_dataset) print(f" Total: {len(train_data)} training examples") if args.dry_run: print("\n[DRY RUN] Would load model and train. Exiting.") for ex in train_data[:2]: print(f" Example: {ex['conversations'][0]['content'][:80]}...") return # Import Unsloth (heavy imports, only when actually training) from unsloth import FastLanguageModel from trl import SFTTrainer, SFTConfig from datasets import Dataset # Load model with 4-bit quantization print(f"\nLoading {args.model} in 4-bit...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_len, load_in_4bit=True, dtype=None, # auto-detect ) # Add LoRA adapters print(f"Adding LoRA adapters (rank={args.rank}, alpha={args.alpha})...") model = FastLanguageModel.get_peft_model( model, r=args.rank, lora_alpha=args.alpha, lora_dropout=0, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], bias="none", use_gradient_checkpointing="unsloth", ) # Prepare dataset dataset = Dataset.from_list(train_data) def formatting_func(examples): """Format conversations for the chat template. Handles both: - 'conversations': list of role/content dicts → apply chat template - 'text': pre-formatted Qwen3 text (tool-calling examples) → pass through """ texts = [] convos_list = examples.get("conversations", []) text_list = examples.get("text", []) for i in range(len(convos_list)): convos = convos_list[i] pre_text = text_list[i] if i < len(text_list) else None if pre_text: # Pre-formatted tool-calling example texts.append(pre_text) elif convos: # Standard conversation → apply chat template text = tokenizer.apply_chat_template( convos, tokenize=False, add_generation_prompt=False ) texts.append(text) else: texts.append("") return {"text": texts} dataset = dataset.map(formatting_func, batched=True) # Training config training_args = SFTConfig( output_dir=args.output, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, lr_scheduler_type="cosine", warmup_ratio=0.1, weight_decay=0.01, bf16=True, logging_steps=1, save_strategy="epoch", seed=42, max_seq_length=args.max_seq_len, dataset_text_field="text", packing=True, ) # Train print(f"\nStarting training ({args.epochs} epochs, {len(train_data)} examples)...") trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, args=training_args, ) trainer.train() # Save adapter print(f"\nSaving LoRA adapter to {args.output}...") model.save_pretrained(args.output) tokenizer.save_pretrained(args.output) print("\nTraining complete!") print(f"Adapter saved to: {args.output}") print(f"To convert to GGUF for Ollama, use:") print(f" python -m unsloth.save --model {args.output} --output_type gguf") if __name__ == "__main__": main()