From ca52b94ffde3064f9ae8cfeb1e92fd544bab1056 Mon Sep 17 00:00:00 2001 From: Mortdecai Date: Fri, 10 Apr 2026 01:27:52 -0400 Subject: [PATCH] feat: add FastAPI app with WebSocket streaming and escalation loop --- server/main.py | 254 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_main.py | 32 ++++++ 2 files changed, 286 insertions(+) create mode 100644 server/main.py create mode 100644 tests/test_main.py diff --git a/server/main.py b/server/main.py new file mode 100644 index 0000000..96d36f0 --- /dev/null +++ b/server/main.py @@ -0,0 +1,254 @@ +"""FastAPI application — WebSocket streaming, REST endpoints, background workers.""" + +import asyncio +import base64 +import logging +import random +import time +from contextlib import asynccontextmanager +from pathlib import Path + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse, HTMLResponse +from fastapi.staticfiles import StaticFiles + +from server.config import config +from server.escalation import EscalationEngine +from server.asset_pool import AssetPool +from server.streaming import StreamManager +from server.prompts import get_image_prompt, get_voice_text, get_direct_address_text + +logger = logging.getLogger("ai-hell") + +# Global instances (set during lifespan or create_app) +escalation: EscalationEngine | None = None +pool: AssetPool | None = None +stream: StreamManager | None = None +asset_gen = None # AssetGenerator (lazy, needs GPU) +voice_gen = None # VoiceGenerator (lazy, needs GPU) +_workers: list[asyncio.Task] = [] + + +def create_app(skip_models: bool = False) -> FastAPI: + """Create the FastAPI app. skip_models=True for testing without GPU.""" + global escalation, pool, stream, asset_gen, voice_gen + + escalation = EscalationEngine() + pool = AssetPool() + stream = StreamManager() + + @asynccontextmanager + async def lifespan(app: FastAPI): + global asset_gen, voice_gen + escalation.start_session() + + if not skip_models: + from server.asset_generator import AssetGenerator + from server.voice_generator import VoiceGenerator + + logger.info("Loading SDXL Turbo...") + asset_gen = AssetGenerator() + logger.info("Loading XTTS v2...") + voice_gen = VoiceGenerator() + logger.info("Models loaded. Generating initial batch...") + + # Generate initial asset batch in background + loop = asyncio.get_running_loop() + _workers.append(asyncio.create_task(_initial_batch(loop))) + _workers.append(asyncio.create_task(_background_generator(loop))) + _workers.append(asyncio.create_task(_escalation_loop())) + + yield + + # Shutdown workers + for task in _workers: + task.cancel() + + the_app = FastAPI(title="AI Hell", lifespan=lifespan) + + # Mount assets directory for static serving + assets_dir = Path(config.assets_dir) + assets_dir.mkdir(parents=True, exist_ok=True) + (assets_dir / "img").mkdir(exist_ok=True) + (assets_dir / "audio").mkdir(exist_ok=True) + the_app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets") + + # --- REST endpoints --- + + @the_app.get("/") + async def index(): + html_path = Path(__file__).parent.parent / "frontend" / "index.html" + if html_path.exists(): + return FileResponse(html_path, media_type="text/html") + return HTMLResponse("

AI Hell

Frontend not found.

") + + @the_app.get("/status") + async def status(): + return { + "intensity": round(escalation.get_intensity(), 2), + "connected_clients": stream.client_count, + **pool.get_status(), + } + + @the_app.post("/reset") + async def reset(): + escalation.reset() + return {"status": "ok"} + + # --- WebSocket --- + + @the_app.websocket("/stream") + async def stream_ws(ws: WebSocket): + await ws.accept() + stream.add_client(ws) + # Send current state immediately + intensity = escalation.get_intensity() + params = escalation.get_phase_params(intensity) + await ws.send_text( + __import__("json").dumps({ + "type": "phase", + "intensity": round(intensity, 2), + "params": params, + }) + ) + try: + while True: + await ws.receive_text() # Keep alive, ignore pings + except WebSocketDisconnect: + pass + finally: + stream.remove_client(ws) + + # Serve frontend shader files + @the_app.get("/shaders/{filename}") + async def serve_shader(filename: str): + shader_path = Path(__file__).parent.parent / "frontend" / "shaders" / filename + if shader_path.exists(): + return FileResponse(shader_path, media_type="text/plain") + return HTMLResponse("Not found", status_code=404) + + return the_app + + +async def _initial_batch(loop: asyncio.AbstractEventLoop) -> None: + """Generate the initial pool of images and audio clips.""" + batch_size = config.escalation.initial_batch_size + img_count = int(batch_size * 0.75) + audio_count = batch_size - img_count + + for i in range(img_count): + severity = (i / max(1, img_count - 1)) * 4.0 # Spread across severity range + prompt = get_image_prompt(severity) + try: + data = await asyncio.to_thread(asset_gen.generate, prompt) + pool.add_image(data, severity=severity) + logger.info(f"Initial image {i+1}/{img_count} (severity={severity:.1f})") + except Exception as e: + logger.error(f"Failed to generate initial image: {e}") + + for i in range(audio_count): + severity = (i / max(1, audio_count - 1)) * 4.0 + text = get_voice_text() + try: + data = await asyncio.to_thread(voice_gen.generate, text) + pool.add_audio(data, severity=severity) + logger.info(f"Initial audio {i+1}/{audio_count} (severity={severity:.1f})") + except Exception as e: + logger.error(f"Failed to generate initial audio: {e}") + + logger.info("Initial batch complete.") + + +async def _background_generator(loop: asyncio.AbstractEventLoop) -> None: + """Continuously generate new assets biased toward current viewer needs.""" + while True: + await asyncio.sleep(random.uniform(10, 30)) + if stream.client_count == 0: + continue + + intensity = escalation.get_intensity() + severity = escalation.select_severity(intensity) + + # Alternate between images and audio + if random.random() < 0.7: # 70% images, 30% audio + prompt = get_image_prompt(severity) + try: + data = await asyncio.to_thread(asset_gen.generate, prompt) + pool.add_image(data, severity=severity) + except Exception as e: + logger.error(f"Background image gen failed: {e}") + else: + text = get_voice_text() + try: + data = await asyncio.to_thread(voice_gen.generate, text) + pool.add_audio(data, severity=severity) + except Exception as e: + logger.error(f"Background audio gen failed: {e}") + + +async def _escalation_loop() -> None: + """Main escalation loop — pushes phase updates and triggers events.""" + while True: + if stream.client_count == 0: + await asyncio.sleep(1) + continue + + intensity = escalation.get_intensity() + params = escalation.get_phase_params(intensity) + + # Phase update + await stream.broadcast_phase(intensity=intensity, params=params) + + # Asset swap + severity = escalation.select_severity(intensity) + url = pool.select_image(target_severity=severity) + if url: + transition = _pick_transition(intensity) + await stream.broadcast_asset(url=url, severity=severity, transition=transition) + + # Whisper check + voice_interval = escalation.get_voice_interval(intensity) + if random.random() < (2.0 / max(1.0, voice_interval)): + audio_url = pool.select_audio(target_severity=severity) + if audio_url: + await stream.broadcast_whisper( + url=audio_url, + pan=random.uniform(-1.0, 1.0), + volume=random.uniform(0.1, 0.8), + reverb=random.uniform(0.3, 0.9), + ) + + # Direct address check (rarer) + if intensity > 1.5 and random.random() < params["voice_frequency"] * 0.1: + text = get_direct_address_text() + if voice_gen: + try: + data = await asyncio.to_thread(voice_gen.generate, text) + audio_b64 = base64.b64encode(data).decode("ascii") + await stream.broadcast_address(audio_b64=audio_b64, text=text) + except Exception as e: + logger.error(f"Direct address gen failed: {e}") + + # Surprise scare check + if random.random() < params["surprise_chance"] * 0.05: + effect = random.choice(["face_flash", "white_out", "inversion", "glitch_burst"]) + duration = random.randint(50, 300) + await stream.broadcast_scare(effect=effect, duration_ms=duration) + + # Wait for next cycle + swap_interval = escalation.get_asset_swap_interval(intensity) + await asyncio.sleep(swap_interval) + + +def _pick_transition(intensity: float) -> str: + """Pick transition mode based on intensity.""" + if intensity < 1.0: + return "crossfade" + elif intensity < 2.5: + return random.choice(["crossfade", "dissolve", "melt_morph"]) + else: + return random.choice(["glitch_cut", "melt_morph", "dissolve", "crossfade"]) + + +# Default app instance for uvicorn +app = create_app(skip_models=False) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..64993a0 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,32 @@ +from fastapi.testclient import TestClient + +from server.main import create_app + + +class TestRESTEndpoints: + def test_status_endpoint(self): + """GET /status returns intensity and pool info.""" + test_app = create_app(skip_models=True) + with TestClient(test_app) as client: + resp = client.get("/status") + assert resp.status_code == 200 + data = resp.json() + assert "intensity" in data + assert "connected_clients" in data + assert "image_pool_size" in data + + def test_reset_endpoint(self): + """POST /reset restarts escalation.""" + test_app = create_app(skip_models=True) + with TestClient(test_app) as client: + resp = client.post("/reset") + assert resp.status_code == 200 + assert resp.json()["status"] == "ok" + + def test_index_serves_html(self): + """GET / serves the frontend HTML (or fallback).""" + test_app = create_app(skip_models=True) + with TestClient(test_app) as client: + resp = client.get("/") + assert resp.status_code == 200 + assert "text/html" in resp.headers["content-type"]