diff --git a/stt_voice_messages/gtk/config_dialog.py b/stt_voice_messages/gtk/config_dialog.py index e88b77f..6d5fe8d 100644 --- a/stt_voice_messages/gtk/config_dialog.py +++ b/stt_voice_messages/gtk/config_dialog.py @@ -15,6 +15,7 @@ from __future__ import annotations +from dataclasses import asdict import logging import typing from pathlib import Path @@ -40,45 +41,51 @@ if TYPE_CHECKING: log = logging.getLogger('gajim.p.sttvm_config_dialog') -SUPPORTED_MODELS: dict[str, dict[str, typing.Union[list[str], Any, str]]] = { - 'model_openaiwhisper': { - 'moduls': ['whisper'], - 'class': openai_whisper.WhisperModel, - 'name': 'OpenAI Whisper' - }, - 'model_ctranslate2': { - 'moduls': ['ctranslate2'], - 'class': None, - 'name': _('CTranslate2') - }, - 'model_faster-whisper': { - 'moduls': ['faster-whisper'], - 'class': None, - 'name:': _('Faster-Whisper') - }, - 'model_distill': { - 'moduls': ['transformers', 'accelerate', 'datasets[audio]'], - 'class': None, - 'name': _('Distill') - } + +@dataclass +class Model: + name: str + required_moduls: list[str] + klass: object + config: Any + instance: typing.Optional[object] = None + + +SUPPORTED_MODELS: dict[str, Model] = { + 'model_openaiwhisper': Model('OpenAI Whisper', + ['whisper'], + openai_whisper.WhisperModel, + OpenAIWhisperSettings), + 'model_ctranslate2': Model('CTranslate2', + ['ctranslate2'], + None, + None), + 'model_faster-whisper': Model('Fast-Whisper', + ['faster-whisper'], + None, + None), + 'model_distill': Model('Distill', + ['transformers', 'accelerate', 'datasets[audio]'], + None, + None) } class Configuration: def __init__(self, plugin: STTVoiceMessagesPlugin): self._plugin = plugin - self._openaiwhisper_settings = OpenAIWhisperSettings() - self._available_models: dict[ - str, dict[str, typing.Union[list[str], Any, str]]] = {} + + self._available_models: dict[str, Model] = {} self.check_available_moduls() + log.debug('config = %s', self._plugin.config['model_openaiwhisper']) + @property def plugin(self) -> STTVoiceMessagesPlugin: return self._plugin @property - def available_models(self) -> dict[ - str, dict[str, typing.Union[list[str], Any, str]]]: + def available_models(self) -> dict[str, Model]: return self._available_models def on_setting(self, value: Any, data: Any) -> None: @@ -86,58 +93,61 @@ class Configuration: value.strip() log.debug('plugin config before:\n %s', self.plugin.config.data) - # TODO: Is 'modelname_key = value' a good design? self.plugin.config[data] = value - - # TODO: Apply setting only to specific instance - self._plugin.config['model_instance'].on_setting(data, value) log.debug('plugin config after:\n %s', self.plugin.config.data) - def on_set_model(self, value: Any, data: Any) -> None: + def on_config_model(self, model: str, value: Any, data: Any) -> None: if isinstance(value, str): value.strip() + + log.debug('plugin config before:\n %s', self.plugin.config.data[model]) + setattr(self.plugin.config.data[model], data, value) + log.debug('plugin config after:\n %s', self.plugin.config.data[model]) + + self._plugin.config.data[model].instance.set_config(self.plugin.config.data[model]) + + def on_set_model(self, model: Any) -> None: + if isinstance(model, str): + model.strip() log.debug('plugin config before:\n %s', self.plugin.config.data) - self._available_models[value]['model_instance'] = \ - self._available_models[value]['class']() + if (self.plugin.config.data[model].instance is None and + self._available_models[model].klass is not None): + self.plugin.config.data[model].instance = \ + self._available_models[model].klass() + else: + return - self.plugin.config['model_class'] = self._available_models[value][ - 'class'] - self.plugin.config['model_instance'] = self._available_models[value][ - 'model_instance'] - - self.on_setting(value, data) + self.plugin.config['model'] = model log.debug('plugin config after:\n %s', self.plugin.config.data) - @staticmethod - def is_module_available(module: str) -> bool: - try: - __import__(module) - return True - except ModuleNotFoundError: - log.debug('Could not find module %s', module) - return False - except ImportError as ex: - log.debug(str(ex)) - return False - def check_available_moduls(self): + def is_module_available(module: str) -> bool: + try: + __import__(module) + return True + except ModuleNotFoundError: + log.debug('Could not find module %s', module) + return False + except ImportError as ex: + log.debug(str(ex)) + return False + for model in SUPPORTED_MODELS: available = True - for modul in SUPPORTED_MODELS[model]['moduls']: - if not self.is_module_available(modul): + for modul in SUPPORTED_MODELS[model].required_moduls: + if not is_module_available(modul): available = False continue if available: self._available_models[model] = SUPPORTED_MODELS[model] + if SUPPORTED_MODELS[model].config is not None: + log.debug('created config for model = %s: %s', model, self._available_models[model]) + log.debug('plugin config for model = %s', self.plugin.config[model]) + self.plugin.config.data[model].instance = None + self._available_models[model].config = self.plugin.config[model] - if (self.plugin.config.data['model_class'] is None - and len(self._available_models) > 0): - model = list(self._available_models)[0] - self.on_set_model(model, 'model') - log.debug('Choose first available model!') - else: - log.debug('Available model already chosen!') + self.on_set_model(self._plugin.config['model']) log.debug('models = %s', self._available_models) @@ -188,6 +198,9 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): self.show_all() + ############################################################################ + # General Settings + ############################################################################ class STTBehaviour(PreferenceBox): def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: settings = [ @@ -206,7 +219,7 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): models: list[tuple[str, str]] = [] for key, value in config_dialog.config.available_models.items(): models.append( - (key, str(value['name'])) + (key, str(value.name)) ) settings = [ @@ -222,29 +235,37 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): PreferenceBox.__init__(self, settings) + ############################################################################ + # OpenAI Whisper Settings + ############################################################################ class OpenAIWhisperGeneral(PreferenceBox): def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: + + self._model = 'model_openaiwhisper' + self._config_dialog = config_dialog + settings = [ Setting(SettingKind.POPOVER, _('Language Model Size'), SettingType.VALUE, - value=config_dialog.plugin.config[ - 'whisperai_model_size'], - data='whisperai_model_size', - callback=config_dialog.config.on_setting, + value=config_dialog.config.available_models[self._model].config.model_size, + data='model_size', + callback=self._set_config, props={'entries': whisper.available_models()}), Setting(SettingKind.SWITCH, _('Translate'), SettingType.VALUE, - value=config_dialog.plugin.config[ - 'whisperai_translate'], - data='whisperai_translate', - callback=config_dialog.config.on_setting) + value=config_dialog.config.available_models[self._model].config.translate_to_english, + data='translate_to_english', + callback=self._set_config) ] PreferenceBox.__init__(self, settings) + def _set_config(self, value: Any, data: Any): + self._config_dialog.config.on_config_model(self._model, value, data) + def _add_prefs(self, prefs: list[tuple[str, type[PreferenceBox]]]): for ui_name, klass in prefs: pref_box = getattr(self._ui, ui_name) diff --git a/stt_voice_messages/gtk/sttbox.py b/stt_voice_messages/gtk/sttbox.py index 4280c9a..dbb2666 100644 --- a/stt_voice_messages/gtk/sttbox.py +++ b/stt_voice_messages/gtk/sttbox.py @@ -57,7 +57,8 @@ class STTBox(Gtk.Box): def _on_transcribe_clicked(self, _button: Gtk.Button) -> None: log.debug('config.data = %s', self._config.data) - model = self._config.data['model_instance'] + model_name = self._config.data['model'] + model = self._config.data[model_name].instance if model is None: return diff --git a/stt_voice_messages/models/model_settings.py b/stt_voice_messages/models/model_settings.py index 3228888..5b697ed 100644 --- a/stt_voice_messages/models/model_settings.py +++ b/stt_voice_messages/models/model_settings.py @@ -19,5 +19,5 @@ from dataclasses import dataclass, field @dataclass class OpenAIWhisperSettings: - whisperai_model_size: str = field(default='tiny', init=True) - + model_size: str = field(default='tiny', init=True) + translate_to_english: bool = field(default=False, init=True) \ No newline at end of file diff --git a/stt_voice_messages/models/model_template.py b/stt_voice_messages/models/model_template.py index 99e3881..ce5bf7a 100644 --- a/stt_voice_messages/models/model_template.py +++ b/stt_voice_messages/models/model_template.py @@ -16,8 +16,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from gajim.gtk.const import Setting - from ..helper import Results @@ -26,7 +24,3 @@ class Model(ABC): @abstractmethod def transcribe(self, result: Results, audio_file: Path) -> str: return '' - - @abstractmethod - def on_setting(self, setting: Setting): - pass \ No newline at end of file diff --git a/stt_voice_messages/models/openai_whisper.py b/stt_voice_messages/models/openai_whisper.py index 32f0242..4f1163d 100644 --- a/stt_voice_messages/models/openai_whisper.py +++ b/stt_voice_messages/models/openai_whisper.py @@ -46,11 +46,10 @@ class WhisperModel(Model): return self._result def transcribe(self, result: Results, audio_file: Path) -> str: - model = whisper.load_model(self._config['whisperai_model_size']) - log.debug('model size is used = %s', self._config['whisperai_model_size']) - result.text = model.transcribe(audio_file)['text'] + model = whisper.load_model(self._config.model_size) + log.debug('model size is used = %s', self._config.model_size) + result.text = model.transcribe(audio_file)['text'] # pyright: ignore [reportAttributeAccessIssue] - def on_setting(self, key, value): - log.debug('key = %s, value = %s', key, value) - self._config[key] = value + def set_config(self, config: OpenAIWhisperSettings) -> None: + self._config = config diff --git a/stt_voice_messages/stt_voice_messages.py b/stt_voice_messages/stt_voice_messages.py index 5b78d35..7c3df06 100644 --- a/stt_voice_messages/stt_voice_messages.py +++ b/stt_voice_messages/stt_voice_messages.py @@ -22,6 +22,7 @@ from gajim.plugins.plugins_i18n import _ from .gtk.config_dialog import * from .gtk.sttbox import STTBox +from .models.model_settings import * log = logging.getLogger('gajim.p.stt_voice_messages') @@ -29,6 +30,17 @@ log = logging.getLogger('gajim.p.stt_voice_messages') class STTVoiceMessagesPlugin(GajimPlugin): def init(self) -> None: self.description = _('Transcribes voice messages to text.') + + self.config_default_values = { + 'auto_transcribe': (False, ''), + 'model': ('model_openaiwhisper', ''), + 'model_openaiwhisper': ( + OpenAIWhisperSettings( + model_size='tiny', + translate_to_english=False), + '') + } + self._config = Configuration(self) self._config.check_available_moduls() self.config_dialog = partial(STTVoiceMessagesConfigDialog, self._config) @@ -37,14 +49,6 @@ class STTVoiceMessagesPlugin(GajimPlugin): 'preview_audio': (self._on_preview_audio_created, None), } - self.config_default_values = { - 'auto_transcribe': (False, ''), - 'model': ('', ''), - 'model_class': (None, ''), - 'whisperai_model_size': ('tiny', ''), - 'whisperai_translate': (False, ''), - } - self._audio_file: str = '' self._preview_audio_widget = None self._stt_box = None