diff --git a/server/asset_generator.py b/server/asset_generator.py new file mode 100644 index 0000000..639e542 --- /dev/null +++ b/server/asset_generator.py @@ -0,0 +1,57 @@ +"""SDXL Turbo wrapper for horror image generation.""" + +import io + +import torch + +try: + from diffusers import AutoPipelineForText2Image +except ImportError: # pragma: no cover - exercised only in test environments + AutoPipelineForText2Image = None # Tests patch this attribute directly. + +from server.config import config +from server.prompts import NEGATIVE_PROMPT + + +class AssetGenerator: + """Generates horror images via SDXL Turbo.""" + + def __init__(self, device: str | None = None, model_id: str | None = None): + self.device = device or config.device + self.model_id = model_id or config.models.sdxl_model_id + + if AutoPipelineForText2Image is None: + raise RuntimeError( + "diffusers is not installed; install diffusers to use AssetGenerator" + ) + + use_fp16 = self.device == "cuda" + self._pipe = AutoPipelineForText2Image.from_pretrained( + self.model_id, + torch_dtype=torch.float16 if use_fp16 else torch.float32, + variant="fp16" if use_fp16 else None, + ) + if self.device == "cuda": + self._pipe = self._pipe.to("cuda") + + def generate(self, prompt: str, seed: int | None = None) -> bytes: + """Generate a 512x512 PNG image from a horror prompt. Returns PNG bytes.""" + if seed is None: + seed = torch.randint(0, 2**32, (1,)).item() + + generator = torch.Generator(device=self.device).manual_seed(seed) + + result = self._pipe( + prompt=prompt, + negative_prompt=NEGATIVE_PROMPT, + num_inference_steps=config.models.sdxl_steps, + guidance_scale=config.models.sdxl_guidance_scale, + width=config.models.sdxl_width, + height=config.models.sdxl_height, + generator=generator, + ) + + image = result.images[0] + buf = io.BytesIO() + image.save(buf, format="PNG") + return buf.getvalue() diff --git a/tests/test_asset_generator.py b/tests/test_asset_generator.py new file mode 100644 index 0000000..050d766 --- /dev/null +++ b/tests/test_asset_generator.py @@ -0,0 +1,52 @@ +from unittest.mock import MagicMock, patch +from PIL import Image + +from server.asset_generator import AssetGenerator + + +class TestAssetGenerator: + @patch("server.asset_generator.AutoPipelineForText2Image") + def test_init_loads_model(self, mock_pipeline_cls): + """Generator loads the SDXL Turbo pipeline on init.""" + mock_pipe = MagicMock() + mock_pipeline_cls.from_pretrained.return_value = mock_pipe + mock_pipe.to.return_value = mock_pipe + + gen = AssetGenerator(device="cpu") + mock_pipeline_cls.from_pretrained.assert_called_once() + + @patch("server.asset_generator.AutoPipelineForText2Image") + def test_generate_returns_bytes(self, mock_pipeline_cls): + """Generate returns PNG bytes.""" + mock_pipe = MagicMock() + mock_pipeline_cls.from_pretrained.return_value = mock_pipe + mock_pipe.to.return_value = mock_pipe + + # Mock pipeline output + fake_image = Image.new("RGB", (512, 512), color="black") + mock_result = MagicMock() + mock_result.images = [fake_image] + mock_pipe.return_value = mock_result + + gen = AssetGenerator(device="cpu") + data = gen.generate("dark void, horror") + assert isinstance(data, bytes) + assert len(data) > 0 + + @patch("server.asset_generator.AutoPipelineForText2Image") + def test_generate_uses_negative_prompt(self, mock_pipeline_cls): + """Generate passes the negative prompt to the pipeline.""" + mock_pipe = MagicMock() + mock_pipeline_cls.from_pretrained.return_value = mock_pipe + mock_pipe.to.return_value = mock_pipe + + fake_image = Image.new("RGB", (512, 512), color="black") + mock_result = MagicMock() + mock_result.images = [fake_image] + mock_pipe.return_value = mock_result + + gen = AssetGenerator(device="cpu") + gen.generate("test prompt") + + call_kwargs = mock_pipe.call_args + assert "negative_prompt" in call_kwargs.kwargs