WIP: parakeeet
This commit is contained in:
132
stt_voice_messages/models/stt.py
Normal file
132
stt_voice_messages/models/stt.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file is part of Gajim.
|
||||
#
|
||||
# Gajim is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Gajim is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Gajim. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from gi.repository import GLib
|
||||
|
||||
from ..helper import Results
|
||||
from .model_settings import OnnxAsrSettings
|
||||
from .model_template import Model
|
||||
|
||||
log = logging.getLogger('gajim.p.sttvm_onnx_asr')
|
||||
|
||||
|
||||
_IDLE_UNLOAD_SECONDS = 300
|
||||
|
||||
class OnnxAsrModel(Model):
|
||||
def __init__(self):
|
||||
self._proc = None
|
||||
self._loaded = False
|
||||
self._config = OnnxAsrSettings()
|
||||
self._unload_source = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._loaded
|
||||
|
||||
@property
|
||||
def will_download(self) -> bool:
|
||||
if self.is_loaded or self._config.model_path:
|
||||
return False
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from onnx_asr.resolver import model_repos
|
||||
repo = model_repos.get(self._config.model_id, self._config.model_id)
|
||||
if '/' not in repo:
|
||||
return False
|
||||
return not isinstance(try_to_load_from_cache(repo, 'config.json'), str)
|
||||
|
||||
def load(self) -> None:
|
||||
if self._loaded:
|
||||
self._reschedule_unload()
|
||||
return
|
||||
log.debug('Loading model %s in worker', self._config.model_id)
|
||||
self._send({
|
||||
'op': 'load',
|
||||
'model_id': self._config.model_id,
|
||||
'model_path': self._config.model_path,
|
||||
})
|
||||
self._loaded = True
|
||||
self._reschedule_unload()
|
||||
|
||||
def recognize(self, result: Results, audio: np.ndarray) -> None:
|
||||
self.load()
|
||||
response = self._send({'op': 'recognize', 'audio': audio})
|
||||
result.text = response['text']
|
||||
self._reschedule_unload()
|
||||
|
||||
def set_config(self, config: OnnxAsrSettings) -> None:
|
||||
if (config.model_id != self._config.model_id
|
||||
or config.model_path != self._config.model_path):
|
||||
self.unload_now()
|
||||
self._config = OnnxAsrSettings(
|
||||
model_id=config.model_id, model_path=config.model_path)
|
||||
|
||||
def unload_now(self) -> None:
|
||||
if self._unload_source is not None:
|
||||
GLib.source_remove(self._unload_source)
|
||||
self._unload_source = None
|
||||
if self._proc is not None:
|
||||
log.debug('Terminating STT worker subprocess')
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
self._proc.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._proc.kill()
|
||||
self._proc.wait()
|
||||
self._proc = None
|
||||
self._loaded = False
|
||||
|
||||
def _ensure_proc(self) -> None:
|
||||
if self._proc is not None and self._proc.poll() is None:
|
||||
return
|
||||
log.debug('Starting STT worker subprocess')
|
||||
self._proc = subprocess.Popen(
|
||||
[sys.executable, str(Path(__file__).parent / 'stt_worker.py')],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
self._loaded = False
|
||||
|
||||
def _send(self, cmd: dict) -> dict:
|
||||
self._ensure_proc()
|
||||
pickle.dump(cmd, self._proc.stdin)
|
||||
self._proc.stdin.flush()
|
||||
try:
|
||||
response = pickle.load(self._proc.stdout)
|
||||
except EOFError as e:
|
||||
self._proc = None
|
||||
self._loaded = False
|
||||
raise RuntimeError('Worker subprocess exited unexpectedly') from e
|
||||
if not response.get('ok'):
|
||||
raise RuntimeError(response.get('error', 'unknown worker error'))
|
||||
return response
|
||||
|
||||
def _reschedule_unload(self) -> None:
|
||||
if self._unload_source is not None:
|
||||
GLib.source_remove(self._unload_source)
|
||||
self._unload_source = GLib.timeout_add_seconds(
|
||||
_IDLE_UNLOAD_SECONDS, self._on_idle_unload)
|
||||
|
||||
def _on_idle_unload(self) -> bool:
|
||||
self._unload_source = None
|
||||
log.debug('Idle unload after %ds', _IDLE_UNLOAD_SECONDS)
|
||||
self.unload_now()
|
||||
return GLib.SOURCE_REMOVE
|
||||
Reference in New Issue
Block a user