#!/usr/bin/env python3 """ Validate training dataset against schema and print summary statistics. Usage: python3 validate_dataset.py [path_to_jsonl] """ import json import sys from collections import Counter from pathlib import Path VALID_SOURCES = {"repair_pattern", "prayer_log", "sudo_log", "bug_report", "session_history", "manual", "synthetic"} VALID_CATEGORIES = {"command_gen", "troubleshoot", "info", "safety", "negative"} VALID_DIFFICULTIES = {"easy", "medium", "hard"} VALID_SAFETY_FLAGS = {"destructive", "teleport", "op_required", "affects_all_players"} def validate_example(ex: dict, line_num: int) -> list: errors = [] prefix = f"line {line_num} (id={ex.get('id', '?')})" # Required fields for field in ("id", "source", "category", "input", "output"): if field not in ex: errors.append(f"{prefix}: missing required field '{field}'") # Source validation if ex.get("source") not in VALID_SOURCES: errors.append(f"{prefix}: invalid source '{ex.get('source')}' (valid: {VALID_SOURCES})") # Category validation if ex.get("category") not in VALID_CATEGORIES: errors.append(f"{prefix}: invalid category '{ex.get('category')}' (valid: {VALID_CATEGORIES})") # Input validation inp = ex.get("input", {}) if not isinstance(inp, dict): errors.append(f"{prefix}: 'input' must be an object") elif not inp.get("user_message"): errors.append(f"{prefix}: 'input.user_message' is required and non-empty") # Output validation out = ex.get("output", {}) if not isinstance(out, dict): errors.append(f"{prefix}: 'output' must be an object") elif "commands" not in out: errors.append(f"{prefix}: 'output.commands' is required (can be empty list)") else: cmds = out["commands"] if not isinstance(cmds, list): errors.append(f"{prefix}: 'output.commands' must be a list") for i, cmd in enumerate(cmds): if not isinstance(cmd, str): errors.append(f"{prefix}: command[{i}] must be a string") elif cmd.startswith("/"): errors.append(f"{prefix}: command[{i}] starts with '/' -- should have no leading slash") # Safety flags validation for flag in out.get("safety_flags", []): if flag not in VALID_SAFETY_FLAGS: errors.append(f"{prefix}: invalid safety_flag '{flag}'") # Metadata validation meta = ex.get("metadata", {}) if meta.get("difficulty") and meta["difficulty"] not in VALID_DIFFICULTIES: errors.append(f"{prefix}: invalid difficulty '{meta['difficulty']}'") return errors def main(): path = sys.argv[1] if len(sys.argv) > 1 else "data/processed/seed_dataset.jsonl" p = Path(path) if not p.exists(): print(f"File not found: {path}") sys.exit(1) examples = [] parse_errors = [] with open(p) as f: for i, line in enumerate(f, 1): line = line.strip() if not line: continue try: examples.append((i, json.loads(line))) except json.JSONDecodeError as e: parse_errors.append(f"line {i}: JSON parse error: {e}") if parse_errors: print("JSON PARSE ERRORS:") for e in parse_errors: print(f" {e}") print() all_errors = [] ids_seen = set() sources = Counter() categories = Counter() difficulties = Counter() has_negative = 0 has_reasoning = 0 total_commands = 0 for line_num, ex in examples: errs = validate_example(ex, line_num) all_errors.extend(errs) eid = ex.get("id", "") if eid in ids_seen: all_errors.append(f"line {line_num}: duplicate id '{eid}'") ids_seen.add(eid) sources[ex.get("source", "?")] += 1 categories[ex.get("category", "?")] += 1 difficulties[ex.get("metadata", {}).get("difficulty", "?")] += 1 if ex.get("negative_output"): has_negative += 1 if ex.get("output", {}).get("reasoning"): has_reasoning += 1 total_commands += len(ex.get("output", {}).get("commands", [])) print(f"=== Dataset Validation: {p.name} ===") print(f"Total examples: {len(examples)}") print(f"Total commands: {total_commands}") print(f"With negative_output (wrong->right pairs): {has_negative}") print(f"With reasoning (chain-of-thought): {has_reasoning}") print() print("By source:") for k, v in sources.most_common(): print(f" {k}: {v}") print() print("By category:") for k, v in categories.most_common(): print(f" {k}: {v}") print() print("By difficulty:") for k, v in difficulties.most_common(): print(f" {k}: {v}") print() if all_errors: print(f"VALIDATION ERRORS ({len(all_errors)}):") for e in all_errors: print(f" {e}") sys.exit(1) else: print("All examples valid.") if __name__ == "__main__": main()