feat: add asset pool with severity tagging and rotation
This commit is contained in:
@@ -0,0 +1,130 @@
|
||||
"""Asset pool — manages generated images and audio on disk with severity tagging."""
|
||||
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
from server.config import config
|
||||
|
||||
|
||||
@dataclass
|
||||
class Asset:
|
||||
"""A generated asset with metadata."""
|
||||
filename: str
|
||||
severity: float
|
||||
created_at: float
|
||||
asset_type: str # "image" or "audio"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
subdir = "img" if self.asset_type == "image" else "audio"
|
||||
return f"/assets/{subdir}/{self.filename}"
|
||||
|
||||
|
||||
class AssetPool:
|
||||
"""Thread-safe pool of generated assets with severity-based selection and rotation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: str | None = None,
|
||||
max_images: int | None = None,
|
||||
max_audio: int | None = None,
|
||||
):
|
||||
self.base_dir = Path(base_dir or config.assets_dir)
|
||||
self.max_images = max_images if max_images is not None else config.escalation.max_images
|
||||
self.max_audio = max_audio if max_audio is not None else config.escalation.max_audio_clips
|
||||
self._images: list[Asset] = []
|
||||
self._audio: list[Asset] = []
|
||||
self._lock = Lock()
|
||||
|
||||
(self.base_dir / "img").mkdir(parents=True, exist_ok=True)
|
||||
(self.base_dir / "audio").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def image_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._images)
|
||||
|
||||
@property
|
||||
def audio_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._audio)
|
||||
|
||||
def add_image(self, data: bytes, severity: float) -> str:
|
||||
"""Save image data to disk and add to pool. Returns URL path."""
|
||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||
path = self.base_dir / "img" / filename
|
||||
path.write_bytes(data)
|
||||
|
||||
asset = Asset(
|
||||
filename=filename,
|
||||
severity=severity,
|
||||
created_at=time.monotonic(),
|
||||
asset_type="image",
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._images.append(asset)
|
||||
self._rotate(self._images, self.max_images, "img")
|
||||
|
||||
return asset.url
|
||||
|
||||
def add_audio(self, data: bytes, severity: float) -> str:
|
||||
"""Save audio data to disk and add to pool. Returns URL path."""
|
||||
filename = f"{uuid.uuid4().hex[:12]}.wav"
|
||||
path = self.base_dir / "audio" / filename
|
||||
path.write_bytes(data)
|
||||
|
||||
asset = Asset(
|
||||
filename=filename,
|
||||
severity=severity,
|
||||
created_at=time.monotonic(),
|
||||
asset_type="audio",
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._audio.append(asset)
|
||||
self._rotate(self._audio, self.max_audio, "audio")
|
||||
|
||||
return asset.url
|
||||
|
||||
def select_image(self, target_severity: float) -> str | None:
|
||||
"""Select an image near the target severity. Weighted random, biased toward close matches."""
|
||||
with self._lock:
|
||||
return self._select(self._images, target_severity)
|
||||
|
||||
def select_audio(self, target_severity: float) -> str | None:
|
||||
"""Select an audio clip near the target severity."""
|
||||
with self._lock:
|
||||
return self._select(self._audio, target_severity)
|
||||
|
||||
def _select(self, assets: list[Asset], target: float) -> str | None:
|
||||
"""Weighted selection: closer severity = higher weight."""
|
||||
if not assets:
|
||||
return None
|
||||
weights = []
|
||||
for a in assets:
|
||||
distance = abs(a.severity - target)
|
||||
weights.append(1.0 / (1.0 + distance))
|
||||
chosen = random.choices(assets, weights=weights, k=1)[0]
|
||||
return chosen.url
|
||||
|
||||
def _rotate(self, assets: list[Asset], max_count: int, subdir: str) -> None:
|
||||
"""Remove oldest assets when pool exceeds capacity. Must hold lock."""
|
||||
while len(assets) > max_count:
|
||||
old = assets.pop(0)
|
||||
path = self.base_dir / subdir / old.filename
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def get_status(self) -> dict:
|
||||
with self._lock:
|
||||
return {
|
||||
"image_pool_size": len(self._images),
|
||||
"audio_pool_size": len(self._audio),
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
from pathlib import Path
|
||||
|
||||
from server.asset_pool import AssetPool
|
||||
|
||||
|
||||
def _make_pool(tmp_path: Path, max_images: int = 10, max_audio: int = 5) -> AssetPool:
|
||||
return AssetPool(
|
||||
base_dir=str(tmp_path),
|
||||
max_images=max_images,
|
||||
max_audio=max_audio,
|
||||
)
|
||||
|
||||
|
||||
def _fake_image(tmp_path: Path, pool: AssetPool, severity: float) -> str:
|
||||
content = b"fake png data"
|
||||
return pool.add_image(content, severity=severity)
|
||||
|
||||
|
||||
def _fake_audio(tmp_path: Path, pool: AssetPool, severity: float) -> str:
|
||||
content = b"fake wav data"
|
||||
return pool.add_audio(content, severity=severity)
|
||||
|
||||
|
||||
class TestAssetPoolInit:
|
||||
def test_creates_directories(self, tmp_path):
|
||||
"""Pool creates img/ and audio/ subdirectories."""
|
||||
pool = _make_pool(tmp_path)
|
||||
assert (tmp_path / "img").is_dir()
|
||||
assert (tmp_path / "audio").is_dir()
|
||||
|
||||
def test_empty_pool(self, tmp_path):
|
||||
"""New pool has no assets."""
|
||||
pool = _make_pool(tmp_path)
|
||||
assert pool.image_count == 0
|
||||
assert pool.audio_count == 0
|
||||
|
||||
|
||||
class TestAddAssets:
|
||||
def test_add_image(self, tmp_path):
|
||||
"""Adding an image increments count and returns a URL path."""
|
||||
pool = _make_pool(tmp_path)
|
||||
url = _fake_image(tmp_path, pool, severity=1.0)
|
||||
assert pool.image_count == 1
|
||||
assert url.startswith("/assets/img/")
|
||||
assert url.endswith(".png")
|
||||
|
||||
def test_add_audio(self, tmp_path):
|
||||
"""Adding audio increments count and returns a URL path."""
|
||||
pool = _make_pool(tmp_path)
|
||||
url = _fake_audio(tmp_path, pool, severity=1.0)
|
||||
assert pool.audio_count == 1
|
||||
assert url.startswith("/assets/audio/")
|
||||
assert url.endswith(".wav")
|
||||
|
||||
def test_file_exists_on_disk(self, tmp_path):
|
||||
"""Added assets exist as real files."""
|
||||
pool = _make_pool(tmp_path)
|
||||
url = _fake_image(tmp_path, pool, severity=1.0)
|
||||
filename = url.split("/")[-1]
|
||||
assert (tmp_path / "img" / filename).exists()
|
||||
|
||||
|
||||
class TestSelectAssets:
|
||||
def test_select_image_by_severity(self, tmp_path):
|
||||
"""Selects an image closest to target severity."""
|
||||
pool = _make_pool(tmp_path)
|
||||
_fake_image(tmp_path, pool, severity=0.5)
|
||||
_fake_image(tmp_path, pool, severity=2.0)
|
||||
_fake_image(tmp_path, pool, severity=4.0)
|
||||
url = pool.select_image(target_severity=1.8)
|
||||
assert url is not None
|
||||
|
||||
def test_select_audio_by_severity(self, tmp_path):
|
||||
"""Selects an audio clip closest to target severity."""
|
||||
pool = _make_pool(tmp_path)
|
||||
_fake_audio(tmp_path, pool, severity=0.5)
|
||||
_fake_audio(tmp_path, pool, severity=3.0)
|
||||
url = pool.select_audio(target_severity=2.5)
|
||||
assert url is not None
|
||||
|
||||
def test_select_from_empty_returns_none(self, tmp_path):
|
||||
"""Selecting from empty pool returns None."""
|
||||
pool = _make_pool(tmp_path)
|
||||
assert pool.select_image(target_severity=1.0) is None
|
||||
assert pool.select_audio(target_severity=1.0) is None
|
||||
|
||||
|
||||
class TestRotation:
|
||||
def test_image_rotation(self, tmp_path):
|
||||
"""Oldest images are removed when pool exceeds max."""
|
||||
pool = _make_pool(tmp_path, max_images=3)
|
||||
for i in range(5):
|
||||
_fake_image(tmp_path, pool, severity=float(i))
|
||||
assert pool.image_count == 3
|
||||
|
||||
def test_audio_rotation(self, tmp_path):
|
||||
"""Oldest audio clips are removed when pool exceeds max."""
|
||||
pool = _make_pool(tmp_path, max_audio=2)
|
||||
for i in range(4):
|
||||
_fake_audio(tmp_path, pool, severity=float(i))
|
||||
assert pool.audio_count == 2
|
||||
|
||||
|
||||
class TestStatus:
|
||||
def test_status_dict(self, tmp_path):
|
||||
"""Status returns pool sizes."""
|
||||
pool = _make_pool(tmp_path)
|
||||
_fake_image(tmp_path, pool, severity=1.0)
|
||||
_fake_audio(tmp_path, pool, severity=1.0)
|
||||
status = pool.get_status()
|
||||
assert status["image_pool_size"] == 1
|
||||
assert status["audio_pool_size"] == 1
|
||||
Reference in New Issue
Block a user