#!/usr/bin/env python3 """ Mortdecai Ollama Gateway — authenticated proxy with power metering. Sits in front of Ollama, provides: - API key authentication - Power/cost tracking (GPU utilization × TDP × electricity rate) - Usage dashboard - Spending cap enforcement - Health check endpoint Usage: python3 gateway.py OLLAMA_URL=http://localhost:11434 API_KEY=mk_test python3 gateway.py """ import json import os import time import threading import subprocess from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import urlparse, parse_qs import requests # --- Config --- OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434") LISTEN_PORT = int(os.environ.get("GATEWAY_PORT", "8434")) API_KEY = os.environ.get("API_KEY", "mk_mortdecai_default") ELECTRICITY_RATE = float(os.environ.get("ELECTRICITY_RATE", "0.15")) # $/kWh GPU_TDP_WATTS = float(os.environ.get("GPU_TDP_WATTS", "54")) # Strix Halo iGPU SYSTEM_OVERHEAD_WATTS = float(os.environ.get("SYSTEM_OVERHEAD_WATTS", "30")) # CPU/RAM/etc idle draw during inference SPENDING_CAP = float(os.environ.get("SPENDING_CAP", "10.00")) # $ before refusing requests STATS_FILE = os.environ.get("STATS_FILE", "/var/lib/mortdecai-gateway/stats.json") # --- Stats tracking --- _stats_lock = threading.Lock() _stats = { "total_requests": 0, "total_tokens_in": 0, "total_tokens_out": 0, "total_inference_seconds": 0, "total_energy_wh": 0.0, "total_cost": 0.0, "started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"), "last_request_at": None, "requests_rejected": 0, } def _load_stats(): global _stats try: with open(STATS_FILE) as f: saved = json.load(f) _stats.update(saved) except: pass def _save_stats(): try: os.makedirs(os.path.dirname(STATS_FILE), exist_ok=True) with open(STATS_FILE, "w") as f: json.dump(_stats, f, indent=2) except: pass def _track_request(tokens_in, tokens_out, duration_seconds): """Track a completed inference request.""" with _stats_lock: _stats["total_requests"] += 1 _stats["total_tokens_in"] += tokens_in _stats["total_tokens_out"] += tokens_out _stats["total_inference_seconds"] += duration_seconds _stats["last_request_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ") # Power calculation # GPU draws TDP watts during inference, plus system overhead total_watts = GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS energy_wh = (total_watts * duration_seconds) / 3600 cost = (energy_wh / 1000) * ELECTRICITY_RATE _stats["total_energy_wh"] += energy_wh _stats["total_cost"] += cost # Save every 10 requests if _stats["total_requests"] % 10 == 0: _save_stats() def _check_budget(): """Returns True if under spending cap.""" with _stats_lock: return _stats["total_cost"] < SPENDING_CAP def _get_gpu_utilization(): """Get current GPU utilization via nvidia-smi or rocm-smi.""" try: # Try nvidia-smi first result = subprocess.run( ["nvidia-smi", "--query-gpu=utilization.gpu,temperature.gpu,power.draw", "--format=csv,noheader,nounits"], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: parts = [p.strip() for p in result.stdout.strip().split(",")] return { "utilization": float(parts[0]), "temperature": float(parts[1]), "power_watts": float(parts[2]) if parts[2] != "[N/A]" else GPU_TDP_WATTS, "source": "nvidia-smi" } except: pass try: # Try rocm-smi for AMD result = subprocess.run( ["rocm-smi", "--showuse", "--showtemp", "--json"], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: data = json.loads(result.stdout) # Parse rocm-smi JSON (format varies by version) for card_id, card_data in data.items(): if isinstance(card_data, dict): return { "utilization": float(card_data.get("GPU use (%)", 0)), "temperature": float(card_data.get("Temperature (Sensor edge) (C)", 0)), "power_watts": GPU_TDP_WATTS, "source": "rocm-smi" } except: pass return {"utilization": 0, "temperature": 0, "power_watts": 0, "source": "unavailable"} # --- HTTP Handler --- class GatewayHandler(BaseHTTPRequestHandler): def log_message(self, fmt, *args): pass # Quiet def _check_auth(self): auth = self.headers.get("Authorization", "") if auth == f"Bearer {API_KEY}" or auth == API_KEY: return True self._send_json(401, {"error": "Invalid API key"}) return False def _send_json(self, status, data): body = json.dumps(data).encode() self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", len(body)) self.end_headers() self.wfile.write(body) def _proxy_to_ollama(self, path, body=None): """Proxy request to Ollama and track usage.""" if not _check_budget(): with _stats_lock: _stats["requests_rejected"] += 1 self._send_json(402, { "error": "Spending cap reached", "total_cost": _stats["total_cost"], "cap": SPENDING_CAP, }) return t0 = time.time() try: if body: r = requests.post(f"{OLLAMA_URL}{path}", json=body, timeout=120) else: r = requests.get(f"{OLLAMA_URL}{path}", timeout=10) duration = time.time() - t0 data = r.json() # Track token usage from response tokens_in = data.get("prompt_eval_count", 0) tokens_out = data.get("eval_count", 0) if tokens_in or tokens_out: _track_request(tokens_in, tokens_out, duration) # Add gateway metadata to response if isinstance(data, dict): data["_gateway"] = { "duration_seconds": round(duration, 2), "energy_wh": round((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600, 4), "estimated_cost": round(((GPU_TDP_WATTS + SYSTEM_OVERHEAD_WATTS) * duration / 3600 / 1000) * ELECTRICITY_RATE, 6), "total_cost": round(_stats["total_cost"], 4), "budget_remaining": round(SPENDING_CAP - _stats["total_cost"], 4), } self._send_json(r.status_code, data) except requests.exceptions.ConnectionError: self._send_json(502, {"error": "Ollama is not running"}) except requests.exceptions.Timeout: self._send_json(504, {"error": "Ollama timeout"}) except Exception as e: self._send_json(500, {"error": str(e)}) def do_GET(self): parsed = urlparse(self.path) # Public endpoints (no auth) if parsed.path == "/health": try: r = requests.get(f"{OLLAMA_URL}/api/tags", timeout=5) models = [m["name"] for m in r.json().get("models", [])] self._send_json(200, {"status": "ok", "ollama": "connected", "models": models}) except: self._send_json(503, {"status": "error", "ollama": "disconnected"}) return if parsed.path == "/stats": if not self._check_auth(): return gpu = _get_gpu_utilization() with _stats_lock: stats_copy = dict(_stats) stats_copy["gpu"] = gpu stats_copy["config"] = { "gpu_tdp_watts": GPU_TDP_WATTS, "system_overhead_watts": SYSTEM_OVERHEAD_WATTS, "electricity_rate": ELECTRICITY_RATE, "spending_cap": SPENDING_CAP, } self._send_json(200, stats_copy) return if parsed.path == "/dashboard": self._serve_dashboard() return # Proxy everything else to Ollama if not self._check_auth(): return self._proxy_to_ollama(self.path) def do_POST(self): if not self._check_auth(): return length = int(self.headers.get("Content-Length", 0)) body = json.loads(self.rfile.read(length)) if length > 0 else None self._proxy_to_ollama(self.path, body) def _serve_dashboard(self): """Simple HTML dashboard showing usage stats.""" with _stats_lock: s = dict(_stats) gpu = _get_gpu_utilization() html = f"""