diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..95a0986 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,67 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest + +from whisper_local.__main__ import App + + +class TestApp: + @patch("whisper_local.__main__.Transcriber") + @patch("whisper_local.__main__.HotkeyListener") + def test_app_init(self, mock_hotkey_class, mock_transcriber_class): + app = App() + assert app.recorder is not None + assert app.inserter is not None + mock_transcriber_class.assert_called_once() + mock_hotkey_class.assert_called_once() + + @patch("whisper_local.__main__.Transcriber") + @patch("whisper_local.__main__.HotkeyListener") + def test_on_press_starts_recording(self, mock_hotkey_class, mock_transcriber_class): + app = App() + app.recorder = MagicMock() + + import asyncio + asyncio.run(app.on_press()) + + app.recorder.start.assert_called_once() + + @patch("whisper_local.__main__.Transcriber") + @patch("whisper_local.__main__.HotkeyListener") + def test_on_release_stops_and_transcribes(self, mock_hotkey_class, mock_transcriber_class): + mock_transcriber = MagicMock() + mock_transcriber.transcribe.return_value = "Hallo" + mock_transcriber_class.return_value = mock_transcriber + + app = App() + app.recorder = MagicMock() + audio = np.zeros(16000, dtype=np.float32) + app.recorder.stop.return_value = audio + app.inserter = MagicMock() + app.inserter.insert = AsyncMock() + + import asyncio + asyncio.run(app.on_release()) + + app.recorder.stop.assert_called_once() + mock_transcriber.transcribe.assert_called_once_with(audio) + app.inserter.insert.assert_awaited_once_with("Hallo") + + @patch("whisper_local.__main__.Transcriber") + @patch("whisper_local.__main__.HotkeyListener") + def test_on_release_no_audio_skips(self, mock_hotkey_class, mock_transcriber_class): + mock_transcriber = MagicMock() + mock_transcriber_class.return_value = mock_transcriber + + app = App() + app.recorder = MagicMock() + app.recorder.stop.return_value = None + app.inserter = MagicMock() + app.inserter.insert = AsyncMock() + + import asyncio + asyncio.run(app.on_release()) + + mock_transcriber.transcribe.assert_not_called() + app.inserter.insert.assert_not_awaited() diff --git a/whisper_local/__main__.py b/whisper_local/__main__.py new file mode 100644 index 0000000..826a18a --- /dev/null +++ b/whisper_local/__main__.py @@ -0,0 +1,73 @@ +"""Entry-Point für whisper-local.""" + +import asyncio +import logging +import sys + +from whisper_local.config import Config, load_config +from whisper_local.hotkey import HotkeyListener +from whisper_local.inserter import Inserter +from whisper_local.recorder import Recorder +from whisper_local.transcriber import Transcriber + +logger = logging.getLogger(__name__) + + +class App: + def __init__(self, config: Config | None = None): + if config is None: + config = load_config() + + self.recorder = Recorder( + sample_rate=config.sample_rate, + channels=config.channels, + min_duration=config.min_duration, + ) + self.transcriber = Transcriber( + model_name=config.whisper_model, + compute_type=config.compute_type, + language=config.language, + ) + self.inserter = Inserter() + self.hotkey = HotkeyListener(key_name=config.hotkey) + self.hotkey.on_press = self.on_press + self.hotkey.on_release = self.on_release + + async def on_press(self) -> None: + """Callback: Hotkey gedrückt — Aufnahme starten.""" + logger.info("Aufnahme startet...") + self.recorder.start() + + async def on_release(self) -> None: + """Callback: Hotkey losgelassen — Aufnahme stoppen, transkribieren, einfügen.""" + audio = self.recorder.stop() + if audio is None: + logger.info("Keine Audio-Daten, übersprungen") + return + + logger.info("Transkribiere...") + text = self.transcriber.transcribe(audio) + if text: + await self.inserter.insert(text) + + async def run(self) -> None: + """Startet den Hauptloop.""" + logger.info("whisper-local gestartet, warte auf Hotkey...") + await self.hotkey.listen() + + +def main(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + app = App() + try: + asyncio.run(app.run()) + except KeyboardInterrupt: + logger.info("Beendet") + sys.exit(0) + + +if __name__ == "__main__": + main()