53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
from unittest.mock import MagicMock, patch
|
|
from PIL import Image
|
|
|
|
from server.asset_generator import AssetGenerator
|
|
|
|
|
|
class TestAssetGenerator:
|
|
@patch("server.asset_generator.AutoPipelineForText2Image")
|
|
def test_init_loads_model(self, mock_pipeline_cls):
|
|
"""Generator loads the SDXL Turbo pipeline on init."""
|
|
mock_pipe = MagicMock()
|
|
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
|
mock_pipe.to.return_value = mock_pipe
|
|
|
|
gen = AssetGenerator(device="cpu")
|
|
mock_pipeline_cls.from_pretrained.assert_called_once()
|
|
|
|
@patch("server.asset_generator.AutoPipelineForText2Image")
|
|
def test_generate_returns_bytes(self, mock_pipeline_cls):
|
|
"""Generate returns PNG bytes."""
|
|
mock_pipe = MagicMock()
|
|
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
|
mock_pipe.to.return_value = mock_pipe
|
|
|
|
# Mock pipeline output
|
|
fake_image = Image.new("RGB", (512, 512), color="black")
|
|
mock_result = MagicMock()
|
|
mock_result.images = [fake_image]
|
|
mock_pipe.return_value = mock_result
|
|
|
|
gen = AssetGenerator(device="cpu")
|
|
data = gen.generate("dark void, horror")
|
|
assert isinstance(data, bytes)
|
|
assert len(data) > 0
|
|
|
|
@patch("server.asset_generator.AutoPipelineForText2Image")
|
|
def test_generate_uses_negative_prompt(self, mock_pipeline_cls):
|
|
"""Generate passes the negative prompt to the pipeline."""
|
|
mock_pipe = MagicMock()
|
|
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
|
mock_pipe.to.return_value = mock_pipe
|
|
|
|
fake_image = Image.new("RGB", (512, 512), color="black")
|
|
mock_result = MagicMock()
|
|
mock_result.images = [fake_image]
|
|
mock_pipe.return_value = mock_result
|
|
|
|
gen = AssetGenerator(device="cpu")
|
|
gen.generate("test prompt")
|
|
|
|
call_kwargs = mock_pipe.call_args
|
|
assert "negative_prompt" in call_kwargs.kwargs
|