feat: add SDXL Turbo image generator wrapper
This commit is contained in:
@@ -0,0 +1,57 @@
|
|||||||
|
"""SDXL Turbo wrapper for horror image generation."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers import AutoPipelineForText2Image
|
||||||
|
except ImportError: # pragma: no cover - exercised only in test environments
|
||||||
|
AutoPipelineForText2Image = None # Tests patch this attribute directly.
|
||||||
|
|
||||||
|
from server.config import config
|
||||||
|
from server.prompts import NEGATIVE_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
class AssetGenerator:
|
||||||
|
"""Generates horror images via SDXL Turbo."""
|
||||||
|
|
||||||
|
def __init__(self, device: str | None = None, model_id: str | None = None):
|
||||||
|
self.device = device or config.device
|
||||||
|
self.model_id = model_id or config.models.sdxl_model_id
|
||||||
|
|
||||||
|
if AutoPipelineForText2Image is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"diffusers is not installed; install diffusers to use AssetGenerator"
|
||||||
|
)
|
||||||
|
|
||||||
|
use_fp16 = self.device == "cuda"
|
||||||
|
self._pipe = AutoPipelineForText2Image.from_pretrained(
|
||||||
|
self.model_id,
|
||||||
|
torch_dtype=torch.float16 if use_fp16 else torch.float32,
|
||||||
|
variant="fp16" if use_fp16 else None,
|
||||||
|
)
|
||||||
|
if self.device == "cuda":
|
||||||
|
self._pipe = self._pipe.to("cuda")
|
||||||
|
|
||||||
|
def generate(self, prompt: str, seed: int | None = None) -> bytes:
|
||||||
|
"""Generate a 512x512 PNG image from a horror prompt. Returns PNG bytes."""
|
||||||
|
if seed is None:
|
||||||
|
seed = torch.randint(0, 2**32, (1,)).item()
|
||||||
|
|
||||||
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||||
|
|
||||||
|
result = self._pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=NEGATIVE_PROMPT,
|
||||||
|
num_inference_steps=config.models.sdxl_steps,
|
||||||
|
guidance_scale=config.models.sdxl_guidance_scale,
|
||||||
|
width=config.models.sdxl_width,
|
||||||
|
height=config.models.sdxl_height,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = result.images[0]
|
||||||
|
buf = io.BytesIO()
|
||||||
|
image.save(buf, format="PNG")
|
||||||
|
return buf.getvalue()
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from server.asset_generator import AssetGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssetGenerator:
|
||||||
|
@patch("server.asset_generator.AutoPipelineForText2Image")
|
||||||
|
def test_init_loads_model(self, mock_pipeline_cls):
|
||||||
|
"""Generator loads the SDXL Turbo pipeline on init."""
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
||||||
|
mock_pipe.to.return_value = mock_pipe
|
||||||
|
|
||||||
|
gen = AssetGenerator(device="cpu")
|
||||||
|
mock_pipeline_cls.from_pretrained.assert_called_once()
|
||||||
|
|
||||||
|
@patch("server.asset_generator.AutoPipelineForText2Image")
|
||||||
|
def test_generate_returns_bytes(self, mock_pipeline_cls):
|
||||||
|
"""Generate returns PNG bytes."""
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
||||||
|
mock_pipe.to.return_value = mock_pipe
|
||||||
|
|
||||||
|
# Mock pipeline output
|
||||||
|
fake_image = Image.new("RGB", (512, 512), color="black")
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.images = [fake_image]
|
||||||
|
mock_pipe.return_value = mock_result
|
||||||
|
|
||||||
|
gen = AssetGenerator(device="cpu")
|
||||||
|
data = gen.generate("dark void, horror")
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
assert len(data) > 0
|
||||||
|
|
||||||
|
@patch("server.asset_generator.AutoPipelineForText2Image")
|
||||||
|
def test_generate_uses_negative_prompt(self, mock_pipeline_cls):
|
||||||
|
"""Generate passes the negative prompt to the pipeline."""
|
||||||
|
mock_pipe = MagicMock()
|
||||||
|
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
||||||
|
mock_pipe.to.return_value = mock_pipe
|
||||||
|
|
||||||
|
fake_image = Image.new("RGB", (512, 512), color="black")
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.images = [fake_image]
|
||||||
|
mock_pipe.return_value = mock_result
|
||||||
|
|
||||||
|
gen = AssetGenerator(device="cpu")
|
||||||
|
gen.generate("test prompt")
|
||||||
|
|
||||||
|
call_kwargs = mock_pipe.call_args
|
||||||
|
assert "negative_prompt" in call_kwargs.kwargs
|
||||||
Reference in New Issue
Block a user