Files
ai-hell/tests/test_voice_generator.py

91 lines
3.3 KiB
Python

import os
import tempfile
import wave
from unittest.mock import MagicMock, patch
from server.voice_generator import VoiceGenerator
class TestVoiceGenerator:
@patch("server.voice_generator.TTS")
def test_init_loads_model(self, mock_tts_cls):
"""Generator loads the XTTS v2 model on init."""
mock_tts = MagicMock()
mock_tts_cls.return_value = mock_tts
gen = VoiceGenerator(device="cpu")
mock_tts_cls.assert_called_once()
@patch("server.voice_generator.TTS")
def test_generate_returns_wav_bytes(self, mock_tts_cls):
"""Generate returns WAV bytes."""
mock_tts = MagicMock()
mock_tts_cls.return_value = mock_tts
# Create a real WAV file for the mock to "produce"
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
tmp_wav = f.name
with wave.open(f, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(22050)
wf.writeframes(b"\x00\x00" * 22050) # 1 second of silence
try:
# Mock tts_to_file to copy our test WAV
def fake_tts_to_file(text, speaker_wav, language, file_path):
import shutil
shutil.copy2(tmp_wav, file_path)
mock_tts.tts_to_file = fake_tts_to_file
gen = VoiceGenerator(device="cpu")
data = gen.generate("hello", speaker_wav=tmp_wav)
assert isinstance(data, bytes)
assert len(data) > 0
finally:
os.unlink(tmp_wav)
@patch("server.voice_generator.TTS")
def test_list_clone_sources(self, mock_tts_cls):
"""Lists available clone source files."""
mock_tts = MagicMock()
mock_tts_cls.return_value = mock_tts
with tempfile.TemporaryDirectory() as samples_dir:
# Create some fake sample files
for name in ["dog.wav", "machine.wav", "wind.wav"]:
with open(os.path.join(samples_dir, name), "wb") as f:
f.write(b"fake")
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
sources = gen.list_clone_sources()
assert len(sources) == 3
assert all(s.endswith(".wav") for s in sources)
@patch("server.voice_generator.TTS")
def test_random_clone_source(self, mock_tts_cls):
"""Picks a random clone source from samples directory."""
mock_tts = MagicMock()
mock_tts_cls.return_value = mock_tts
with tempfile.TemporaryDirectory() as samples_dir:
for name in ["a.wav", "b.wav", "c.wav"]:
with open(os.path.join(samples_dir, name), "wb") as f:
f.write(b"fake")
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
source = gen.random_clone_source()
assert source is not None
assert source.endswith(".wav")
@patch("server.voice_generator.TTS")
def test_empty_samples_dir(self, mock_tts_cls):
"""Empty samples dir returns None for random source."""
mock_tts = MagicMock()
mock_tts_cls.return_value = mock_tts
with tempfile.TemporaryDirectory() as samples_dir:
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
assert gen.random_clone_source() is None