diff --git a/tests/test_transcriber.py b/tests/test_transcriber.py new file mode 100644 index 0000000..d43c434 --- /dev/null +++ b/tests/test_transcriber.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from whisper_local.transcriber import Transcriber + + +class TestTranscriber: + @patch("whisper_local.transcriber.WhisperModel") + def test_init_loads_model(self, mock_model_class): + t = Transcriber(model_name="small", compute_type="int8", language="de") + mock_model_class.assert_called_once_with("small", compute_type="int8") + assert t.language == "de" + + @patch("whisper_local.transcriber.WhisperModel") + def test_transcribe_returns_text(self, mock_model_class): + mock_model = MagicMock() + mock_segment = MagicMock() + mock_segment.text = " Hallo Welt " + mock_model.transcribe.return_value = ([mock_segment], None) + mock_model_class.return_value = mock_model + + t = Transcriber(model_name="small", compute_type="int8", language="de") + audio = np.zeros(16000, dtype=np.float32) + result = t.transcribe(audio) + + assert result == "Hallo Welt" + mock_model.transcribe.assert_called_once_with(audio, language="de") + + @patch("whisper_local.transcriber.WhisperModel") + def test_transcribe_empty_segments(self, mock_model_class): + mock_model = MagicMock() + mock_model.transcribe.return_value = ([], None) + mock_model_class.return_value = mock_model + + t = Transcriber(model_name="small", compute_type="int8", language="de") + audio = np.zeros(16000, dtype=np.float32) + result = t.transcribe(audio) + + assert result == "" + + @patch("whisper_local.transcriber.WhisperModel") + def test_transcribe_multiple_segments(self, mock_model_class): + mock_model = MagicMock() + seg1 = MagicMock() + seg1.text = " Hallo " + seg2 = MagicMock() + seg2.text = " Welt " + mock_model.transcribe.return_value = ([seg1, seg2], None) + mock_model_class.return_value = mock_model + + t = Transcriber(model_name="small", compute_type="int8", language="de") + audio = np.zeros(16000, dtype=np.float32) + result = t.transcribe(audio) + + assert result == "Hallo Welt" diff --git a/whisper_local/transcriber.py b/whisper_local/transcriber.py new file mode 100644 index 0000000..614a13b --- /dev/null +++ b/whisper_local/transcriber.py @@ -0,0 +1,26 @@ +"""Whisper-Transkription via faster-whisper.""" + +import logging + +import numpy as np +from faster_whisper import WhisperModel + +logger = logging.getLogger(__name__) + + +class Transcriber: + def __init__(self, model_name: str = "small", compute_type: str = "int8", language: str = "de"): + self.language = language + logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type) + self.model = WhisperModel(model_name, compute_type=compute_type) + logger.info("Modell geladen") + + def transcribe(self, audio: np.ndarray) -> str: + """Transkribiert Audio-Array zu Text.""" + segments, _ = self.model.transcribe(audio, language=self.language) + text = " ".join(segment.text.strip() for segment in segments if segment.text.strip()) + if text: + logger.info("Transkribiert: %s", text) + else: + logger.info("Keine Sprache erkannt") + return text