diff --git a/server/voice_generator.py b/server/voice_generator.py new file mode 100644 index 0000000..81efe2f --- /dev/null +++ b/server/voice_generator.py @@ -0,0 +1,73 @@ +"""XTTS v2 wrapper for voice cloning from non-voice audio samples.""" + +import os +import random +import tempfile +from pathlib import Path + +try: + from TTS.api import TTS +except ImportError: + TTS = None # Tests patch this; real runtime requires the TTS package + +from server.config import config + + +class VoiceGenerator: + """Generates speech cloned from arbitrary audio samples via XTTS v2.""" + + def __init__( + self, + device: str | None = None, + model_name: str | None = None, + samples_dir: str | None = None, + ): + self.device = device or config.device + self.model_name = model_name or config.models.xtts_model + self.samples_dir = Path(samples_dir or config.samples_dir) + if TTS is None: + raise RuntimeError( + "TTS package is not installed; cannot instantiate VoiceGenerator" + ) + self._tts = TTS(model_name=self.model_name) + self._tts.to(self.device) + + def generate(self, text: str, speaker_wav: str | None = None) -> bytes: + """Generate speech as WAV bytes. Uses a random clone source if none specified.""" + if speaker_wav is None: + speaker_wav = self.random_clone_source() + if speaker_wav is None: + raise ValueError("No speaker WAV provided and no samples available") + + # XTTS writes to file, so use a temp file + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp.close() + try: + self._tts.tts_to_file( + text=text, + speaker_wav=speaker_wav, + language=config.models.xtts_language, + file_path=tmp.name, + ) + with open(tmp.name, "rb") as f: + return f.read() + finally: + try: + os.unlink(tmp.name) + except OSError: + pass + + def list_clone_sources(self) -> list[str]: + """List all WAV files in the samples directory.""" + if not self.samples_dir.is_dir(): + return [] + return [ + str(p) for p in sorted(self.samples_dir.glob("*.wav")) + ] + + def random_clone_source(self) -> str | None: + """Pick a random clone source WAV file.""" + sources = self.list_clone_sources() + if not sources: + return None + return random.choice(sources) diff --git a/tests/test_voice_generator.py b/tests/test_voice_generator.py new file mode 100644 index 0000000..969d826 --- /dev/null +++ b/tests/test_voice_generator.py @@ -0,0 +1,90 @@ +import os +import tempfile +import wave +from unittest.mock import MagicMock, patch + +from server.voice_generator import VoiceGenerator + + +class TestVoiceGenerator: + @patch("server.voice_generator.TTS") + def test_init_loads_model(self, mock_tts_cls): + """Generator loads the XTTS v2 model on init.""" + mock_tts = MagicMock() + mock_tts_cls.return_value = mock_tts + + gen = VoiceGenerator(device="cpu") + mock_tts_cls.assert_called_once() + + @patch("server.voice_generator.TTS") + def test_generate_returns_wav_bytes(self, mock_tts_cls): + """Generate returns WAV bytes.""" + mock_tts = MagicMock() + mock_tts_cls.return_value = mock_tts + + # Create a real WAV file for the mock to "produce" + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + tmp_wav = f.name + with wave.open(f, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(22050) + wf.writeframes(b"\x00\x00" * 22050) # 1 second of silence + + try: + # Mock tts_to_file to copy our test WAV + def fake_tts_to_file(text, speaker_wav, language, file_path): + import shutil + shutil.copy2(tmp_wav, file_path) + + mock_tts.tts_to_file = fake_tts_to_file + + gen = VoiceGenerator(device="cpu") + data = gen.generate("hello", speaker_wav=tmp_wav) + assert isinstance(data, bytes) + assert len(data) > 0 + finally: + os.unlink(tmp_wav) + + @patch("server.voice_generator.TTS") + def test_list_clone_sources(self, mock_tts_cls): + """Lists available clone source files.""" + mock_tts = MagicMock() + mock_tts_cls.return_value = mock_tts + + with tempfile.TemporaryDirectory() as samples_dir: + # Create some fake sample files + for name in ["dog.wav", "machine.wav", "wind.wav"]: + with open(os.path.join(samples_dir, name), "wb") as f: + f.write(b"fake") + + gen = VoiceGenerator(device="cpu", samples_dir=samples_dir) + sources = gen.list_clone_sources() + assert len(sources) == 3 + assert all(s.endswith(".wav") for s in sources) + + @patch("server.voice_generator.TTS") + def test_random_clone_source(self, mock_tts_cls): + """Picks a random clone source from samples directory.""" + mock_tts = MagicMock() + mock_tts_cls.return_value = mock_tts + + with tempfile.TemporaryDirectory() as samples_dir: + for name in ["a.wav", "b.wav", "c.wav"]: + with open(os.path.join(samples_dir, name), "wb") as f: + f.write(b"fake") + + gen = VoiceGenerator(device="cpu", samples_dir=samples_dir) + source = gen.random_clone_source() + assert source is not None + assert source.endswith(".wav") + + @patch("server.voice_generator.TTS") + def test_empty_samples_dir(self, mock_tts_cls): + """Empty samples dir returns None for random source.""" + mock_tts = MagicMock() + mock_tts_cls.return_value = mock_tts + + with tempfile.TemporaryDirectory() as samples_dir: + gen = VoiceGenerator(device="cpu", samples_dir=samples_dir) + assert gen.random_clone_source() is None