Self-play: --api-key for authenticated gateway connections

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-20 19:40:01 -04:00
parent aa5400e31e
commit 0f043384e5
+15 -5
View File
@@ -42,6 +42,9 @@ sys.path.insert(0, str(ROOT))
OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl" OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl"
# Module-level API key, set from args in main()
_API_KEY = None
# --- RCON (persistent connection) --- # --- RCON (persistent connection) ---
from agent.tools.persistent_rcon import get_rcon from agent.tools.persistent_rcon import get_rcon
@@ -72,8 +75,8 @@ def rcon_command(cmd, host, port, password):
# --- LLM calls --- # --- LLM calls ---
def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None): def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None, api_key=None):
"""Call Ollama and return content with think blocks stripped.""" """Call Ollama (or gateway) and return content with think blocks stripped."""
payload = { payload = {
"model": model, "model": model,
"messages": [ "messages": [
@@ -85,7 +88,10 @@ def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, f
} }
if fmt: if fmt:
payload["format"] = fmt payload["format"] = fmt
r = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=120) headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
r = requests.post(f"{ollama_url}/api/chat", json=payload, headers=headers, timeout=120)
r.raise_for_status() r.raise_for_status()
content = r.json()["message"]["content"] content = r.json()["message"]["content"]
# Strip think blocks # Strip think blocks
@@ -345,7 +351,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass,
# First attempt # First attempt
try: try:
raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json") raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json", api_key=_API_KEY)
result = json.loads(raw) result = json.loads(raw)
except (json.JSONDecodeError, Exception) as e: except (json.JSONDecodeError, Exception) as e:
# Try extracting JSON # Try extracting JSON
@@ -401,7 +407,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass,
retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands." retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands."
try: try:
raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json") raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json", api_key=_API_KEY)
result = json.loads(raw) result = json.loads(raw)
except: except:
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '') match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
@@ -542,6 +548,7 @@ def main():
parser = argparse.ArgumentParser(description="Self-play training data generator") parser = argparse.ArgumentParser(description="Self-play training data generator")
parser.add_argument("--model", default="qwen3-8b-mc-lora-v3") parser.add_argument("--model", default="qwen3-8b-mc-lora-v3")
parser.add_argument("--ollama-url", default="http://192.168.0.141:11434") parser.add_argument("--ollama-url", default="http://192.168.0.141:11434")
parser.add_argument("--api-key", default=None, help="API key for authenticated gateways")
parser.add_argument("--rcon-host", default="192.168.0.244") parser.add_argument("--rcon-host", default="192.168.0.244")
parser.add_argument("--rcon-port", type=int, default=25578) parser.add_argument("--rcon-port", type=int, default=25578)
parser.add_argument("--rcon-pass", default="REDACTED_RCON") parser.add_argument("--rcon-pass", default="REDACTED_RCON")
@@ -553,6 +560,9 @@ def main():
parser.add_argument("--max-retries", type=int, default=2) parser.add_argument("--max-retries", type=int, default=2)
args = parser.parse_args() args = parser.parse_args()
global _API_KEY
_API_KEY = args.api_key
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)] tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
print(f"Self-play training data generator") print(f"Self-play training data generator")