Files
mortdecai-gateway/gateway.py
T
Seth c5865feb35 Mortdecai Gateway — authenticated Ollama proxy with power metering
- API key auth on all inference endpoints
- Power/cost tracking: GPU TDP × inference time × electricity rate
- Spending cap enforcement
- Web dashboard with live stats
- Docker compose for AMD ROCm (Strix Halo) or NVIDIA
- Auto-setup script with GGUF loading
- Tested against local Ollama

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-20 19:26:43 -04:00

320 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Mortdecai Ollama Gateway — authenticated proxy with power metering.
Sits in front of Ollama, provides:
- API key authentication
- Power/cost tracking (GPU utilization × TDP × electricity rate)
- Usage dashboard
- Spending cap enforcement
- Health check endpoint
Usage:
python3 gateway.py
OLLAMA_URL=http://localhost:11434 API_KEY=mk_test python3 gateway.py
"""
import json
import os
import time
import threading
import subprocess
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
import requests
# --- Config ---
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
LISTEN_PORT = int(os.environ.get("GATEWAY_PORT", "8434"))
API_KEY = os.environ.get("API_KEY", "mk_mortdecai_default")
ELECTRICITY_RATE = float(os.environ.get("ELECTRICITY_RATE", "0.15")) # $/kWh
GPU_TDP_WATTS = float(os.environ.get("GPU_TDP_WATTS", "54")) # Strix Halo iGPU
SYSTEM_OVERHEAD_WATTS = float(os.environ.get("SYSTEM_OVERHEAD_WATTS", "30")) # CPU/RAM/etc idle draw during inference
SPENDING_CAP = float(os.environ.get("SPENDING_CAP", "10.00")) # $ before refusing requests
STATS_FILE = os.environ.get("STATS_FILE", "/var/lib/mortdecai-gateway/stats.json")
# --- Stats tracking ---
_stats_lock = threading.Lock()
_stats = {
"total_requests": 0,
"total_tokens_in": 0,
"total_tokens_out": 0,
"total_inference_seconds": 0,
"total_energy_wh": 0.0,
"total_cost": 0.0,
"started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"last_request_at": None,
"requests_rejected": 0,
}
def _load_stats():
global _stats
try:
with open(STATS_FILE) as f:
saved = json.load(f)
_stats.update(saved)
except:
pass
def _save_stats():
try:
os.makedirs(os.path.dirname(STATS_FILE), exist_ok=True)
with open(STATS_FILE, "w") as f:
json.dump(_stats, f, indent=2)
except:
pass
def _track_request(tokens_in, tokens_out, duration_seconds):
"""Track a completed inference request."""
with _stats_lock:
_stats["total_requests"] += 1
_stats["total_tokens_in"] += tokens_in
_stats["total_tokens_out"] += tokens_out
_stats["total_inference_seconds"] += duration_seconds
_stats["last_request_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ")
# Power calculation
# GPU draws TDP watts during inference, plus system overhead
total_watts = GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS
energy_wh = (total_watts * duration_seconds) / 3600
cost = (energy_wh / 1000) * ELECTRICITY_RATE
_stats["total_energy_wh"] += energy_wh
_stats["total_cost"] += cost
# Save every 10 requests
if _stats["total_requests"] % 10 == 0:
_save_stats()
def _check_budget():
"""Returns True if under spending cap."""
with _stats_lock:
return _stats["total_cost"] < SPENDING_CAP
def _get_gpu_utilization():
"""Get current GPU utilization via nvidia-smi or rocm-smi."""
try:
# Try nvidia-smi first
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu,temperature.gpu,power.draw",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
parts = [p.strip() for p in result.stdout.strip().split(",")]
return {
"utilization": float(parts[0]),
"temperature": float(parts[1]),
"power_watts": float(parts[2]) if parts[2] != "[N/A]" else GPU_TDP_WATTS,
"source": "nvidia-smi"
}
except:
pass
try:
# Try rocm-smi for AMD
result = subprocess.run(
["rocm-smi", "--showuse", "--showtemp", "--json"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
data = json.loads(result.stdout)
# Parse rocm-smi JSON (format varies by version)
for card_id, card_data in data.items():
if isinstance(card_data, dict):
return {
"utilization": float(card_data.get("GPU use (%)", 0)),
"temperature": float(card_data.get("Temperature (Sensor edge) (C)", 0)),
"power_watts": GPU_TDP_WATTS,
"source": "rocm-smi"
}
except:
pass
return {"utilization": 0, "temperature": 0, "power_watts": 0, "source": "unavailable"}
# --- HTTP Handler ---
class GatewayHandler(BaseHTTPRequestHandler):
def log_message(self, fmt, *args):
pass # Quiet
def _check_auth(self):
auth = self.headers.get("Authorization", "")
if auth == f"Bearer {API_KEY}" or auth == API_KEY:
return True
self._send_json(401, {"error": "Invalid API key"})
return False
def _send_json(self, status, data):
body = json.dumps(data).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(body))
self.end_headers()
self.wfile.write(body)
def _proxy_to_ollama(self, path, body=None):
"""Proxy request to Ollama and track usage."""
if not _check_budget():
with _stats_lock:
_stats["requests_rejected"] += 1
self._send_json(402, {
"error": "Spending cap reached",
"total_cost": _stats["total_cost"],
"cap": SPENDING_CAP,
})
return
t0 = time.time()
try:
if body:
r = requests.post(f"{OLLAMA_URL}{path}", json=body, timeout=120)
else:
r = requests.get(f"{OLLAMA_URL}{path}", timeout=10)
duration = time.time() - t0
data = r.json()
# Track token usage from response
tokens_in = data.get("prompt_eval_count", 0)
tokens_out = data.get("eval_count", 0)
if tokens_in or tokens_out:
_track_request(tokens_in, tokens_out, duration)
# Add gateway metadata to response
if isinstance(data, dict):
data["_gateway"] = {
"duration_seconds": round(duration, 2),
"energy_wh": round((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600, 4),
"estimated_cost": round(((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600 / 1000) * ELECTRICITY_RATE, 6),
"total_cost": round(_stats["total_cost"], 4),
"budget_remaining": round(SPENDING_CAP - _stats["total_cost"], 4),
}
self._send_json(r.status_code, data)
except requests.exceptions.ConnectionError:
self._send_json(502, {"error": "Ollama is not running"})
except requests.exceptions.Timeout:
self._send_json(504, {"error": "Ollama timeout"})
except Exception as e:
self._send_json(500, {"error": str(e)})
def do_GET(self):
parsed = urlparse(self.path)
# Public endpoints (no auth)
if parsed.path == "/health":
try:
r = requests.get(f"{OLLAMA_URL}/api/tags", timeout=5)
models = [m["name"] for m in r.json().get("models", [])]
self._send_json(200, {"status": "ok", "ollama": "connected", "models": models})
except:
self._send_json(503, {"status": "error", "ollama": "disconnected"})
return
if parsed.path == "/stats":
if not self._check_auth():
return
gpu = _get_gpu_utilization()
with _stats_lock:
stats_copy = dict(_stats)
stats_copy["gpu"] = gpu
stats_copy["config"] = {
"gpu_tdp_watts": GPU_TDP_WATTS,
"system_overhead_watts": SYSTEM_OVERHEAD_WATTS,
"electricity_rate": ELECTRICITY_RATE,
"spending_cap": SPENDING_CAP,
}
self._send_json(200, stats_copy)
return
if parsed.path == "/dashboard":
self._serve_dashboard()
return
# Proxy everything else to Ollama
if not self._check_auth():
return
self._proxy_to_ollama(self.path)
def do_POST(self):
if not self._check_auth():
return
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length)) if length > 0 else None
self._proxy_to_ollama(self.path, body)
def _serve_dashboard(self):
"""Simple HTML dashboard showing usage stats."""
with _stats_lock:
s = dict(_stats)
gpu = _get_gpu_utilization()
html = f"""<!DOCTYPE html>
<html><head><title>Mortdecai Gateway</title>
<meta http-equiv="refresh" content="10">
<style>
body {{ font-family: monospace; background: #1a1a1a; color: #e0e0e0; padding: 2rem; }}
h1 {{ color: #D35400; }}
.stat {{ background: #252525; border: 1px solid #333; padding: 1rem; margin: 0.5rem 0; border-radius: 6px; }}
.label {{ color: #999; }}
.value {{ color: #D35400; font-size: 1.2rem; font-weight: bold; }}
</style></head><body>
<h1>Mortdecai Gateway</h1>
<div class="stat"><span class="label">Status:</span> <span class="value">{"ACTIVE" if _check_budget() else "PAUSED (cap reached)"}</span></div>
<div class="stat"><span class="label">Total Requests:</span> <span class="value">{s['total_requests']}</span></div>
<div class="stat"><span class="label">Tokens (in/out):</span> <span class="value">{s['total_tokens_in']:,} / {s['total_tokens_out']:,}</span></div>
<div class="stat"><span class="label">Inference Time:</span> <span class="value">{s['total_inference_seconds']:.0f}s</span></div>
<div class="stat"><span class="label">Energy Used:</span> <span class="value">{s['total_energy_wh']:.1f} Wh</span></div>
<div class="stat"><span class="label">Estimated Cost:</span> <span class="value">${s['total_cost']:.4f} / ${SPENDING_CAP:.2f}</span></div>
<div class="stat"><span class="label">Rejected (over cap):</span> <span class="value">{s['requests_rejected']}</span></div>
<div class="stat"><span class="label">GPU Utilization:</span> <span class="value">{gpu['utilization']}% ({gpu['source']})</span></div>
<div class="stat"><span class="label">GPU Temperature:</span> <span class="value">{gpu['temperature']}°C</span></div>
<div class="stat"><span class="label">Last Request:</span> <span class="value">{s['last_request_at'] or 'never'}</span></div>
<div class="stat"><span class="label">Config:</span> <span class="value">TDP={GPU_TDP_WATTS}W + {SYSTEM_OVERHEAD_WATTS}W overhead @ ${ELECTRICITY_RATE}/kWh</span></div>
</body></html>"""
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(html.encode())
def main():
_load_stats()
print(f"Mortdecai Gateway starting")
print(f" Ollama: {OLLAMA_URL}")
print(f" Listen: 0.0.0.0:{LISTEN_PORT}")
print(f" TDP: {GPU_TDP_WATTS}W + {SYSTEM_OVERHEAD_WATTS}W overhead")
print(f" Rate: ${ELECTRICITY_RATE}/kWh")
print(f" Cap: ${SPENDING_CAP}")
print(f" Dashboard: http://localhost:{LISTEN_PORT}/dashboard")
# Save stats periodically
def _periodic_save():
while True:
time.sleep(60)
with _stats_lock:
_save_stats()
t = threading.Thread(target=_periodic_save, daemon=True)
t.start()
server = HTTPServer(("0.0.0.0", LISTEN_PORT), GatewayHandler)
server.serve_forever()
if __name__ == "__main__":
main()