Mortdecai Gateway — authenticated Ollama proxy with power metering

- API key auth on all inference endpoints
- Power/cost tracking: GPU TDP × inference time × electricity rate
- Spending cap enforcement
- Web dashboard with live stats
- Docker compose for AMD ROCm (Strix Halo) or NVIDIA
- Auto-setup script with GGUF loading
- Tested against local Ollama

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-20 19:26:43 -04:00
commit c5865feb35
7 changed files with 561 additions and 0 deletions
+319
View File
@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Mortdecai Ollama Gateway — authenticated proxy with power metering.
Sits in front of Ollama, provides:
- API key authentication
- Power/cost tracking (GPU utilization × TDP × electricity rate)
- Usage dashboard
- Spending cap enforcement
- Health check endpoint
Usage:
python3 gateway.py
OLLAMA_URL=http://localhost:11434 API_KEY=mk_test python3 gateway.py
"""
import json
import os
import time
import threading
import subprocess
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
import requests
# --- Config ---
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
LISTEN_PORT = int(os.environ.get("GATEWAY_PORT", "8434"))
API_KEY = os.environ.get("API_KEY", "mk_mortdecai_default")
ELECTRICITY_RATE = float(os.environ.get("ELECTRICITY_RATE", "0.15")) # $/kWh
GPU_TDP_WATTS = float(os.environ.get("GPU_TDP_WATTS", "54")) # Strix Halo iGPU
SYSTEM_OVERHEAD_WATTS = float(os.environ.get("SYSTEM_OVERHEAD_WATTS", "30")) # CPU/RAM/etc idle draw during inference
SPENDING_CAP = float(os.environ.get("SPENDING_CAP", "10.00")) # $ before refusing requests
STATS_FILE = os.environ.get("STATS_FILE", "/var/lib/mortdecai-gateway/stats.json")
# --- Stats tracking ---
_stats_lock = threading.Lock()
_stats = {
"total_requests": 0,
"total_tokens_in": 0,
"total_tokens_out": 0,
"total_inference_seconds": 0,
"total_energy_wh": 0.0,
"total_cost": 0.0,
"started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"last_request_at": None,
"requests_rejected": 0,
}
def _load_stats():
global _stats
try:
with open(STATS_FILE) as f:
saved = json.load(f)
_stats.update(saved)
except:
pass
def _save_stats():
try:
os.makedirs(os.path.dirname(STATS_FILE), exist_ok=True)
with open(STATS_FILE, "w") as f:
json.dump(_stats, f, indent=2)
except:
pass
def _track_request(tokens_in, tokens_out, duration_seconds):
"""Track a completed inference request."""
with _stats_lock:
_stats["total_requests"] += 1
_stats["total_tokens_in"] += tokens_in
_stats["total_tokens_out"] += tokens_out
_stats["total_inference_seconds"] += duration_seconds
_stats["last_request_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ")
# Power calculation
# GPU draws TDP watts during inference, plus system overhead
total_watts = GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS
energy_wh = (total_watts * duration_seconds) / 3600
cost = (energy_wh / 1000) * ELECTRICITY_RATE
_stats["total_energy_wh"] += energy_wh
_stats["total_cost"] += cost
# Save every 10 requests
if _stats["total_requests"] % 10 == 0:
_save_stats()
def _check_budget():
"""Returns True if under spending cap."""
with _stats_lock:
return _stats["total_cost"] < SPENDING_CAP
def _get_gpu_utilization():
"""Get current GPU utilization via nvidia-smi or rocm-smi."""
try:
# Try nvidia-smi first
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu,temperature.gpu,power.draw",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
parts = [p.strip() for p in result.stdout.strip().split(",")]
return {
"utilization": float(parts[0]),
"temperature": float(parts[1]),
"power_watts": float(parts[2]) if parts[2] != "[N/A]" else GPU_TDP_WATTS,
"source": "nvidia-smi"
}
except:
pass
try:
# Try rocm-smi for AMD
result = subprocess.run(
["rocm-smi", "--showuse", "--showtemp", "--json"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
data = json.loads(result.stdout)
# Parse rocm-smi JSON (format varies by version)
for card_id, card_data in data.items():
if isinstance(card_data, dict):
return {
"utilization": float(card_data.get("GPU use (%)", 0)),
"temperature": float(card_data.get("Temperature (Sensor edge) (C)", 0)),
"power_watts": GPU_TDP_WATTS,
"source": "rocm-smi"
}
except:
pass
return {"utilization": 0, "temperature": 0, "power_watts": 0, "source": "unavailable"}
# --- HTTP Handler ---
class GatewayHandler(BaseHTTPRequestHandler):
def log_message(self, fmt, *args):
pass # Quiet
def _check_auth(self):
auth = self.headers.get("Authorization", "")
if auth == f"Bearer {API_KEY}" or auth == API_KEY:
return True
self._send_json(401, {"error": "Invalid API key"})
return False
def _send_json(self, status, data):
body = json.dumps(data).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(body))
self.end_headers()
self.wfile.write(body)
def _proxy_to_ollama(self, path, body=None):
"""Proxy request to Ollama and track usage."""
if not _check_budget():
with _stats_lock:
_stats["requests_rejected"] += 1
self._send_json(402, {
"error": "Spending cap reached",
"total_cost": _stats["total_cost"],
"cap": SPENDING_CAP,
})
return
t0 = time.time()
try:
if body:
r = requests.post(f"{OLLAMA_URL}{path}", json=body, timeout=120)
else:
r = requests.get(f"{OLLAMA_URL}{path}", timeout=10)
duration = time.time() - t0
data = r.json()
# Track token usage from response
tokens_in = data.get("prompt_eval_count", 0)
tokens_out = data.get("eval_count", 0)
if tokens_in or tokens_out:
_track_request(tokens_in, tokens_out, duration)
# Add gateway metadata to response
if isinstance(data, dict):
data["_gateway"] = {
"duration_seconds": round(duration, 2),
"energy_wh": round((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600, 4),
"estimated_cost": round(((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600 / 1000) * ELECTRICITY_RATE, 6),
"total_cost": round(_stats["total_cost"], 4),
"budget_remaining": round(SPENDING_CAP - _stats["total_cost"], 4),
}
self._send_json(r.status_code, data)
except requests.exceptions.ConnectionError:
self._send_json(502, {"error": "Ollama is not running"})
except requests.exceptions.Timeout:
self._send_json(504, {"error": "Ollama timeout"})
except Exception as e:
self._send_json(500, {"error": str(e)})
def do_GET(self):
parsed = urlparse(self.path)
# Public endpoints (no auth)
if parsed.path == "/health":
try:
r = requests.get(f"{OLLAMA_URL}/api/tags", timeout=5)
models = [m["name"] for m in r.json().get("models", [])]
self._send_json(200, {"status": "ok", "ollama": "connected", "models": models})
except:
self._send_json(503, {"status": "error", "ollama": "disconnected"})
return
if parsed.path == "/stats":
if not self._check_auth():
return
gpu = _get_gpu_utilization()
with _stats_lock:
stats_copy = dict(_stats)
stats_copy["gpu"] = gpu
stats_copy["config"] = {
"gpu_tdp_watts": GPU_TDP_WATTS,
"system_overhead_watts": SYSTEM_OVERHEAD_WATTS,
"electricity_rate": ELECTRICITY_RATE,
"spending_cap": SPENDING_CAP,
}
self._send_json(200, stats_copy)
return
if parsed.path == "/dashboard":
self._serve_dashboard()
return
# Proxy everything else to Ollama
if not self._check_auth():
return
self._proxy_to_ollama(self.path)
def do_POST(self):
if not self._check_auth():
return
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length)) if length > 0 else None
self._proxy_to_ollama(self.path, body)
def _serve_dashboard(self):
"""Simple HTML dashboard showing usage stats."""
with _stats_lock:
s = dict(_stats)
gpu = _get_gpu_utilization()
html = f"""<!DOCTYPE html>
<html><head><title>Mortdecai Gateway</title>
<meta http-equiv="refresh" content="10">
<style>
body {{ font-family: monospace; background: #1a1a1a; color: #e0e0e0; padding: 2rem; }}
h1 {{ color: #D35400; }}
.stat {{ background: #252525; border: 1px solid #333; padding: 1rem; margin: 0.5rem 0; border-radius: 6px; }}
.label {{ color: #999; }}
.value {{ color: #D35400; font-size: 1.2rem; font-weight: bold; }}
</style></head><body>
<h1>Mortdecai Gateway</h1>
<div class="stat"><span class="label">Status:</span> <span class="value">{"ACTIVE" if _check_budget() else "PAUSED (cap reached)"}</span></div>
<div class="stat"><span class="label">Total Requests:</span> <span class="value">{s['total_requests']}</span></div>
<div class="stat"><span class="label">Tokens (in/out):</span> <span class="value">{s['total_tokens_in']:,} / {s['total_tokens_out']:,}</span></div>
<div class="stat"><span class="label">Inference Time:</span> <span class="value">{s['total_inference_seconds']:.0f}s</span></div>
<div class="stat"><span class="label">Energy Used:</span> <span class="value">{s['total_energy_wh']:.1f} Wh</span></div>
<div class="stat"><span class="label">Estimated Cost:</span> <span class="value">${s['total_cost']:.4f} / ${SPENDING_CAP:.2f}</span></div>
<div class="stat"><span class="label">Rejected (over cap):</span> <span class="value">{s['requests_rejected']}</span></div>
<div class="stat"><span class="label">GPU Utilization:</span> <span class="value">{gpu['utilization']}% ({gpu['source']})</span></div>
<div class="stat"><span class="label">GPU Temperature:</span> <span class="value">{gpu['temperature']}°C</span></div>
<div class="stat"><span class="label">Last Request:</span> <span class="value">{s['last_request_at'] or 'never'}</span></div>
<div class="stat"><span class="label">Config:</span> <span class="value">TDP={GPU_TDP_WATTS}W + {SYSTEM_OVERHEAD_WATTS}W overhead @ ${ELECTRICITY_RATE}/kWh</span></div>
</body></html>"""
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(html.encode())
def main():
_load_stats()
print(f"Mortdecai Gateway starting")
print(f" Ollama: {OLLAMA_URL}")
print(f" Listen: 0.0.0.0:{LISTEN_PORT}")
print(f" TDP: {GPU_TDP_WATTS}W + {SYSTEM_OVERHEAD_WATTS}W overhead")
print(f" Rate: ${ELECTRICITY_RATE}/kWh")
print(f" Cap: ${SPENDING_CAP}")
print(f" Dashboard: http://localhost:{LISTEN_PORT}/dashboard")
# Save stats periodically
def _periodic_save():
while True:
time.sleep(60)
with _stats_lock:
_save_stats()
t = threading.Thread(target=_periodic_save, daemon=True)
t.start()
server = HTTPServer(("0.0.0.0", LISTEN_PORT), GatewayHandler)
server.serve_forever()
if __name__ == "__main__":
main()