diff --git a/tests/test_transcriber.py b/tests/test_transcriber.py index d43c434..2d193b1 100644 --- a/tests/test_transcriber.py +++ b/tests/test_transcriber.py @@ -10,7 +10,7 @@ class TestTranscriber: @patch("whisper_local.transcriber.WhisperModel") def test_init_loads_model(self, mock_model_class): t = Transcriber(model_name="small", compute_type="int8", language="de") - mock_model_class.assert_called_once_with("small", compute_type="int8") + mock_model_class.assert_called_once_with("small", compute_type="int8", download_root=None) assert t.language == "de" @patch("whisper_local.transcriber.WhisperModel") @@ -55,3 +55,9 @@ class TestTranscriber: result = t.transcribe(audio) assert result == "Hallo Welt" + + def test_init_with_preloaded_model(self): + mock_model = MagicMock() + t = Transcriber(language="de", model=mock_model) + assert t.model is mock_model + assert t.language == "de" diff --git a/whisper_local/transcriber.py b/whisper_local/transcriber.py index 55b0c1b..1aca66e 100644 --- a/whisper_local/transcriber.py +++ b/whisper_local/transcriber.py @@ -24,11 +24,22 @@ def _model_cache_dir() -> str | None: class Transcriber: - def __init__(self, model_name: str = "small", compute_type: str = "int8", language: str = "de"): + def __init__( + self, + model_name: str = "small", + compute_type: str = "int8", + language: str = "de", + model: WhisperModel | None = None, + ): self.language = language - logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type) - self.model = WhisperModel(model_name, compute_type=compute_type, download_root=_model_cache_dir()) - logger.info("Modell geladen") + if model is not None: + self.model = model + else: + logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type) + self.model = WhisperModel( + model_name, compute_type=compute_type, download_root=_model_cache_dir() + ) + logger.info("Modell geladen") def transcribe(self, audio: np.ndarray) -> str: """Transkribiert Audio-Array zu Text."""