WIP: parakeeet

This commit is contained in:
hueso
2026-05-18 23:10:13 -03:00
parent 2e4aeb3b6f
commit aec56abe73
10 changed files with 549 additions and 654 deletions

View File

@@ -18,11 +18,6 @@ from dataclasses import dataclass, field
@dataclass
class OpenAIWhisperSettings:
model_size: str = field(default='tiny', init=True)
translate_to_english: bool = field(default=False, init=True)
@dataclass
class FasterWhisperSettings:
model_size: str = field(default='tiny', init=True)
translate_to_english: bool = field(default=False, init=True)
class OnnxAsrSettings:
model_id: str = field(default='nemo-parakeet-tdt-0.6b-v3', init=True)
model_path: str = ''

View File

@@ -14,16 +14,26 @@
# along with Gajim. If not, see <http://www.gnu.org/licenses/>.
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import numpy as np
from ..helper import Results
class Model(ABC):
@property
@abstractmethod
def transcribe(self, result: Results, audio_file: Path) -> None:
def is_loaded(self) -> bool:
pass
@abstractmethod
def load(self) -> None:
pass
@abstractmethod
def recognize(self, result: Results, audio: np.ndarray) -> None:
pass
@abstractmethod

View File

@@ -0,0 +1,132 @@
# This file is part of Gajim.
#
# Gajim is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Gajim is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Gajim. If not, see <http://www.gnu.org/licenses/>.
import logging
import pickle
import subprocess
import sys
from pathlib import Path
import numpy as np
from gi.repository import GLib
from ..helper import Results
from .model_settings import OnnxAsrSettings
from .model_template import Model
log = logging.getLogger('gajim.p.sttvm_onnx_asr')
_IDLE_UNLOAD_SECONDS = 300
class OnnxAsrModel(Model):
def __init__(self):
self._proc = None
self._loaded = False
self._config = OnnxAsrSettings()
self._unload_source = None
@property
def is_loaded(self) -> bool:
return self._loaded
@property
def will_download(self) -> bool:
if self.is_loaded or self._config.model_path:
return False
from huggingface_hub import try_to_load_from_cache
from onnx_asr.resolver import model_repos
repo = model_repos.get(self._config.model_id, self._config.model_id)
if '/' not in repo:
return False
return not isinstance(try_to_load_from_cache(repo, 'config.json'), str)
def load(self) -> None:
if self._loaded:
self._reschedule_unload()
return
log.debug('Loading model %s in worker', self._config.model_id)
self._send({
'op': 'load',
'model_id': self._config.model_id,
'model_path': self._config.model_path,
})
self._loaded = True
self._reschedule_unload()
def recognize(self, result: Results, audio: np.ndarray) -> None:
self.load()
response = self._send({'op': 'recognize', 'audio': audio})
result.text = response['text']
self._reschedule_unload()
def set_config(self, config: OnnxAsrSettings) -> None:
if (config.model_id != self._config.model_id
or config.model_path != self._config.model_path):
self.unload_now()
self._config = OnnxAsrSettings(
model_id=config.model_id, model_path=config.model_path)
def unload_now(self) -> None:
if self._unload_source is not None:
GLib.source_remove(self._unload_source)
self._unload_source = None
if self._proc is not None:
log.debug('Terminating STT worker subprocess')
try:
self._proc.stdin.close()
self._proc.wait(timeout=2)
except subprocess.TimeoutExpired:
self._proc.kill()
self._proc.wait()
self._proc = None
self._loaded = False
def _ensure_proc(self) -> None:
if self._proc is not None and self._proc.poll() is None:
return
log.debug('Starting STT worker subprocess')
self._proc = subprocess.Popen(
[sys.executable, str(Path(__file__).parent / 'stt_worker.py')],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
self._loaded = False
def _send(self, cmd: dict) -> dict:
self._ensure_proc()
pickle.dump(cmd, self._proc.stdin)
self._proc.stdin.flush()
try:
response = pickle.load(self._proc.stdout)
except EOFError as e:
self._proc = None
self._loaded = False
raise RuntimeError('Worker subprocess exited unexpectedly') from e
if not response.get('ok'):
raise RuntimeError(response.get('error', 'unknown worker error'))
return response
def _reschedule_unload(self) -> None:
if self._unload_source is not None:
GLib.source_remove(self._unload_source)
self._unload_source = GLib.timeout_add_seconds(
_IDLE_UNLOAD_SECONDS, self._on_idle_unload)
def _on_idle_unload(self) -> bool:
self._unload_source = None
log.debug('Idle unload after %ds', _IDLE_UNLOAD_SECONDS)
self.unload_now()
return GLib.SOURCE_REMOVE

View File

@@ -0,0 +1,54 @@
# This file is part of Gajim.
#
# Gajim is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Gajim is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Gajim. If not, see <http://www.gnu.org/licenses/>.
import pickle
import sys
import traceback
def _respond(response: dict) -> None:
pickle.dump(response, sys.stdout.buffer)
sys.stdout.buffer.flush()
def main() -> None:
model = None
while True:
try:
cmd = pickle.load(sys.stdin.buffer)
except EOFError:
return
try:
op = cmd['op']
if op == 'load':
import onnx_asr
model = onnx_asr.load_model(
cmd['model_id'], cmd.get('model_path') or None)
_respond({'ok': True})
elif op == 'recognize':
text = model.recognize(cmd['audio'])
_respond({'ok': True, 'text': text})
else:
_respond({'ok': False, 'error': f'unknown op: {op}'})
except Exception as e:
_respond({
'ok': False,
'error': f'{type(e).__name__}: {e}',
'traceback': traceback.format_exc(),
})
if __name__ == '__main__':
main()