133 lines
4.4 KiB
Python
133 lines
4.4 KiB
Python
# This file is part of Gajim.
|
|
#
|
|
# Gajim is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Gajim is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with Gajim. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import logging
|
|
import pickle
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from gi.repository import GLib
|
|
|
|
from ..helper import Results
|
|
from .model_settings import OnnxAsrSettings
|
|
from .model_template import Model
|
|
|
|
log = logging.getLogger('gajim.p.sttvm_onnx_asr')
|
|
|
|
|
|
_IDLE_UNLOAD_SECONDS = 300
|
|
|
|
class OnnxAsrModel(Model):
|
|
def __init__(self):
|
|
self._proc = None
|
|
self._loaded = False
|
|
self._config = OnnxAsrSettings()
|
|
self._unload_source = None
|
|
|
|
@property
|
|
def is_loaded(self) -> bool:
|
|
return self._loaded
|
|
|
|
@property
|
|
def will_download(self) -> bool:
|
|
if self.is_loaded or self._config.model_path:
|
|
return False
|
|
from huggingface_hub import try_to_load_from_cache
|
|
from onnx_asr.resolver import model_repos
|
|
repo = model_repos.get(self._config.model_id, self._config.model_id)
|
|
if '/' not in repo:
|
|
return False
|
|
return not isinstance(try_to_load_from_cache(repo, 'config.json'), str)
|
|
|
|
def load(self) -> None:
|
|
if self._loaded:
|
|
self._reschedule_unload()
|
|
return
|
|
log.debug('Loading model %s in worker', self._config.model_id)
|
|
self._send({
|
|
'op': 'load',
|
|
'model_id': self._config.model_id,
|
|
'model_path': self._config.model_path,
|
|
})
|
|
self._loaded = True
|
|
self._reschedule_unload()
|
|
|
|
def recognize(self, result: Results, audio: np.ndarray) -> None:
|
|
self.load()
|
|
response = self._send({'op': 'recognize', 'audio': audio})
|
|
result.text = response['text']
|
|
self._reschedule_unload()
|
|
|
|
def set_config(self, config: OnnxAsrSettings) -> None:
|
|
if (config.model_id != self._config.model_id
|
|
or config.model_path != self._config.model_path):
|
|
self.unload_now()
|
|
self._config = OnnxAsrSettings(
|
|
model_id=config.model_id, model_path=config.model_path)
|
|
|
|
def unload_now(self) -> None:
|
|
if self._unload_source is not None:
|
|
GLib.source_remove(self._unload_source)
|
|
self._unload_source = None
|
|
if self._proc is not None:
|
|
log.debug('Terminating STT worker subprocess')
|
|
try:
|
|
self._proc.stdin.close()
|
|
self._proc.wait(timeout=2)
|
|
except subprocess.TimeoutExpired:
|
|
self._proc.kill()
|
|
self._proc.wait()
|
|
self._proc = None
|
|
self._loaded = False
|
|
|
|
def _ensure_proc(self) -> None:
|
|
if self._proc is not None and self._proc.poll() is None:
|
|
return
|
|
log.debug('Starting STT worker subprocess')
|
|
self._proc = subprocess.Popen(
|
|
[sys.executable, str(Path(__file__).parent / 'stt_worker.py')],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
)
|
|
self._loaded = False
|
|
|
|
def _send(self, cmd: dict) -> dict:
|
|
self._ensure_proc()
|
|
pickle.dump(cmd, self._proc.stdin)
|
|
self._proc.stdin.flush()
|
|
try:
|
|
response = pickle.load(self._proc.stdout)
|
|
except EOFError as e:
|
|
self._proc = None
|
|
self._loaded = False
|
|
raise RuntimeError('Worker subprocess exited unexpectedly') from e
|
|
if not response.get('ok'):
|
|
raise RuntimeError(response.get('error', 'unknown worker error'))
|
|
return response
|
|
|
|
def _reschedule_unload(self) -> None:
|
|
if self._unload_source is not None:
|
|
GLib.source_remove(self._unload_source)
|
|
self._unload_source = GLib.timeout_add_seconds(
|
|
_IDLE_UNLOAD_SECONDS, self._on_idle_unload)
|
|
|
|
def _on_idle_unload(self) -> bool:
|
|
self._unload_source = None
|
|
log.debug('Idle unload after %ds', _IDLE_UNLOAD_SECONDS)
|
|
self.unload_now()
|
|
return GLib.SOURCE_REMOVE
|