diff --git a/stt_voice_messages/gtk/config_dialog.py b/stt_voice_messages/gtk/config_dialog.py index f29d2eb..7fa36b5 100644 --- a/stt_voice_messages/gtk/config_dialog.py +++ b/stt_voice_messages/gtk/config_dialog.py @@ -18,33 +18,23 @@ from __future__ import annotations import logging import typing from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any try: - import whisper + import onnx_asr except ModuleNotFoundError: if typing.TYPE_CHECKING: - import whisper + import onnx_asr -try: - import faster_whisper as fwhisper -except ModuleNotFoundError: - if typing.TYPE_CHECKING: - import faster_whisper as fwhisper +from gi.repository import Adw, Gtk -from gi.repository import Gtk - -from gajim.common import app -from gajim.common.app import Any -from gajim.gtk.builder import get_builder from gajim.gtk.const import Setting, SettingKind, SettingType -from gajim.gtk.settings import SettingsBox -from gajim.gtk.sidebar_switcher import SideBarSwitcher -from gajim.plugins.helpers import get_builder +from gajim.gtk.filechoosers import Filter +from gajim.gtk.settings import GajimPreferencesGroup, SettingsDialog from gajim.plugins.plugins_i18n import _ -from ..models import faster_whisper, openai_whisper -from ..models.model_settings import * +from ..models import stt +from ..models.model_settings import OnnxAsrSettings if TYPE_CHECKING: from ..stt_voice_messages import STTVoiceMessagesPlugin @@ -52,271 +42,250 @@ if TYPE_CHECKING: log = logging.getLogger('gajim.p.sttvm_config_dialog') -@dataclass -class Model: - name: str - required_moduls: list[str] - klass: object - config: Any - instance: typing.Optional[object] = None - - -SUPPORTED_MODELS: dict[str, Model] = { - 'model_openaiwhisper': Model('OpenAI Whisper', - ['whisper'], - openai_whisper.WhisperModel, - OpenAIWhisperSettings), - 'model_faster-whisper': Model('Faster-Whisper', - ['faster_whisper'], - faster_whisper.FasterWhisperModel, - FasterWhisperSettings) -} - - class Configuration: def __init__(self, plugin: STTVoiceMessagesPlugin): self._plugin = plugin - - self._available_models: dict[str, Model] = {} - self.check_available_moduls() - - log.debug('config = %s', self._plugin.config) + self._instance = None + self._main_model_row = None + self._preset_model_picker = None + self._custom_model_id_entry = None + self._local_model_file_picker = None + self._status_group = None + self._model_data: dict[str, str] = {} + self._instance = stt.OnnxAsrModel() + self._instance.set_config(OnnxAsrSettings( + model_id=self.plugin.config['model_id'], + model_path=self.plugin.config['model_path'] + )) + self._model_data = self._steal_model_list() @property def plugin(self) -> STTVoiceMessagesPlugin: return self._plugin @property - def available_models(self) -> dict[str, Model]: - return self._available_models + def is_available(self) -> bool: + return self._instance is not None + + def unload_model(self) -> None: + if self._instance is not None: + self._instance.unload_now() + + def _steal_model_list(self) -> dict[str, str]: + # UGLY: Extract available model choices from onnx_asr type hints. + ann = onnx_asr.load_model.__annotations__.get('model') + return { + v: v for arg in typing.get_args(ann) + for v in typing.get_args(arg) + if isinstance(v, str) + } def on_setting(self, value: Any, data: Any) -> None: if isinstance(value, str): - value.strip() - - log.debug('plugin config before:\n %s', self.plugin.config.data) + value = value.strip() self.plugin.config[data] = value - log.debug('plugin config after:\n %s', self.plugin.config.data) - def on_config_model(self, model: str, value: Any, data: Any) -> None: - if isinstance(value, str): - value.strip() + def on_preset_changed(self, value: str, data: Any) -> None: + if self._custom_model_id_entry is not None: + entry_text = self._custom_model_id_entry.entry.get_text().strip() + if entry_text: + self._update_model_status() + return # custom entry overrides; ignore preset change + self._write_model_id(value) + self._update_model_status() - log.debug('plugin config before:\n %s', self.plugin.config.data[model]) - setattr(self.plugin.config.data[model], data, value) - log.debug('plugin config after:\n %s', self.plugin.config.data[model]) + def on_custom_model_id_changed(self, value: str, data: Any) -> None: + value = value.strip() + if value: + self._write_model_id(value) + elif self._preset_model_picker is not None: + preset_key = self._preset_model_picker._dropdown.get_selected_key() + if preset_key is not None: + self._write_model_id(preset_key) + self._apply_sensitivity_state() + self._update_model_status() - self._plugin.config.data[model].instance.set_config(self.plugin.config.data[model]) + def on_model_file_picked(self, value: str, data: Any) -> None: + self._write_model_path(str(Path(value).parent) if value else '') + self._apply_sensitivity_state() + self._update_model_status() - def create_model(self, model: Any) -> None: - if (self.plugin.config.data[model].instance is None and - self._available_models[model].klass is not None): - self.plugin.config.data[model].instance = \ - self._available_models[model].klass() + def _write_model_id(self, model_id: str) -> None: + if self.plugin.config['model_id'] == model_id: + return + self.plugin.config['model_id'] = model_id + if self._instance is not None: + self._instance.set_config(OnnxAsrSettings( + model_id=self.plugin.config['model_id'], + model_path=self.plugin.config['model_path'] + )) + + def _write_model_path(self, model_path: str) -> None: + if self.plugin.config['model_path'] == model_path: + return + self.plugin.config['model_path'] = model_path + if self._instance is not None: + self._instance.set_config(OnnxAsrSettings( + model_id=self.plugin.config['model_id'], + model_path=self.plugin.config['model_path'] + )) + + def sync_model_path_from_widget(self) -> None: + if self._local_model_file_picker is None: + return + button = self._local_model_file_picker.get_activatable_widget() + path = button.get_path() + new_path = str(path.parent) if path else '' + self._write_model_path(new_path) + + def _apply_sensitivity_state(self) -> None: + if self._preset_model_picker is None: + return + has_local = bool(self.plugin.config['model_path']) + entry_text = (self._custom_model_id_entry.entry.get_text().strip() + if self._custom_model_id_entry else '') + has_entry = bool(entry_text) + self._custom_model_id_entry.set_sensitive(not has_local) + self._preset_model_picker.set_sensitive(not has_local and not has_entry) + + def _update_model_status(self) -> None: + if self._main_model_row is None: + return + entry_text = (self._custom_model_id_entry.entry.get_text().strip() + if self._custom_model_id_entry else '') + + if self.plugin.config['model_path']: + path = Path(self.plugin.config['model_path']) + summary = _('Local: {}').format(path.name or str(path)) + description = _('Loading model files from {}').format(path) + if not (path / 'config.json').exists(): + description += '\n' + _( + 'config.json not found in this directory — onnx-asr will' + ' fall back to Model preset or Custom Model ID for the' + ' architecture.') + elif entry_text: + summary = _('Custom: {}').format(entry_text) + description = _('Using custom model: {}').format(entry_text) else: - log.debug('Could not create model %s', model) + preset_key = (self._preset_model_picker._dropdown.get_selected_key() + if self._preset_model_picker else '') + summary = preset_key or _('(none)') + description = (_('Using preset: {}').format(preset_key) + if preset_key else '') - def on_set_model(self, model: Any, data: str = 'model') -> None: - if isinstance(model, str): - model.strip() - - self.plugin.config['model'] = model - log.debug('Created model %s with config %s', model, self.plugin.config.data[model]) - - def check_available_moduls(self): - def is_module_available(module: str) -> bool: - try: - __import__(module) - return True - except ModuleNotFoundError: - log.debug('Could not find module %s', module) - return False - except ImportError as ex: - log.debug(str(ex)) - return False - - for model in SUPPORTED_MODELS: - available = True - for modul in SUPPORTED_MODELS[model].required_moduls: - if not is_module_available(modul): - available = False - continue - if available: - self._available_models[model] = SUPPORTED_MODELS[model] - if SUPPORTED_MODELS[model].config is not None: - log.debug('created config for model = %s: %s', model, self._available_models[model]) - log.debug('plugin config for model = %s', self.plugin.config[model]) - self.plugin.config.data[model].instance = None - self._available_models[model].config = self.plugin.config[model] - self.create_model(model) - - self.on_set_model(self._plugin.config['model']) - - log.debug('models = %s', self._available_models) + self._main_model_row._label.set_text(summary) + if self._status_group is not None: + self._status_group.set_description(description) -class PreferenceBox(SettingsBox): - def __init__(self, settings: list[Setting]) -> None: - SettingsBox.__init__(self, None) - self.get_style_context().add_class('border') - self.set_selection_mode(Gtk.SelectionMode.NONE) - self.set_vexpand(False) - self.set_valign(Gtk.Align.END) - - for setting in settings: - self.add_setting(setting) - self.update_states() - - -class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): +class STTVoiceMessagesConfigDialog(SettingsDialog): def __init__(self, config: Configuration, parent: Gtk.Window) -> None: - Gtk.ApplicationWindow.__init__(self) - - self.set_application(app.app) - self.set_position(Gtk.WindowPosition.CENTER) - self.set_show_menubar(False) - self.set_name('PreferencesWindow') - self.set_default_size(900, 650) - self.set_resizable(True) - self.set_title(_('STT Voice Messages - Preferences')) - - ui_path = Path(__file__).parent - self._ui = get_builder(str(ui_path.resolve() / 'config_dialog.ui')) - - self._prefs: dict[str, PreferenceBox] = {} - prefs: list[tuple[str, type[PreferenceBox]]] = [ - ('stt_behaviour', self.STTBehaviour), - ('models', self.Models), - ] - - if 'model_openaiwhisper' in config.available_models: - prefs.append(('openaiwhisper_general', self.OpenAIWhisperGeneral)) - else: - self._ui.stack.remove(getattr(self._ui, 'openai-whisper')) - - if 'model_faster-whisper' in config.available_models: - prefs.append(('fasterwhisper_general', self.FasterWhisperGeneral)) - else: - self._ui.stack.remove(getattr(self._ui, 'faster-whisper')) - - side_bar_switcher = SideBarSwitcher() - side_bar_switcher.set_stack(self._ui.stack) - self._ui.grid.attach(side_bar_switcher, 0, 0, 1, 1) - self.add(self._ui.grid) - self.config = config self.plugin = self.config.plugin - self._add_prefs(prefs) + if not config.is_available: + return - self.show_all() + rows = [ + Setting(SettingKind.SWITCH, + _('Auto Transcribe'), + SettingType.VALUE, + value=self.plugin.config['auto_transcribe'], + data='auto_transcribe', + callback=config.on_setting, + desc=_('Transcribe messages as they appear')), + Setting(SettingKind.SUBPAGE, + _('Model'), + SettingType.VALUE, + value=None, + name='main_model', + props={'subpage': 'sttvm-model'}), + ] - def _add_prefs(self, prefs: list[tuple[str, type[PreferenceBox]]]): - for ui_name, klass in prefs: - pref_box = getattr(self._ui, ui_name) - pref = klass(self) # pyright: ignore - log.debug('ui_name = %s, klass = %s, pref_box = %s', ui_name, klass, pref_box) - pref_box.add(pref) - self._prefs[ui_name] = pref + SettingsDialog.__init__( + self, + parent, + _('STT Voice Messages'), + Gtk.DialogFlags.MODAL, + rows, + '', + ) + + config._main_model_row = self.get_setting('main_model') + + use_custom = self.plugin.config['model_id'] not in config._model_data - ############################################################################ - # General Settings - ############################################################################ - class STTBehaviour(PreferenceBox): - def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: - settings = [ - Setting(SettingKind.SWITCH, - _('Auto Transcribe'), - SettingType.VALUE, - value=config_dialog.plugin.config['auto_transcribe'], - data='auto_transcribe', - callback=config_dialog.config.on_setting) - ] + subpage_rows: list[Setting] = [ + Setting(SettingKind.DROPDOWN, + _('Model'), + SettingType.VALUE, + value=self.plugin.config['model_id'], + name='preset_model', + callback=config.on_preset_changed, + props={'data': config._model_data}), + Setting(SettingKind.ENTRY, + _('Custom Model'), + SettingType.VALUE, + value=self.plugin.config['model_id'] if use_custom else '', + name='custom_model', + callback=config.on_custom_model_id_changed, + desc=_('Custom HF model path or model ID')), + Setting(SettingKind.FILECHOOSER, + _('Local File'), + SettingType.VALUE, + value='', + name='local_model_file', + callback=config.on_model_file_picked, + desc=_('Model ID is taken from config.json if not set'), + props={'filefilters': [ + Filter(_('ONNX model'), suffixes=['onnx'], default=True), + ]}), + ] - PreferenceBox.__init__(self, settings) + controls_group = GajimPreferencesGroup('model_controls') + for s in subpage_rows: + controls_group.add_setting(s) - class Models(PreferenceBox): - def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: - models: list[tuple[str, str]] = [] - for key, value in config_dialog.config.available_models.items(): - models.append( - (key, str(value.name)) - ) + status_group = Adw.PreferencesGroup() - settings = [ - Setting(SettingKind.COMBO, - _('Speech To Text Model'), - SettingType.VALUE, - value=config_dialog.plugin.config['model'], - data='model', - callback=config_dialog.config.on_set_model, - props={'combo_items': models}, - desc=_('Choose Model to use')), - ] + pref_page = Adw.PreferencesPage() + pref_page.add(controls_group) + pref_page.add(status_group) - PreferenceBox.__init__(self, settings) + toolbar = Adw.ToolbarView(content=pref_page) + toolbar.add_top_bar(Adw.HeaderBar()) - ############################################################################ - # OpenAI Whisper Settings - ############################################################################ - class OpenAIWhisperGeneral(PreferenceBox): - def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: + page = Adw.NavigationPage( + tag='sttvm-model', title=_('Model'), child=toolbar) + self._nav.add(page) - self._model = 'model_openaiwhisper' - self._config_dialog = config_dialog + config._preset_model_picker = controls_group.get_setting('preset_model') + config._custom_model_id_entry = controls_group.get_setting('custom_model') + config._local_model_file_picker = controls_group.get_setting( + 'local_model_file') + config._status_group = status_group - settings = [ - Setting(SettingKind.POPOVER, - _('Language Model Size'), - SettingType.VALUE, - value=config_dialog.config.available_models[self._model].config.model_size, - data='model_size', - callback=self._set_config, - props={'entries': whisper.available_models()}), + config._custom_model_id_entry.entry.set_placeholder_text( + _('onnx-community/whisper-large-v3-turbo')) - Setting(SettingKind.SWITCH, - _('Translate'), - SettingType.VALUE, - value=config_dialog.config.available_models[self._model].config.translate_to_english, - data='translate_to_english', - callback=self._set_config) - ] + button = config._local_model_file_picker.get_activatable_widget() + button._label_text = _('.oonx') + button.reset() - PreferenceBox.__init__(self, settings) + if self.plugin.config['model_path']: + onnx_in_dir = next(iter(Path(self.plugin.config['model_path']).glob('*.onnx')), + None) + if onnx_in_dir is not None: + button.set_path(onnx_in_dir) - def _set_config(self, value: Any, data: Any): - self._config_dialog.config.on_config_model(self._model, value, data) + config._update_model_status() + config._apply_sensitivity_state() - ############################################################################ - # Faster Whisper Settings - ############################################################################ - class FasterWhisperGeneral(PreferenceBox): - def __init__(self, - config_dialog: STTVoiceMessagesConfigDialog) -> None: - self._model = 'model_faster-whisper' - self._config_dialog = config_dialog - - settings = [ - Setting(SettingKind.POPOVER, - _('Language Model Size'), - SettingType.VALUE, - value=config_dialog.config.available_models[ - self._model].config.model_size, - data='model_size', - callback=self._set_config, - props={'entries': fwhisper.available_models()}), - - Setting(SettingKind.SWITCH, - _('Translate'), - SettingType.VALUE, - value=config_dialog.config.available_models[ - self._model].config.translate_to_english, - data='translate_to_english', - callback=self._set_config) - ] - - PreferenceBox.__init__(self, settings) - - def _set_config(self, value: Any, data: Any): - self._config_dialog.config.on_config_model(self._model, value, - data) + def _cleanup(self) -> None: + self.config.sync_model_path_from_widget() + self.config._main_model_row = None + self.config._preset_model_picker = None + self.config._custom_model_id_entry = None + self.config._local_model_file_picker = None + self.config._status_group = None + SettingsDialog._cleanup(self) diff --git a/stt_voice_messages/gtk/config_dialog.ui b/stt_voice_messages/gtk/config_dialog.ui deleted file mode 100644 index 6aec4bf..0000000 --- a/stt_voice_messages/gtk/config_dialog.ui +++ /dev/null @@ -1,349 +0,0 @@ - - - - - - - True - False - - - True - False - True - - - True - True - never - in - False - - - True - False - - - True - False - vertical - 24 - - - - True - False - vertical - 12 - - - True - False - Behaviour of STT Voice Messages - 0 - - - - 0 - 0 - - - - - False - True - 0 - - - - - - True - False - vertical - 12 - - - True - False - General Model Configuration - 0 - - - - 0 - 0 - - - - - False - True - 1 - - - - - - True - False - vertical - 12 - - - True - False - Preview UI - 0 - - - - 0 - 0 - - - - - False - True - 2 - - - - - - - - - general - General - computer-symbolic - - - - - True - True - never - in - False - - - True - False - - - True - False - vertical - 24 - - - - True - False - vertical - 12 - - - True - False - General - 0 - - - - 0 - 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - False - True - 0 - - - - - - - - - - - - - - - - - - - - - openai-whisper - OpenAI Whisper - 1 - - - - - True - True - never - in - False - - - True - False - - - True - False - vertical - 24 - - - - True - False - vertical - 12 - - - True - False - General - 0 - - - - 0 - 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - False - True - 0 - - - - - - - - - - - - - - - - - - - - - faster-whisper - Faster Whisper - 2 - - - - - - 1 - 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/stt_voice_messages/gtk/sttbox.py b/stt_voice_messages/gtk/sttbox.py index dbb2666..1d94d46 100644 --- a/stt_voice_messages/gtk/sttbox.py +++ b/stt_voice_messages/gtk/sttbox.py @@ -13,9 +13,12 @@ # You should have received a copy of the GNU General Public License # along with Gajim. If not, see . -import logging +from __future__ import annotations -from gi.repository import Gtk +import logging +from pathlib import Path + +from gi.repository import Gtk, Adw from gajim.plugins.gajimplugin import GajimPluginConfig from gajim.plugins.plugins_i18n import _ @@ -26,49 +29,62 @@ log = logging.getLogger('gajim.p.stt_voice_messages_sttbox') class STTBox(Gtk.Box): def __init__(self, - preview_audio_widget: Gtk.Box, config: GajimPluginConfig, - audio_file: str, + audio_file: Path, ) -> None: - Gtk.Box.__init__(self, orientation=Gtk.Orientation.VERTICAL, spacing=12) + Gtk.Box.__init__(self, orientation=Gtk.Orientation.HORIZONTAL, spacing=6) self._config = config - self._preview_audio = preview_audio_widget self._model = None self._audio_file = audio_file self._text = '' - self._transcribe_button = Gtk.Button(label=_('Transcribe')) + self._transcribe_button = Gtk.Button.new_from_icon_name("lucide-captions-symbolic") + self._transcribe_button.set_tooltip_text(_('Transcribe voice message')) + + self._spinner = Adw.Spinner(valign=Gtk.Align.START, visible=False) self._transcription_label = Gtk.Label( label=_('Nothing transcribed yet')) self._transcription_label.set_max_width_chars(40) - self._transcription_label.set_line_wrap(True) - - self.add(self._transcribe_button) - self.add(self._transcription_label) + self._transcription_label.set_wrap(True) + self.append(self._spinner) + self.append(self._transcription_label) + self._transcribe_button.connect('clicked', self._on_transcribe_clicked) self._result = helper.Results('') - self._transcribe_button.connect('clicked', self._on_transcribe_clicked) - - self.show_all() + @property + def button(self) -> Gtk.Button: + return self._transcribe_button def _on_transcribe_clicked(self, _button: Gtk.Button) -> None: - log.debug('config.data = %s', self._config.data) - model_name = self._config.data['model'] - model = self._config.data[model_name].instance - if model is None: + log.debug('config._instance = %s', self._config._instance) + self._model = self._config._instance + if self._model is None: return - self._model = model + if self._model.is_loaded: + text = _('Transcribing…') + elif self._model.will_download: + text = _('Downloading ') + self._model.model_id + else: + text = _('Loading model…') + self._transcription_label.set_text(text) + self._spinner.set_visible(True) + self._task = helper.BackgroundTask( + self._model.load, self._on_load_done) + self._task.start() - transcription_task = helper.BackgroundTask( - self._model.transcribe(self._result, self._audio_file), - self._show_result + def _on_load_done(self): + self._transcription_label.set_text(_('Transcribing…')) + self._task = helper.BackgroundTask( + lambda: self._model.recognize( + self._result, helper.load_audio(self._audio_file)), + self._show_result, ) - transcription_task.start() + self._task.start() def _show_result(self): assert self._model is not None @@ -77,3 +93,4 @@ class STTBox(Gtk.Box): self._transcription_label.set_text(self._text.strip()) else: self._transcription_label.set_text(_('_Have not heard any word!_')) + self._spinner.set_visible(False) diff --git a/stt_voice_messages/helper.py b/stt_voice_messages/helper.py index 88cc853..1baf5f6 100644 --- a/stt_voice_messages/helper.py +++ b/stt_voice_messages/helper.py @@ -13,16 +13,53 @@ # You should have received a copy of the GNU General Public License # along with Gajim. If not, see . +import logging +import typing from dataclasses import dataclass +from pathlib import Path +import gi +import numpy as np from gi.repository import Gio, GObject +try: + gi.require_version('Gst', '1.0') + from gi.repository import Gst +except Exception: + if typing.TYPE_CHECKING: + from gi.repository import Gst + +log = logging.getLogger('gajim.p.sttvm_helper') + @dataclass class Results: text: str +def load_audio(path: Path, sample_rate: int = 16000) -> np.ndarray: + Gst.init(None) + pipeline = Gst.parse_launch( + 'filesrc name=src ! decodebin ! audioconvert ! audioresample ! ' + f'audio/x-raw,format=F32LE,rate={sample_rate},channels=1 ! ' + 'appsink name=sink sync=false' + ) + pipeline.get_by_name('src').set_property('location', str(path)) + sink = pipeline.get_by_name('sink') + chunks: list[np.ndarray] = [] + + pipeline.set_state(Gst.State.PLAYING) + while (sample := sink.emit('try-pull-sample', 10 * Gst.SECOND)) is not None: + buf = sample.get_buffer() + _, info = buf.map(Gst.MapFlags.READ) + chunks.append(np.frombuffer(bytes(info.data), dtype=np.float32)) + buf.unmap(info) + pipeline.set_state(Gst.State.NULL) + + if not chunks: + raise RuntimeError(f'Could not decode audio: {path}') + return np.concatenate(chunks) + ''' https://discourse.gnome.org/t/gtk-threading-problem-with-glib-idle-add/13597/5 @@ -57,6 +94,7 @@ class BackgroundTask(GObject.Object): retval = self.function() task.return_value(retval) except Exception as e: + log.exception('Background task failed') task.return_value(e) def finish(self): diff --git a/stt_voice_messages/models/model_settings.py b/stt_voice_messages/models/model_settings.py index 55467f6..2cede5e 100644 --- a/stt_voice_messages/models/model_settings.py +++ b/stt_voice_messages/models/model_settings.py @@ -18,11 +18,6 @@ from dataclasses import dataclass, field @dataclass -class OpenAIWhisperSettings: - model_size: str = field(default='tiny', init=True) - translate_to_english: bool = field(default=False, init=True) - -@dataclass -class FasterWhisperSettings: - model_size: str = field(default='tiny', init=True) - translate_to_english: bool = field(default=False, init=True) \ No newline at end of file +class OnnxAsrSettings: + model_id: str = field(default='nemo-parakeet-tdt-0.6b-v3', init=True) + model_path: str = '' diff --git a/stt_voice_messages/models/model_template.py b/stt_voice_messages/models/model_template.py index fe11984..d3dbba0 100644 --- a/stt_voice_messages/models/model_template.py +++ b/stt_voice_messages/models/model_template.py @@ -14,16 +14,26 @@ # along with Gajim. If not, see . from abc import ABC, abstractmethod -from pathlib import Path from typing import Any +import numpy as np + from ..helper import Results class Model(ABC): + @property @abstractmethod - def transcribe(self, result: Results, audio_file: Path) -> None: + def is_loaded(self) -> bool: + pass + + @abstractmethod + def load(self) -> None: + pass + + @abstractmethod + def recognize(self, result: Results, audio: np.ndarray) -> None: pass @abstractmethod diff --git a/stt_voice_messages/models/stt.py b/stt_voice_messages/models/stt.py new file mode 100644 index 0000000..538de3a --- /dev/null +++ b/stt_voice_messages/models/stt.py @@ -0,0 +1,132 @@ +# This file is part of Gajim. +# +# Gajim is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Gajim is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Gajim. If not, see . + +import logging +import pickle +import subprocess +import sys +from pathlib import Path + +import numpy as np +from gi.repository import GLib + +from ..helper import Results +from .model_settings import OnnxAsrSettings +from .model_template import Model + +log = logging.getLogger('gajim.p.sttvm_onnx_asr') + + +_IDLE_UNLOAD_SECONDS = 300 + +class OnnxAsrModel(Model): + def __init__(self): + self._proc = None + self._loaded = False + self._config = OnnxAsrSettings() + self._unload_source = None + + @property + def is_loaded(self) -> bool: + return self._loaded + + @property + def will_download(self) -> bool: + if self.is_loaded or self._config.model_path: + return False + from huggingface_hub import try_to_load_from_cache + from onnx_asr.resolver import model_repos + repo = model_repos.get(self._config.model_id, self._config.model_id) + if '/' not in repo: + return False + return not isinstance(try_to_load_from_cache(repo, 'config.json'), str) + + def load(self) -> None: + if self._loaded: + self._reschedule_unload() + return + log.debug('Loading model %s in worker', self._config.model_id) + self._send({ + 'op': 'load', + 'model_id': self._config.model_id, + 'model_path': self._config.model_path, + }) + self._loaded = True + self._reschedule_unload() + + def recognize(self, result: Results, audio: np.ndarray) -> None: + self.load() + response = self._send({'op': 'recognize', 'audio': audio}) + result.text = response['text'] + self._reschedule_unload() + + def set_config(self, config: OnnxAsrSettings) -> None: + if (config.model_id != self._config.model_id + or config.model_path != self._config.model_path): + self.unload_now() + self._config = OnnxAsrSettings( + model_id=config.model_id, model_path=config.model_path) + + def unload_now(self) -> None: + if self._unload_source is not None: + GLib.source_remove(self._unload_source) + self._unload_source = None + if self._proc is not None: + log.debug('Terminating STT worker subprocess') + try: + self._proc.stdin.close() + self._proc.wait(timeout=2) + except subprocess.TimeoutExpired: + self._proc.kill() + self._proc.wait() + self._proc = None + self._loaded = False + + def _ensure_proc(self) -> None: + if self._proc is not None and self._proc.poll() is None: + return + log.debug('Starting STT worker subprocess') + self._proc = subprocess.Popen( + [sys.executable, str(Path(__file__).parent / 'stt_worker.py')], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + self._loaded = False + + def _send(self, cmd: dict) -> dict: + self._ensure_proc() + pickle.dump(cmd, self._proc.stdin) + self._proc.stdin.flush() + try: + response = pickle.load(self._proc.stdout) + except EOFError as e: + self._proc = None + self._loaded = False + raise RuntimeError('Worker subprocess exited unexpectedly') from e + if not response.get('ok'): + raise RuntimeError(response.get('error', 'unknown worker error')) + return response + + def _reschedule_unload(self) -> None: + if self._unload_source is not None: + GLib.source_remove(self._unload_source) + self._unload_source = GLib.timeout_add_seconds( + _IDLE_UNLOAD_SECONDS, self._on_idle_unload) + + def _on_idle_unload(self) -> bool: + self._unload_source = None + log.debug('Idle unload after %ds', _IDLE_UNLOAD_SECONDS) + self.unload_now() + return GLib.SOURCE_REMOVE diff --git a/stt_voice_messages/models/stt_worker.py b/stt_voice_messages/models/stt_worker.py new file mode 100644 index 0000000..d4cea44 --- /dev/null +++ b/stt_voice_messages/models/stt_worker.py @@ -0,0 +1,54 @@ +# This file is part of Gajim. +# +# Gajim is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Gajim is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Gajim. If not, see . + +import pickle +import sys +import traceback + + +def _respond(response: dict) -> None: + pickle.dump(response, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def main() -> None: + model = None + while True: + try: + cmd = pickle.load(sys.stdin.buffer) + except EOFError: + return + try: + op = cmd['op'] + if op == 'load': + import onnx_asr + model = onnx_asr.load_model( + cmd['model_id'], cmd.get('model_path') or None) + _respond({'ok': True}) + elif op == 'recognize': + text = model.recognize(cmd['audio']) + _respond({'ok': True, 'text': text}) + else: + _respond({'ok': False, 'error': f'unknown op: {op}'}) + except Exception as e: + _respond({ + 'ok': False, + 'error': f'{type(e).__name__}: {e}', + 'traceback': traceback.format_exc(), + }) + + +if __name__ == '__main__': + main() diff --git a/stt_voice_messages/plugin-manifest.json b/stt_voice_messages/plugin-manifest.json index e2df3a3..cb38eae 100644 --- a/stt_voice_messages/plugin-manifest.json +++ b/stt_voice_messages/plugin-manifest.json @@ -13,7 +13,7 @@ "win32" ], "requirements": [ - "gajim>=1.9.0" + "gajim>=2.0.0" ], "short_name": "stt_voice_messages", "version": "0.0.1" diff --git a/stt_voice_messages/stt_voice_messages.py b/stt_voice_messages/stt_voice_messages.py index 22840af..c133322 100644 --- a/stt_voice_messages/stt_voice_messages.py +++ b/stt_voice_messages/stt_voice_messages.py @@ -15,17 +15,24 @@ from __future__ import annotations +import logging from functools import partial +from pathlib import Path +from gi.repository import GLib, Gtk + +from gajim.common import app from gajim.plugins import GajimPlugin from gajim.plugins.plugins_i18n import _ -from .gtk.config_dialog import * +from .gtk.config_dialog import Configuration, STTVoiceMessagesConfigDialog from .gtk.sttbox import STTBox -from .models.model_settings import * +from .models.model_settings import OnnxAsrSettings log = logging.getLogger('gajim.p.stt_voice_messages') +_FOCUS_LOSS_UNLOAD_SECONDS = 30 + class STTVoiceMessagesPlugin(GajimPlugin): def init(self) -> None: @@ -33,42 +40,64 @@ class STTVoiceMessagesPlugin(GajimPlugin): self.config_default_values = { 'auto_transcribe': (False, ''), - 'model': ('model_openaiwhisper', ''), - 'model_openaiwhisper': ( - OpenAIWhisperSettings( - model_size='tiny', - translate_to_english=False), - ''), - 'model_faster-whisper': ( - FasterWhisperSettings( - model_size='tiny', - translate_to_english=False), - '') + 'model_id': ('nemo-parakeet-tdt-0.6b-v3', ''), + 'model_path': ('', ''), } self._config = Configuration(self) - self._config.check_available_moduls() self.config_dialog = partial(STTVoiceMessagesConfigDialog, self._config) self.gui_extension_points = { 'preview_audio': (self._on_preview_audio_created, None), } - self._audio_file: str = '' - self._preview_audio_widget = None - self._stt_box = None + self._active_handler_id = 0 + self._focus_unload_source = None + + def activate(self) -> None: + if app.window is not None and self._active_handler_id == 0: + self._active_handler_id = app.window.connect( + 'notify::is-active', self._on_window_active_changed) + + def deactivate(self) -> None: + if self._focus_unload_source is not None: + GLib.source_remove(self._focus_unload_source) + self._focus_unload_source = None + if self._active_handler_id != 0 and app.window is not None: + app.window.disconnect(self._active_handler_id) + self._active_handler_id = 0 + if self._config.is_available: + self._config.unload_model() + + def _on_window_active_changed(self, + window: Gtk.Window, + _pspec: object, + ) -> None: + if window.is_active(): + if self._focus_unload_source is not None: + GLib.source_remove(self._focus_unload_source) + self._focus_unload_source = None + elif self._focus_unload_source is None: + self._focus_unload_source = GLib.timeout_add_seconds( + _FOCUS_LOSS_UNLOAD_SECONDS, self._on_focus_unload_fired) + + def _on_focus_unload_fired(self) -> bool: + self._focus_unload_source = None + if self._config.is_available: + self._config.unload_model() + return GLib.SOURCE_REMOVE def _on_preview_audio_created(self, - preview_audio_widget: Gtk.Box, + drawing_box: Gtk.Box, + control_box: Gtk.Box, audio_file: Path ) -> None: - self._preview_audio_widget = preview_audio_widget + self._drawing_box = drawing_box; + self._control_box = control_box; self._audio_file = audio_file.as_posix() self._create_stt_box() def _create_stt_box(self) -> None: - assert self._preview_audio_widget is not None - self._stt_box = STTBox(self._preview_audio_widget, - self.config, - self._audio_file) - self._preview_audio_widget.pack_end(self._stt_box, False, False, 0) + self._stt_box = STTBox(self._config, self._audio_file) + self._control_box.append(self._stt_box.button) + self._drawing_box.append(self._stt_box)