feat: Transcriber akzeptiert optionales vorgeladenes WhisperModel
This commit is contained in:
@@ -10,7 +10,7 @@ class TestTranscriber:
|
|||||||
@patch("whisper_local.transcriber.WhisperModel")
|
@patch("whisper_local.transcriber.WhisperModel")
|
||||||
def test_init_loads_model(self, mock_model_class):
|
def test_init_loads_model(self, mock_model_class):
|
||||||
t = Transcriber(model_name="small", compute_type="int8", language="de")
|
t = Transcriber(model_name="small", compute_type="int8", language="de")
|
||||||
mock_model_class.assert_called_once_with("small", compute_type="int8")
|
mock_model_class.assert_called_once_with("small", compute_type="int8", download_root=None)
|
||||||
assert t.language == "de"
|
assert t.language == "de"
|
||||||
|
|
||||||
@patch("whisper_local.transcriber.WhisperModel")
|
@patch("whisper_local.transcriber.WhisperModel")
|
||||||
@@ -55,3 +55,9 @@ class TestTranscriber:
|
|||||||
result = t.transcribe(audio)
|
result = t.transcribe(audio)
|
||||||
|
|
||||||
assert result == "Hallo Welt"
|
assert result == "Hallo Welt"
|
||||||
|
|
||||||
|
def test_init_with_preloaded_model(self):
|
||||||
|
mock_model = MagicMock()
|
||||||
|
t = Transcriber(language="de", model=mock_model)
|
||||||
|
assert t.model is mock_model
|
||||||
|
assert t.language == "de"
|
||||||
|
|||||||
@@ -24,11 +24,22 @@ def _model_cache_dir() -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
def __init__(self, model_name: str = "small", compute_type: str = "int8", language: str = "de"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "small",
|
||||||
|
compute_type: str = "int8",
|
||||||
|
language: str = "de",
|
||||||
|
model: WhisperModel | None = None,
|
||||||
|
):
|
||||||
self.language = language
|
self.language = language
|
||||||
logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type)
|
if model is not None:
|
||||||
self.model = WhisperModel(model_name, compute_type=compute_type, download_root=_model_cache_dir())
|
self.model = model
|
||||||
logger.info("Modell geladen")
|
else:
|
||||||
|
logger.info("Lade Whisper-Modell '%s' (compute_type=%s)...", model_name, compute_type)
|
||||||
|
self.model = WhisperModel(
|
||||||
|
model_name, compute_type=compute_type, download_root=_model_cache_dir()
|
||||||
|
)
|
||||||
|
logger.info("Modell geladen")
|
||||||
|
|
||||||
def transcribe(self, audio: np.ndarray) -> str:
|
def transcribe(self, audio: np.ndarray) -> str:
|
||||||
"""Transkribiert Audio-Array zu Text."""
|
"""Transkribiert Audio-Array zu Text."""
|
||||||
|
|||||||
Reference in New Issue
Block a user