feat: add XTTS v2 voice generator with clone source management
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
"""XTTS v2 wrapper for voice cloning from non-voice audio samples."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None # Tests patch this; real runtime requires the TTS package
|
||||
|
||||
from server.config import config
|
||||
|
||||
|
||||
class VoiceGenerator:
|
||||
"""Generates speech cloned from arbitrary audio samples via XTTS v2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
model_name: str | None = None,
|
||||
samples_dir: str | None = None,
|
||||
):
|
||||
self.device = device or config.device
|
||||
self.model_name = model_name or config.models.xtts_model
|
||||
self.samples_dir = Path(samples_dir or config.samples_dir)
|
||||
if TTS is None:
|
||||
raise RuntimeError(
|
||||
"TTS package is not installed; cannot instantiate VoiceGenerator"
|
||||
)
|
||||
self._tts = TTS(model_name=self.model_name)
|
||||
self._tts.to(self.device)
|
||||
|
||||
def generate(self, text: str, speaker_wav: str | None = None) -> bytes:
|
||||
"""Generate speech as WAV bytes. Uses a random clone source if none specified."""
|
||||
if speaker_wav is None:
|
||||
speaker_wav = self.random_clone_source()
|
||||
if speaker_wav is None:
|
||||
raise ValueError("No speaker WAV provided and no samples available")
|
||||
|
||||
# XTTS writes to file, so use a temp file
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.close()
|
||||
try:
|
||||
self._tts.tts_to_file(
|
||||
text=text,
|
||||
speaker_wav=speaker_wav,
|
||||
language=config.models.xtts_language,
|
||||
file_path=tmp.name,
|
||||
)
|
||||
with open(tmp.name, "rb") as f:
|
||||
return f.read()
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp.name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def list_clone_sources(self) -> list[str]:
|
||||
"""List all WAV files in the samples directory."""
|
||||
if not self.samples_dir.is_dir():
|
||||
return []
|
||||
return [
|
||||
str(p) for p in sorted(self.samples_dir.glob("*.wav"))
|
||||
]
|
||||
|
||||
def random_clone_source(self) -> str | None:
|
||||
"""Pick a random clone source WAV file."""
|
||||
sources = self.list_clone_sources()
|
||||
if not sources:
|
||||
return None
|
||||
return random.choice(sources)
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import tempfile
|
||||
import wave
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from server.voice_generator import VoiceGenerator
|
||||
|
||||
|
||||
class TestVoiceGenerator:
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_init_loads_model(self, mock_tts_cls):
|
||||
"""Generator loads the XTTS v2 model on init."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
gen = VoiceGenerator(device="cpu")
|
||||
mock_tts_cls.assert_called_once()
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_generate_returns_wav_bytes(self, mock_tts_cls):
|
||||
"""Generate returns WAV bytes."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
# Create a real WAV file for the mock to "produce"
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
tmp_wav = f.name
|
||||
with wave.open(f, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(22050)
|
||||
wf.writeframes(b"\x00\x00" * 22050) # 1 second of silence
|
||||
|
||||
try:
|
||||
# Mock tts_to_file to copy our test WAV
|
||||
def fake_tts_to_file(text, speaker_wav, language, file_path):
|
||||
import shutil
|
||||
shutil.copy2(tmp_wav, file_path)
|
||||
|
||||
mock_tts.tts_to_file = fake_tts_to_file
|
||||
|
||||
gen = VoiceGenerator(device="cpu")
|
||||
data = gen.generate("hello", speaker_wav=tmp_wav)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) > 0
|
||||
finally:
|
||||
os.unlink(tmp_wav)
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_list_clone_sources(self, mock_tts_cls):
|
||||
"""Lists available clone source files."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
with tempfile.TemporaryDirectory() as samples_dir:
|
||||
# Create some fake sample files
|
||||
for name in ["dog.wav", "machine.wav", "wind.wav"]:
|
||||
with open(os.path.join(samples_dir, name), "wb") as f:
|
||||
f.write(b"fake")
|
||||
|
||||
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
|
||||
sources = gen.list_clone_sources()
|
||||
assert len(sources) == 3
|
||||
assert all(s.endswith(".wav") for s in sources)
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_random_clone_source(self, mock_tts_cls):
|
||||
"""Picks a random clone source from samples directory."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
with tempfile.TemporaryDirectory() as samples_dir:
|
||||
for name in ["a.wav", "b.wav", "c.wav"]:
|
||||
with open(os.path.join(samples_dir, name), "wb") as f:
|
||||
f.write(b"fake")
|
||||
|
||||
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
|
||||
source = gen.random_clone_source()
|
||||
assert source is not None
|
||||
assert source.endswith(".wav")
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_empty_samples_dir(self, mock_tts_cls):
|
||||
"""Empty samples dir returns None for random source."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
with tempfile.TemporaryDirectory() as samples_dir:
|
||||
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
|
||||
assert gen.random_clone_source() is None
|
||||
Reference in New Issue
Block a user