feat: add FastAPI app with WebSocket streaming and escalation loop
This commit is contained in:
+254
@@ -0,0 +1,254 @@
|
||||
"""FastAPI application — WebSocket streaming, REST endpoints, background workers."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from server.config import config
|
||||
from server.escalation import EscalationEngine
|
||||
from server.asset_pool import AssetPool
|
||||
from server.streaming import StreamManager
|
||||
from server.prompts import get_image_prompt, get_voice_text, get_direct_address_text
|
||||
|
||||
logger = logging.getLogger("ai-hell")
|
||||
|
||||
# Global instances (set during lifespan or create_app)
|
||||
escalation: EscalationEngine | None = None
|
||||
pool: AssetPool | None = None
|
||||
stream: StreamManager | None = None
|
||||
asset_gen = None # AssetGenerator (lazy, needs GPU)
|
||||
voice_gen = None # VoiceGenerator (lazy, needs GPU)
|
||||
_workers: list[asyncio.Task] = []
|
||||
|
||||
|
||||
def create_app(skip_models: bool = False) -> FastAPI:
|
||||
"""Create the FastAPI app. skip_models=True for testing without GPU."""
|
||||
global escalation, pool, stream, asset_gen, voice_gen
|
||||
|
||||
escalation = EscalationEngine()
|
||||
pool = AssetPool()
|
||||
stream = StreamManager()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global asset_gen, voice_gen
|
||||
escalation.start_session()
|
||||
|
||||
if not skip_models:
|
||||
from server.asset_generator import AssetGenerator
|
||||
from server.voice_generator import VoiceGenerator
|
||||
|
||||
logger.info("Loading SDXL Turbo...")
|
||||
asset_gen = AssetGenerator()
|
||||
logger.info("Loading XTTS v2...")
|
||||
voice_gen = VoiceGenerator()
|
||||
logger.info("Models loaded. Generating initial batch...")
|
||||
|
||||
# Generate initial asset batch in background
|
||||
loop = asyncio.get_running_loop()
|
||||
_workers.append(asyncio.create_task(_initial_batch(loop)))
|
||||
_workers.append(asyncio.create_task(_background_generator(loop)))
|
||||
_workers.append(asyncio.create_task(_escalation_loop()))
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown workers
|
||||
for task in _workers:
|
||||
task.cancel()
|
||||
|
||||
the_app = FastAPI(title="AI Hell", lifespan=lifespan)
|
||||
|
||||
# Mount assets directory for static serving
|
||||
assets_dir = Path(config.assets_dir)
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
(assets_dir / "img").mkdir(exist_ok=True)
|
||||
(assets_dir / "audio").mkdir(exist_ok=True)
|
||||
the_app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
|
||||
|
||||
# --- REST endpoints ---
|
||||
|
||||
@the_app.get("/")
|
||||
async def index():
|
||||
html_path = Path(__file__).parent.parent / "frontend" / "index.html"
|
||||
if html_path.exists():
|
||||
return FileResponse(html_path, media_type="text/html")
|
||||
return HTMLResponse("<h1>AI Hell</h1><p>Frontend not found.</p>")
|
||||
|
||||
@the_app.get("/status")
|
||||
async def status():
|
||||
return {
|
||||
"intensity": round(escalation.get_intensity(), 2),
|
||||
"connected_clients": stream.client_count,
|
||||
**pool.get_status(),
|
||||
}
|
||||
|
||||
@the_app.post("/reset")
|
||||
async def reset():
|
||||
escalation.reset()
|
||||
return {"status": "ok"}
|
||||
|
||||
# --- WebSocket ---
|
||||
|
||||
@the_app.websocket("/stream")
|
||||
async def stream_ws(ws: WebSocket):
|
||||
await ws.accept()
|
||||
stream.add_client(ws)
|
||||
# Send current state immediately
|
||||
intensity = escalation.get_intensity()
|
||||
params = escalation.get_phase_params(intensity)
|
||||
await ws.send_text(
|
||||
__import__("json").dumps({
|
||||
"type": "phase",
|
||||
"intensity": round(intensity, 2),
|
||||
"params": params,
|
||||
})
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
await ws.receive_text() # Keep alive, ignore pings
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
stream.remove_client(ws)
|
||||
|
||||
# Serve frontend shader files
|
||||
@the_app.get("/shaders/{filename}")
|
||||
async def serve_shader(filename: str):
|
||||
shader_path = Path(__file__).parent.parent / "frontend" / "shaders" / filename
|
||||
if shader_path.exists():
|
||||
return FileResponse(shader_path, media_type="text/plain")
|
||||
return HTMLResponse("Not found", status_code=404)
|
||||
|
||||
return the_app
|
||||
|
||||
|
||||
async def _initial_batch(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Generate the initial pool of images and audio clips."""
|
||||
batch_size = config.escalation.initial_batch_size
|
||||
img_count = int(batch_size * 0.75)
|
||||
audio_count = batch_size - img_count
|
||||
|
||||
for i in range(img_count):
|
||||
severity = (i / max(1, img_count - 1)) * 4.0 # Spread across severity range
|
||||
prompt = get_image_prompt(severity)
|
||||
try:
|
||||
data = await asyncio.to_thread(asset_gen.generate, prompt)
|
||||
pool.add_image(data, severity=severity)
|
||||
logger.info(f"Initial image {i+1}/{img_count} (severity={severity:.1f})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate initial image: {e}")
|
||||
|
||||
for i in range(audio_count):
|
||||
severity = (i / max(1, audio_count - 1)) * 4.0
|
||||
text = get_voice_text()
|
||||
try:
|
||||
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||
pool.add_audio(data, severity=severity)
|
||||
logger.info(f"Initial audio {i+1}/{audio_count} (severity={severity:.1f})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate initial audio: {e}")
|
||||
|
||||
logger.info("Initial batch complete.")
|
||||
|
||||
|
||||
async def _background_generator(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Continuously generate new assets biased toward current viewer needs."""
|
||||
while True:
|
||||
await asyncio.sleep(random.uniform(10, 30))
|
||||
if stream.client_count == 0:
|
||||
continue
|
||||
|
||||
intensity = escalation.get_intensity()
|
||||
severity = escalation.select_severity(intensity)
|
||||
|
||||
# Alternate between images and audio
|
||||
if random.random() < 0.7: # 70% images, 30% audio
|
||||
prompt = get_image_prompt(severity)
|
||||
try:
|
||||
data = await asyncio.to_thread(asset_gen.generate, prompt)
|
||||
pool.add_image(data, severity=severity)
|
||||
except Exception as e:
|
||||
logger.error(f"Background image gen failed: {e}")
|
||||
else:
|
||||
text = get_voice_text()
|
||||
try:
|
||||
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||
pool.add_audio(data, severity=severity)
|
||||
except Exception as e:
|
||||
logger.error(f"Background audio gen failed: {e}")
|
||||
|
||||
|
||||
async def _escalation_loop() -> None:
|
||||
"""Main escalation loop — pushes phase updates and triggers events."""
|
||||
while True:
|
||||
if stream.client_count == 0:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
intensity = escalation.get_intensity()
|
||||
params = escalation.get_phase_params(intensity)
|
||||
|
||||
# Phase update
|
||||
await stream.broadcast_phase(intensity=intensity, params=params)
|
||||
|
||||
# Asset swap
|
||||
severity = escalation.select_severity(intensity)
|
||||
url = pool.select_image(target_severity=severity)
|
||||
if url:
|
||||
transition = _pick_transition(intensity)
|
||||
await stream.broadcast_asset(url=url, severity=severity, transition=transition)
|
||||
|
||||
# Whisper check
|
||||
voice_interval = escalation.get_voice_interval(intensity)
|
||||
if random.random() < (2.0 / max(1.0, voice_interval)):
|
||||
audio_url = pool.select_audio(target_severity=severity)
|
||||
if audio_url:
|
||||
await stream.broadcast_whisper(
|
||||
url=audio_url,
|
||||
pan=random.uniform(-1.0, 1.0),
|
||||
volume=random.uniform(0.1, 0.8),
|
||||
reverb=random.uniform(0.3, 0.9),
|
||||
)
|
||||
|
||||
# Direct address check (rarer)
|
||||
if intensity > 1.5 and random.random() < params["voice_frequency"] * 0.1:
|
||||
text = get_direct_address_text()
|
||||
if voice_gen:
|
||||
try:
|
||||
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||
audio_b64 = base64.b64encode(data).decode("ascii")
|
||||
await stream.broadcast_address(audio_b64=audio_b64, text=text)
|
||||
except Exception as e:
|
||||
logger.error(f"Direct address gen failed: {e}")
|
||||
|
||||
# Surprise scare check
|
||||
if random.random() < params["surprise_chance"] * 0.05:
|
||||
effect = random.choice(["face_flash", "white_out", "inversion", "glitch_burst"])
|
||||
duration = random.randint(50, 300)
|
||||
await stream.broadcast_scare(effect=effect, duration_ms=duration)
|
||||
|
||||
# Wait for next cycle
|
||||
swap_interval = escalation.get_asset_swap_interval(intensity)
|
||||
await asyncio.sleep(swap_interval)
|
||||
|
||||
|
||||
def _pick_transition(intensity: float) -> str:
|
||||
"""Pick transition mode based on intensity."""
|
||||
if intensity < 1.0:
|
||||
return "crossfade"
|
||||
elif intensity < 2.5:
|
||||
return random.choice(["crossfade", "dissolve", "melt_morph"])
|
||||
else:
|
||||
return random.choice(["glitch_cut", "melt_morph", "dissolve", "crossfade"])
|
||||
|
||||
|
||||
# Default app instance for uvicorn
|
||||
app = create_app(skip_models=False)
|
||||
Reference in New Issue
Block a user