Self-play: --api-key for authenticated gateway connections
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -42,6 +42,9 @@ sys.path.insert(0, str(ROOT))
|
|||||||
|
|
||||||
OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl"
|
OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl"
|
||||||
|
|
||||||
|
# Module-level API key, set from args in main()
|
||||||
|
_API_KEY = None
|
||||||
|
|
||||||
# --- RCON (persistent connection) ---
|
# --- RCON (persistent connection) ---
|
||||||
|
|
||||||
from agent.tools.persistent_rcon import get_rcon
|
from agent.tools.persistent_rcon import get_rcon
|
||||||
@@ -72,8 +75,8 @@ def rcon_command(cmd, host, port, password):
|
|||||||
|
|
||||||
# --- LLM calls ---
|
# --- LLM calls ---
|
||||||
|
|
||||||
def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None):
|
def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None, api_key=None):
|
||||||
"""Call Ollama and return content with think blocks stripped."""
|
"""Call Ollama (or gateway) and return content with think blocks stripped."""
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -85,7 +88,10 @@ def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, f
|
|||||||
}
|
}
|
||||||
if fmt:
|
if fmt:
|
||||||
payload["format"] = fmt
|
payload["format"] = fmt
|
||||||
r = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=120)
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
r = requests.post(f"{ollama_url}/api/chat", json=payload, headers=headers, timeout=120)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
content = r.json()["message"]["content"]
|
content = r.json()["message"]["content"]
|
||||||
# Strip think blocks
|
# Strip think blocks
|
||||||
@@ -345,7 +351,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass,
|
|||||||
|
|
||||||
# First attempt
|
# First attempt
|
||||||
try:
|
try:
|
||||||
raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json")
|
raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json", api_key=_API_KEY)
|
||||||
result = json.loads(raw)
|
result = json.loads(raw)
|
||||||
except (json.JSONDecodeError, Exception) as e:
|
except (json.JSONDecodeError, Exception) as e:
|
||||||
# Try extracting JSON
|
# Try extracting JSON
|
||||||
@@ -401,7 +407,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass,
|
|||||||
retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands."
|
retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json")
|
raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json", api_key=_API_KEY)
|
||||||
result = json.loads(raw)
|
result = json.loads(raw)
|
||||||
except:
|
except:
|
||||||
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
||||||
@@ -542,6 +548,7 @@ def main():
|
|||||||
parser = argparse.ArgumentParser(description="Self-play training data generator")
|
parser = argparse.ArgumentParser(description="Self-play training data generator")
|
||||||
parser.add_argument("--model", default="qwen3-8b-mc-lora-v3")
|
parser.add_argument("--model", default="qwen3-8b-mc-lora-v3")
|
||||||
parser.add_argument("--ollama-url", default="http://192.168.0.141:11434")
|
parser.add_argument("--ollama-url", default="http://192.168.0.141:11434")
|
||||||
|
parser.add_argument("--api-key", default=None, help="API key for authenticated gateways")
|
||||||
parser.add_argument("--rcon-host", default="192.168.0.244")
|
parser.add_argument("--rcon-host", default="192.168.0.244")
|
||||||
parser.add_argument("--rcon-port", type=int, default=25578)
|
parser.add_argument("--rcon-port", type=int, default=25578)
|
||||||
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
||||||
@@ -553,6 +560,9 @@ def main():
|
|||||||
parser.add_argument("--max-retries", type=int, default=2)
|
parser.add_argument("--max-retries", type=int, default=2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
global _API_KEY
|
||||||
|
_API_KEY = args.api_key
|
||||||
|
|
||||||
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
|
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
|
||||||
|
|
||||||
print(f"Self-play training data generator")
|
print(f"Self-play training data generator")
|
||||||
|
|||||||
Reference in New Issue
Block a user