Files
mortdecai-gateway/gateway.py
T
Seth 0b37d7de79 Add opt-in model update endpoint + API key support
Gateway: POST /admin/update-model downloads new GGUF and reloads.
Disabled by default — requires ALLOW_MODEL_UPDATES=true in .env.
Matt controls whether remote model updates are allowed.

Self-play: --api-key flag for authenticated gateway connections.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-20 19:39:50 -04:00

362 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Mortdecai Ollama Gateway — authenticated proxy with power metering.
Sits in front of Ollama, provides:
- API key authentication
- Power/cost tracking (GPU utilization × TDP × electricity rate)
- Usage dashboard
- Spending cap enforcement
- Health check endpoint
Usage:
python3 gateway.py
OLLAMA_URL=http://localhost:11434 API_KEY=mk_test python3 gateway.py
"""
import json
import os
import time
import threading
import subprocess
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
import requests
# --- Config ---
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
LISTEN_PORT = int(os.environ.get("GATEWAY_PORT", "8434"))
API_KEY = os.environ.get("API_KEY", "mk_mortdecai_default")
ELECTRICITY_RATE = float(os.environ.get("ELECTRICITY_RATE", "0.15")) # $/kWh
GPU_TDP_WATTS = float(os.environ.get("GPU_TDP_WATTS", "54")) # Strix Halo iGPU
SYSTEM_OVERHEAD_WATTS = float(os.environ.get("SYSTEM_OVERHEAD_WATTS", "30")) # CPU/RAM/etc idle draw during inference
SPENDING_CAP = float(os.environ.get("SPENDING_CAP", "10.00")) # $ before refusing requests
STATS_FILE = os.environ.get("STATS_FILE", "/var/lib/mortdecai-gateway/stats.json")
# --- Stats tracking ---
_stats_lock = threading.Lock()
_stats = {
"total_requests": 0,
"total_tokens_in": 0,
"total_tokens_out": 0,
"total_inference_seconds": 0,
"total_energy_wh": 0.0,
"total_cost": 0.0,
"started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"last_request_at": None,
"requests_rejected": 0,
}
def _load_stats():
global _stats
try:
with open(STATS_FILE) as f:
saved = json.load(f)
_stats.update(saved)
except:
pass
def _save_stats():
try:
os.makedirs(os.path.dirname(STATS_FILE), exist_ok=True)
with open(STATS_FILE, "w") as f:
json.dump(_stats, f, indent=2)
except:
pass
def _track_request(tokens_in, tokens_out, duration_seconds):
"""Track a completed inference request."""
with _stats_lock:
_stats["total_requests"] += 1
_stats["total_tokens_in"] += tokens_in
_stats["total_tokens_out"] += tokens_out
_stats["total_inference_seconds"] += duration_seconds
_stats["last_request_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ")
# Power calculation
# GPU draws TDP watts during inference, plus system overhead
total_watts = GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS
energy_wh = (total_watts * duration_seconds) / 3600
cost = (energy_wh / 1000) * ELECTRICITY_RATE
_stats["total_energy_wh"] += energy_wh
_stats["total_cost"] += cost
# Save every 10 requests
if _stats["total_requests"] % 10 == 0:
_save_stats()
def _check_budget():
"""Returns True if under spending cap."""
with _stats_lock:
return _stats["total_cost"] < SPENDING_CAP
def _get_gpu_utilization():
"""Get current GPU utilization via nvidia-smi or rocm-smi."""
try:
# Try nvidia-smi first
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu,temperature.gpu,power.draw",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
parts = [p.strip() for p in result.stdout.strip().split(",")]
return {
"utilization": float(parts[0]),
"temperature": float(parts[1]),
"power_watts": float(parts[2]) if parts[2] != "[N/A]" else GPU_TDP_WATTS,
"source": "nvidia-smi"
}
except:
pass
try:
# Try rocm-smi for AMD
result = subprocess.run(
["rocm-smi", "--showuse", "--showtemp", "--json"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
data = json.loads(result.stdout)
# Parse rocm-smi JSON (format varies by version)
for card_id, card_data in data.items():
if isinstance(card_data, dict):
return {
"utilization": float(card_data.get("GPU use (%)", 0)),
"temperature": float(card_data.get("Temperature (Sensor edge) (C)", 0)),
"power_watts": GPU_TDP_WATTS,
"source": "rocm-smi"
}
except:
pass
return {"utilization": 0, "temperature": 0, "power_watts": 0, "source": "unavailable"}
# --- HTTP Handler ---
class GatewayHandler(BaseHTTPRequestHandler):
def log_message(self, fmt, *args):
pass # Quiet
def _check_auth(self):
auth = self.headers.get("Authorization", "")
if auth == f"Bearer {API_KEY}" or auth == API_KEY:
return True
self._send_json(401, {"error": "Invalid API key"})
return False
def _send_json(self, status, data):
body = json.dumps(data).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(body))
self.end_headers()
self.wfile.write(body)
def _proxy_to_ollama(self, path, body=None):
"""Proxy request to Ollama and track usage."""
if not _check_budget():
with _stats_lock:
_stats["requests_rejected"] += 1
self._send_json(402, {
"error": "Spending cap reached",
"total_cost": _stats["total_cost"],
"cap": SPENDING_CAP,
})
return
t0 = time.time()
try:
if body:
r = requests.post(f"{OLLAMA_URL}{path}", json=body, timeout=120)
else:
r = requests.get(f"{OLLAMA_URL}{path}", timeout=10)
duration = time.time() - t0
data = r.json()
# Track token usage from response
tokens_in = data.get("prompt_eval_count", 0)
tokens_out = data.get("eval_count", 0)
if tokens_in or tokens_out:
_track_request(tokens_in, tokens_out, duration)
# Add gateway metadata to response
if isinstance(data, dict):
data["_gateway"] = {
"duration_seconds": round(duration, 2),
"energy_wh": round((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600, 4),
"estimated_cost": round(((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600 / 1000) * ELECTRICITY_RATE, 6),
"total_cost": round(_stats["total_cost"], 4),
"budget_remaining": round(SPENDING_CAP - _stats["total_cost"], 4),
}
self._send_json(r.status_code, data)
except requests.exceptions.ConnectionError:
self._send_json(502, {"error": "Ollama is not running"})
except requests.exceptions.Timeout:
self._send_json(504, {"error": "Ollama timeout"})
except Exception as e:
self._send_json(500, {"error": str(e)})
def do_GET(self):
parsed = urlparse(self.path)
# Public endpoints (no auth)
if parsed.path == "/health":
try:
r = requests.get(f"{OLLAMA_URL}/api/tags", timeout=5)
models = [m["name"] for m in r.json().get("models", [])]
self._send_json(200, {"status": "ok", "ollama": "connected", "models": models})
except:
self._send_json(503, {"status": "error", "ollama": "disconnected"})
return
if parsed.path == "/stats":
if not self._check_auth():
return
gpu = _get_gpu_utilization()
with _stats_lock:
stats_copy = dict(_stats)
stats_copy["gpu"] = gpu
stats_copy["config"] = {
"gpu_tdp_watts": GPU_TDP_WATTS,
"system_overhead_watts": SYSTEM_OVERHEAD_WATTS,
"electricity_rate": ELECTRICITY_RATE,
"spending_cap": SPENDING_CAP,
}
self._send_json(200, stats_copy)
return
if parsed.path == "/dashboard":
self._serve_dashboard()
return
# Proxy everything else to Ollama
if not self._check_auth():
return
self._proxy_to_ollama(self.path)
def do_POST(self):
if not self._check_auth():
return
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length)) if length > 0 else None
# Model update endpoint — downloads new GGUF and reloads
if self.path == "/admin/update-model" and body:
self._handle_model_update(body)
return
self._proxy_to_ollama(self.path, body)
def _handle_model_update(self, body):
"""Download a new GGUF from a URL and reload the model.
Request: {"url": "https://mortdec.ai/dl/...", "name": "mortdecai-v5"}
This is opt-in — the gateway operator must enable ALLOW_MODEL_UPDATES=true.
"""
if os.environ.get("ALLOW_MODEL_UPDATES", "false").lower() != "true":
self._send_json(403, {"error": "Model updates disabled. Set ALLOW_MODEL_UPDATES=true in .env to enable."})
return
url = body.get("url")
name = body.get("name", "mortdecai-latest")
if not url:
self._send_json(400, {"error": "url is required"})
return
try:
import subprocess
# Download GGUF
gguf_path = f"/models/{name}.gguf"
print(f"Downloading model from {url}...")
r = requests.get(url, stream=True, timeout=600)
r.raise_for_status()
with open(f"models/{name}.gguf", "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
# Create Modelfile and load
subprocess.run(
["docker", "exec", "mortdecai-ollama", "ollama", "create", name, "-f", f"/models/Modelfile"],
timeout=120, check=True
)
self._send_json(200, {"status": "ok", "model": name, "message": "Model updated and loaded"})
except Exception as e:
self._send_json(500, {"error": f"Update failed: {e}"})
def _serve_dashboard(self):
"""Simple HTML dashboard showing usage stats."""
with _stats_lock:
s = dict(_stats)
gpu = _get_gpu_utilization()
html = f"""<!DOCTYPE html>
<html><head><title>Mortdecai Gateway</title>
<meta http-equiv="refresh" content="10">
<style>
body {{ font-family: monospace; background: #1a1a1a; color: #e0e0e0; padding: 2rem; }}
h1 {{ color: #D35400; }}
.stat {{ background: #252525; border: 1px solid #333; padding: 1rem; margin: 0.5rem 0; border-radius: 6px; }}
.label {{ color: #999; }}
.value {{ color: #D35400; font-size: 1.2rem; font-weight: bold; }}
</style></head><body>
<h1>Mortdecai Gateway</h1>
<div class="stat"><span class="label">Status:</span> <span class="value">{"ACTIVE" if _check_budget() else "PAUSED (cap reached)"}</span></div>
<div class="stat"><span class="label">Total Requests:</span> <span class="value">{s['total_requests']}</span></div>
<div class="stat"><span class="label">Tokens (in/out):</span> <span class="value">{s['total_tokens_in']:,} / {s['total_tokens_out']:,}</span></div>
<div class="stat"><span class="label">Inference Time:</span> <span class="value">{s['total_inference_seconds']:.0f}s</span></div>
<div class="stat"><span class="label">Energy Used:</span> <span class="value">{s['total_energy_wh']:.1f} Wh</span></div>
<div class="stat"><span class="label">Estimated Cost:</span> <span class="value">${s['total_cost']:.4f} / ${SPENDING_CAP:.2f}</span></div>
<div class="stat"><span class="label">Rejected (over cap):</span> <span class="value">{s['requests_rejected']}</span></div>
<div class="stat"><span class="label">GPU Utilization:</span> <span class="value">{gpu['utilization']}% ({gpu['source']})</span></div>
<div class="stat"><span class="label">GPU Temperature:</span> <span class="value">{gpu['temperature']}°C</span></div>
<div class="stat"><span class="label">Last Request:</span> <span class="value">{s['last_request_at'] or 'never'}</span></div>
<div class="stat"><span class="label">Config:</span> <span class="value">TDP={GPU_TDP_WATTS}W + {SYSTEM_OVERHEAD_WATTS}W overhead @ ${ELECTRICITY_RATE}/kWh</span></div>
</body></html>"""
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(html.encode())
def main():
_load_stats()
print(f"Mortdecai Gateway starting")
print(f" Ollama: {OLLAMA_URL}")
print(f" Listen: 0.0.0.0:{LISTEN_PORT}")
print(f" TDP: {GPU_TDP_WATTS}W + {SYSTEM_OVERHEAD_WATTS}W overhead")
print(f" Rate: ${ELECTRICITY_RATE}/kWh")
print(f" Cap: ${SPENDING_CAP}")
print(f" Dashboard: http://localhost:{LISTEN_PORT}/dashboard")
# Save stats periodically
def _periodic_save():
while True:
time.sleep(60)
with _stats_lock:
_save_stats()
t = threading.Thread(target=_periodic_save, daemon=True)
t.start()
server = HTTPServer(("0.0.0.0", LISTEN_PORT), GatewayHandler)
server.serve_forever()
if __name__ == "__main__":
main()