292 lines
11 KiB
Python
292 lines
11 KiB
Python
# This file is part of Gajim.
|
|
#
|
|
# Gajim is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Gajim is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with Gajim. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import typing
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
try:
|
|
import onnx_asr
|
|
except ModuleNotFoundError:
|
|
if typing.TYPE_CHECKING:
|
|
import onnx_asr
|
|
|
|
from gi.repository import Adw, Gtk
|
|
|
|
from gajim.gtk.const import Setting, SettingKind, SettingType
|
|
from gajim.gtk.filechoosers import Filter
|
|
from gajim.gtk.settings import GajimPreferencesGroup, SettingsDialog
|
|
from gajim.plugins.plugins_i18n import _
|
|
|
|
from ..models import stt
|
|
from ..models.model_settings import OnnxAsrSettings
|
|
|
|
if TYPE_CHECKING:
|
|
from ..stt_voice_messages import STTVoiceMessagesPlugin
|
|
|
|
log = logging.getLogger('gajim.p.sttvm_config_dialog')
|
|
|
|
|
|
class Configuration:
|
|
def __init__(self, plugin: STTVoiceMessagesPlugin):
|
|
self._plugin = plugin
|
|
self._instance = None
|
|
self._main_model_row = None
|
|
self._preset_model_picker = None
|
|
self._custom_model_id_entry = None
|
|
self._local_model_file_picker = None
|
|
self._status_group = None
|
|
self._model_data: dict[str, str] = {}
|
|
self._instance = stt.OnnxAsrModel()
|
|
self._instance.set_config(OnnxAsrSettings(
|
|
model_id=self.plugin.config['model_id'],
|
|
model_path=self.plugin.config['model_path']
|
|
))
|
|
self._model_data = self._steal_model_list()
|
|
|
|
@property
|
|
def plugin(self) -> STTVoiceMessagesPlugin:
|
|
return self._plugin
|
|
|
|
@property
|
|
def is_available(self) -> bool:
|
|
return self._instance is not None
|
|
|
|
def unload_model(self) -> None:
|
|
if self._instance is not None:
|
|
self._instance.unload_now()
|
|
|
|
def _steal_model_list(self) -> dict[str, str]:
|
|
# UGLY: Extract available model choices from onnx_asr type hints.
|
|
ann = onnx_asr.load_model.__annotations__.get('model')
|
|
return {
|
|
v: v for arg in typing.get_args(ann)
|
|
for v in typing.get_args(arg)
|
|
if isinstance(v, str)
|
|
}
|
|
|
|
def on_setting(self, value: Any, data: Any) -> None:
|
|
if isinstance(value, str):
|
|
value = value.strip()
|
|
self.plugin.config[data] = value
|
|
|
|
def on_preset_changed(self, value: str, data: Any) -> None:
|
|
if self._custom_model_id_entry is not None:
|
|
entry_text = self._custom_model_id_entry.entry.get_text().strip()
|
|
if entry_text:
|
|
self._update_model_status()
|
|
return # custom entry overrides; ignore preset change
|
|
self._write_model_id(value)
|
|
self._update_model_status()
|
|
|
|
def on_custom_model_id_changed(self, value: str, data: Any) -> None:
|
|
value = value.strip()
|
|
if value:
|
|
self._write_model_id(value)
|
|
elif self._preset_model_picker is not None:
|
|
preset_key = self._preset_model_picker._dropdown.get_selected_key()
|
|
if preset_key is not None:
|
|
self._write_model_id(preset_key)
|
|
self._apply_sensitivity_state()
|
|
self._update_model_status()
|
|
|
|
def on_model_file_picked(self, value: str, data: Any) -> None:
|
|
self._write_model_path(str(Path(value).parent) if value else '')
|
|
self._apply_sensitivity_state()
|
|
self._update_model_status()
|
|
|
|
def _write_model_id(self, model_id: str) -> None:
|
|
if self.plugin.config['model_id'] == model_id:
|
|
return
|
|
self.plugin.config['model_id'] = model_id
|
|
if self._instance is not None:
|
|
self._instance.set_config(OnnxAsrSettings(
|
|
model_id=self.plugin.config['model_id'],
|
|
model_path=self.plugin.config['model_path']
|
|
))
|
|
|
|
def _write_model_path(self, model_path: str) -> None:
|
|
if self.plugin.config['model_path'] == model_path:
|
|
return
|
|
self.plugin.config['model_path'] = model_path
|
|
if self._instance is not None:
|
|
self._instance.set_config(OnnxAsrSettings(
|
|
model_id=self.plugin.config['model_id'],
|
|
model_path=self.plugin.config['model_path']
|
|
))
|
|
|
|
def sync_model_path_from_widget(self) -> None:
|
|
if self._local_model_file_picker is None:
|
|
return
|
|
button = self._local_model_file_picker.get_activatable_widget()
|
|
path = button.get_path()
|
|
new_path = str(path.parent) if path else ''
|
|
self._write_model_path(new_path)
|
|
|
|
def _apply_sensitivity_state(self) -> None:
|
|
if self._preset_model_picker is None:
|
|
return
|
|
has_local = bool(self.plugin.config['model_path'])
|
|
entry_text = (self._custom_model_id_entry.entry.get_text().strip()
|
|
if self._custom_model_id_entry else '')
|
|
has_entry = bool(entry_text)
|
|
self._custom_model_id_entry.set_sensitive(not has_local)
|
|
self._preset_model_picker.set_sensitive(not has_local and not has_entry)
|
|
|
|
def _update_model_status(self) -> None:
|
|
if self._main_model_row is None:
|
|
return
|
|
entry_text = (self._custom_model_id_entry.entry.get_text().strip()
|
|
if self._custom_model_id_entry else '')
|
|
|
|
if self.plugin.config['model_path']:
|
|
path = Path(self.plugin.config['model_path'])
|
|
summary = _('Local: {}').format(path.name or str(path))
|
|
description = _('Loading model files from {}').format(path)
|
|
if not (path / 'config.json').exists():
|
|
description += '\n' + _(
|
|
'config.json not found in this directory — onnx-asr will'
|
|
' fall back to Model preset or Custom Model ID for the'
|
|
' architecture.')
|
|
elif entry_text:
|
|
summary = _('Custom: {}').format(entry_text)
|
|
description = _('Using custom model: {}').format(entry_text)
|
|
else:
|
|
preset_key = (self._preset_model_picker._dropdown.get_selected_key()
|
|
if self._preset_model_picker else '')
|
|
summary = preset_key or _('(none)')
|
|
description = (_('Using preset: {}').format(preset_key)
|
|
if preset_key else '')
|
|
|
|
self._main_model_row._label.set_text(summary)
|
|
if self._status_group is not None:
|
|
self._status_group.set_description(description)
|
|
|
|
|
|
class STTVoiceMessagesConfigDialog(SettingsDialog):
|
|
def __init__(self, config: Configuration, parent: Gtk.Window) -> None:
|
|
self.config = config
|
|
self.plugin = self.config.plugin
|
|
if not config.is_available:
|
|
return
|
|
|
|
rows = [
|
|
Setting(SettingKind.SWITCH,
|
|
_('Auto Transcribe'),
|
|
SettingType.VALUE,
|
|
value=self.plugin.config['auto_transcribe'],
|
|
data='auto_transcribe',
|
|
callback=config.on_setting,
|
|
desc=_('Transcribe messages as they appear')),
|
|
Setting(SettingKind.SUBPAGE,
|
|
_('Model'),
|
|
SettingType.VALUE,
|
|
value=None,
|
|
name='main_model',
|
|
props={'subpage': 'sttvm-model'}),
|
|
]
|
|
|
|
SettingsDialog.__init__(
|
|
self,
|
|
parent,
|
|
_('STT Voice Messages'),
|
|
Gtk.DialogFlags.MODAL,
|
|
rows,
|
|
'',
|
|
)
|
|
|
|
config._main_model_row = self.get_setting('main_model')
|
|
|
|
use_custom = self.plugin.config['model_id'] not in config._model_data
|
|
|
|
|
|
subpage_rows: list[Setting] = [
|
|
Setting(SettingKind.DROPDOWN,
|
|
_('Model'),
|
|
SettingType.VALUE,
|
|
value=self.plugin.config['model_id'],
|
|
name='preset_model',
|
|
callback=config.on_preset_changed,
|
|
props={'data': config._model_data}),
|
|
Setting(SettingKind.ENTRY,
|
|
_('Custom Model'),
|
|
SettingType.VALUE,
|
|
value=self.plugin.config['model_id'] if use_custom else '',
|
|
name='custom_model',
|
|
callback=config.on_custom_model_id_changed,
|
|
desc=_('Custom HF model path or model ID')),
|
|
Setting(SettingKind.FILECHOOSER,
|
|
_('Local File'),
|
|
SettingType.VALUE,
|
|
value='',
|
|
name='local_model_file',
|
|
callback=config.on_model_file_picked,
|
|
desc=_('Model ID is taken from config.json if not set'),
|
|
props={'filefilters': [
|
|
Filter(_('ONNX model'), suffixes=['onnx'], default=True),
|
|
]}),
|
|
]
|
|
|
|
controls_group = GajimPreferencesGroup('model_controls')
|
|
for s in subpage_rows:
|
|
controls_group.add_setting(s)
|
|
|
|
status_group = Adw.PreferencesGroup()
|
|
|
|
pref_page = Adw.PreferencesPage()
|
|
pref_page.add(controls_group)
|
|
pref_page.add(status_group)
|
|
|
|
toolbar = Adw.ToolbarView(content=pref_page)
|
|
toolbar.add_top_bar(Adw.HeaderBar())
|
|
|
|
page = Adw.NavigationPage(
|
|
tag='sttvm-model', title=_('Model'), child=toolbar)
|
|
self._nav.add(page)
|
|
|
|
config._preset_model_picker = controls_group.get_setting('preset_model')
|
|
config._custom_model_id_entry = controls_group.get_setting('custom_model')
|
|
config._local_model_file_picker = controls_group.get_setting(
|
|
'local_model_file')
|
|
config._status_group = status_group
|
|
|
|
config._custom_model_id_entry.entry.set_placeholder_text(
|
|
_('onnx-community/whisper-large-v3-turbo'))
|
|
|
|
button = config._local_model_file_picker.get_activatable_widget()
|
|
button._label_text = _('.oonx')
|
|
button.reset()
|
|
|
|
if self.plugin.config['model_path']:
|
|
onnx_in_dir = next(iter(Path(self.plugin.config['model_path']).glob('*.onnx')),
|
|
None)
|
|
if onnx_in_dir is not None:
|
|
button.set_path(onnx_in_dir)
|
|
|
|
config._update_model_status()
|
|
config._apply_sensitivity_state()
|
|
|
|
def _cleanup(self) -> None:
|
|
self.config.sync_model_path_from_widget()
|
|
self.config._main_model_row = None
|
|
self.config._preset_model_picker = None
|
|
self.config._custom_model_id_entry = None
|
|
self.config._local_model_file_picker = None
|
|
self.config._status_group = None
|
|
SettingsDialog._cleanup(self)
|