feat: Transcriber akzeptiert optionales vorgeladenes WhisperModel

This commit is contained in:
2026-04-12 12:26:20 +02:00
parent e0893917c1
commit 3a580990ea
2 changed files with 22 additions and 5 deletions
+7 -1
View File
@@ -10,7 +10,7 @@ class TestTranscriber:
@patch("whisper_local.transcriber.WhisperModel") @patch("whisper_local.transcriber.WhisperModel")
def test_init_loads_model(self, mock_model_class): def test_init_loads_model(self, mock_model_class):
t = Transcriber(model_name="small", compute_type="int8", language="de") t = Transcriber(model_name="small", compute_type="int8", language="de")
mock_model_class.assert_called_once_with("small", compute_type="int8") mock_model_class.assert_called_once_with("small", compute_type="int8", download_root=None)
assert t.language == "de" assert t.language == "de"
@patch("whisper_local.transcriber.WhisperModel") @patch("whisper_local.transcriber.WhisperModel")
@@ -55,3 +55,9 @@ class TestTranscriber:
result = t.transcribe(audio) result = t.transcribe(audio)
assert result == "Hallo Welt" assert result == "Hallo Welt"
def test_init_with_preloaded_model(self):
mock_model = MagicMock()
t = Transcriber(language="de", model=mock_model)
assert t.model is mock_model
assert t.language == "de"
+13 -2
View File
@@ -24,10 +24,21 @@ def _model_cache_dir() -> str | None:
class Transcriber: class Transcriber:
def __init__(self, model_name: str = "small", compute_type: str = "int8", language: str = "de"): def __init__(
self,
model_name: str = "small",
compute_type: str = "int8",
language: str = "de",
model: WhisperModel | None = None,
):
self.language = language self.language = language
if model is not None:
self.model = model
else:
logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type) logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type)
self.model = WhisperModel(model_name, compute_type=compute_type, download_root=_model_cache_dir()) self.model = WhisperModel(
model_name, compute_type=compute_type, download_root=_model_cache_dir()
)
logger.info("Modell geladen") logger.info("Modell geladen")
def transcribe(self, audio: np.ndarray) -> str: def transcribe(self, audio: np.ndarray) -> str: