feat: add FastAPI app with WebSocket streaming and escalation loop
This commit is contained in:
+254
@@ -0,0 +1,254 @@
|
|||||||
|
"""FastAPI application — WebSocket streaming, REST endpoints, background workers."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import FileResponse, HTMLResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from server.config import config
|
||||||
|
from server.escalation import EscalationEngine
|
||||||
|
from server.asset_pool import AssetPool
|
||||||
|
from server.streaming import StreamManager
|
||||||
|
from server.prompts import get_image_prompt, get_voice_text, get_direct_address_text
|
||||||
|
|
||||||
|
logger = logging.getLogger("ai-hell")
|
||||||
|
|
||||||
|
# Global instances (set during lifespan or create_app)
|
||||||
|
escalation: EscalationEngine | None = None
|
||||||
|
pool: AssetPool | None = None
|
||||||
|
stream: StreamManager | None = None
|
||||||
|
asset_gen = None # AssetGenerator (lazy, needs GPU)
|
||||||
|
voice_gen = None # VoiceGenerator (lazy, needs GPU)
|
||||||
|
_workers: list[asyncio.Task] = []
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(skip_models: bool = False) -> FastAPI:
|
||||||
|
"""Create the FastAPI app. skip_models=True for testing without GPU."""
|
||||||
|
global escalation, pool, stream, asset_gen, voice_gen
|
||||||
|
|
||||||
|
escalation = EscalationEngine()
|
||||||
|
pool = AssetPool()
|
||||||
|
stream = StreamManager()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global asset_gen, voice_gen
|
||||||
|
escalation.start_session()
|
||||||
|
|
||||||
|
if not skip_models:
|
||||||
|
from server.asset_generator import AssetGenerator
|
||||||
|
from server.voice_generator import VoiceGenerator
|
||||||
|
|
||||||
|
logger.info("Loading SDXL Turbo...")
|
||||||
|
asset_gen = AssetGenerator()
|
||||||
|
logger.info("Loading XTTS v2...")
|
||||||
|
voice_gen = VoiceGenerator()
|
||||||
|
logger.info("Models loaded. Generating initial batch...")
|
||||||
|
|
||||||
|
# Generate initial asset batch in background
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
_workers.append(asyncio.create_task(_initial_batch(loop)))
|
||||||
|
_workers.append(asyncio.create_task(_background_generator(loop)))
|
||||||
|
_workers.append(asyncio.create_task(_escalation_loop()))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown workers
|
||||||
|
for task in _workers:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
the_app = FastAPI(title="AI Hell", lifespan=lifespan)
|
||||||
|
|
||||||
|
# Mount assets directory for static serving
|
||||||
|
assets_dir = Path(config.assets_dir)
|
||||||
|
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(assets_dir / "img").mkdir(exist_ok=True)
|
||||||
|
(assets_dir / "audio").mkdir(exist_ok=True)
|
||||||
|
the_app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
|
||||||
|
|
||||||
|
# --- REST endpoints ---
|
||||||
|
|
||||||
|
@the_app.get("/")
|
||||||
|
async def index():
|
||||||
|
html_path = Path(__file__).parent.parent / "frontend" / "index.html"
|
||||||
|
if html_path.exists():
|
||||||
|
return FileResponse(html_path, media_type="text/html")
|
||||||
|
return HTMLResponse("<h1>AI Hell</h1><p>Frontend not found.</p>")
|
||||||
|
|
||||||
|
@the_app.get("/status")
|
||||||
|
async def status():
|
||||||
|
return {
|
||||||
|
"intensity": round(escalation.get_intensity(), 2),
|
||||||
|
"connected_clients": stream.client_count,
|
||||||
|
**pool.get_status(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@the_app.post("/reset")
|
||||||
|
async def reset():
|
||||||
|
escalation.reset()
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
# --- WebSocket ---
|
||||||
|
|
||||||
|
@the_app.websocket("/stream")
|
||||||
|
async def stream_ws(ws: WebSocket):
|
||||||
|
await ws.accept()
|
||||||
|
stream.add_client(ws)
|
||||||
|
# Send current state immediately
|
||||||
|
intensity = escalation.get_intensity()
|
||||||
|
params = escalation.get_phase_params(intensity)
|
||||||
|
await ws.send_text(
|
||||||
|
__import__("json").dumps({
|
||||||
|
"type": "phase",
|
||||||
|
"intensity": round(intensity, 2),
|
||||||
|
"params": params,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await ws.receive_text() # Keep alive, ignore pings
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
stream.remove_client(ws)
|
||||||
|
|
||||||
|
# Serve frontend shader files
|
||||||
|
@the_app.get("/shaders/{filename}")
|
||||||
|
async def serve_shader(filename: str):
|
||||||
|
shader_path = Path(__file__).parent.parent / "frontend" / "shaders" / filename
|
||||||
|
if shader_path.exists():
|
||||||
|
return FileResponse(shader_path, media_type="text/plain")
|
||||||
|
return HTMLResponse("Not found", status_code=404)
|
||||||
|
|
||||||
|
return the_app
|
||||||
|
|
||||||
|
|
||||||
|
async def _initial_batch(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
"""Generate the initial pool of images and audio clips."""
|
||||||
|
batch_size = config.escalation.initial_batch_size
|
||||||
|
img_count = int(batch_size * 0.75)
|
||||||
|
audio_count = batch_size - img_count
|
||||||
|
|
||||||
|
for i in range(img_count):
|
||||||
|
severity = (i / max(1, img_count - 1)) * 4.0 # Spread across severity range
|
||||||
|
prompt = get_image_prompt(severity)
|
||||||
|
try:
|
||||||
|
data = await asyncio.to_thread(asset_gen.generate, prompt)
|
||||||
|
pool.add_image(data, severity=severity)
|
||||||
|
logger.info(f"Initial image {i+1}/{img_count} (severity={severity:.1f})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate initial image: {e}")
|
||||||
|
|
||||||
|
for i in range(audio_count):
|
||||||
|
severity = (i / max(1, audio_count - 1)) * 4.0
|
||||||
|
text = get_voice_text()
|
||||||
|
try:
|
||||||
|
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||||
|
pool.add_audio(data, severity=severity)
|
||||||
|
logger.info(f"Initial audio {i+1}/{audio_count} (severity={severity:.1f})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate initial audio: {e}")
|
||||||
|
|
||||||
|
logger.info("Initial batch complete.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _background_generator(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
"""Continuously generate new assets biased toward current viewer needs."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(random.uniform(10, 30))
|
||||||
|
if stream.client_count == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
intensity = escalation.get_intensity()
|
||||||
|
severity = escalation.select_severity(intensity)
|
||||||
|
|
||||||
|
# Alternate between images and audio
|
||||||
|
if random.random() < 0.7: # 70% images, 30% audio
|
||||||
|
prompt = get_image_prompt(severity)
|
||||||
|
try:
|
||||||
|
data = await asyncio.to_thread(asset_gen.generate, prompt)
|
||||||
|
pool.add_image(data, severity=severity)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background image gen failed: {e}")
|
||||||
|
else:
|
||||||
|
text = get_voice_text()
|
||||||
|
try:
|
||||||
|
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||||
|
pool.add_audio(data, severity=severity)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background audio gen failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _escalation_loop() -> None:
|
||||||
|
"""Main escalation loop — pushes phase updates and triggers events."""
|
||||||
|
while True:
|
||||||
|
if stream.client_count == 0:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
intensity = escalation.get_intensity()
|
||||||
|
params = escalation.get_phase_params(intensity)
|
||||||
|
|
||||||
|
# Phase update
|
||||||
|
await stream.broadcast_phase(intensity=intensity, params=params)
|
||||||
|
|
||||||
|
# Asset swap
|
||||||
|
severity = escalation.select_severity(intensity)
|
||||||
|
url = pool.select_image(target_severity=severity)
|
||||||
|
if url:
|
||||||
|
transition = _pick_transition(intensity)
|
||||||
|
await stream.broadcast_asset(url=url, severity=severity, transition=transition)
|
||||||
|
|
||||||
|
# Whisper check
|
||||||
|
voice_interval = escalation.get_voice_interval(intensity)
|
||||||
|
if random.random() < (2.0 / max(1.0, voice_interval)):
|
||||||
|
audio_url = pool.select_audio(target_severity=severity)
|
||||||
|
if audio_url:
|
||||||
|
await stream.broadcast_whisper(
|
||||||
|
url=audio_url,
|
||||||
|
pan=random.uniform(-1.0, 1.0),
|
||||||
|
volume=random.uniform(0.1, 0.8),
|
||||||
|
reverb=random.uniform(0.3, 0.9),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Direct address check (rarer)
|
||||||
|
if intensity > 1.5 and random.random() < params["voice_frequency"] * 0.1:
|
||||||
|
text = get_direct_address_text()
|
||||||
|
if voice_gen:
|
||||||
|
try:
|
||||||
|
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||||
|
audio_b64 = base64.b64encode(data).decode("ascii")
|
||||||
|
await stream.broadcast_address(audio_b64=audio_b64, text=text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Direct address gen failed: {e}")
|
||||||
|
|
||||||
|
# Surprise scare check
|
||||||
|
if random.random() < params["surprise_chance"] * 0.05:
|
||||||
|
effect = random.choice(["face_flash", "white_out", "inversion", "glitch_burst"])
|
||||||
|
duration = random.randint(50, 300)
|
||||||
|
await stream.broadcast_scare(effect=effect, duration_ms=duration)
|
||||||
|
|
||||||
|
# Wait for next cycle
|
||||||
|
swap_interval = escalation.get_asset_swap_interval(intensity)
|
||||||
|
await asyncio.sleep(swap_interval)
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_transition(intensity: float) -> str:
|
||||||
|
"""Pick transition mode based on intensity."""
|
||||||
|
if intensity < 1.0:
|
||||||
|
return "crossfade"
|
||||||
|
elif intensity < 2.5:
|
||||||
|
return random.choice(["crossfade", "dissolve", "melt_morph"])
|
||||||
|
else:
|
||||||
|
return random.choice(["glitch_cut", "melt_morph", "dissolve", "crossfade"])
|
||||||
|
|
||||||
|
|
||||||
|
# Default app instance for uvicorn
|
||||||
|
app = create_app(skip_models=False)
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from server.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestRESTEndpoints:
|
||||||
|
def test_status_endpoint(self):
|
||||||
|
"""GET /status returns intensity and pool info."""
|
||||||
|
test_app = create_app(skip_models=True)
|
||||||
|
with TestClient(test_app) as client:
|
||||||
|
resp = client.get("/status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "intensity" in data
|
||||||
|
assert "connected_clients" in data
|
||||||
|
assert "image_pool_size" in data
|
||||||
|
|
||||||
|
def test_reset_endpoint(self):
|
||||||
|
"""POST /reset restarts escalation."""
|
||||||
|
test_app = create_app(skip_models=True)
|
||||||
|
with TestClient(test_app) as client:
|
||||||
|
resp = client.post("/reset")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ok"
|
||||||
|
|
||||||
|
def test_index_serves_html(self):
|
||||||
|
"""GET / serves the frontend HTML (or fallback)."""
|
||||||
|
test_app = create_app(skip_models=True)
|
||||||
|
with TestClient(test_app) as client:
|
||||||
|
resp = client.get("/")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "text/html" in resp.headers["content-type"]
|
||||||
Reference in New Issue
Block a user