diff --git a/stt_voice_messages/configs.py b/stt_voice_messages/configs.py index 831f2ff..c7d505c 100644 --- a/stt_voice_messages/configs.py +++ b/stt_voice_messages/configs.py @@ -1,21 +1,108 @@ -from dataclasses import dataclass, field +# This file is part of Gajim. +# +# Gajim is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Gajim is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Gajim. If not, see . -from whisper import available_models +import logging +import typing from gajim.common.app import Any +from gajim.plugins.plugins_i18n import _ -from .models.model import Model +from .models import openai_whisper +from .stt_voice_messages import STTVoiceMessagesPlugin + +log = logging.getLogger('gajim.p.stt_voice_messages_config') + +SUPPORTED_MODELS: dict[str, dict[str, typing.Union[list[str], Any, str]]] = { + 'model_openaiwhisper': { + 'moduls': ['whisper'], + 'class': openai_whisper.WhisperModel, + 'name': 'OpenAI Whisper' + }, + 'model_ctranslate2': { + 'moduls': ['ctranslate2'], + 'class': None, + 'name': _('CTranslate2') + }, + 'model_faster-whisper': { + 'moduls': ['faster-whisper'], + 'class': None, + 'name:': _('Faster-Whisper') + }, + 'model_distill': { + 'moduls': ['transformers', 'accelerate', 'datasets[audio]'], + 'class': None, + 'name': _('Distill') + } +} -@dataclass -class PluginConfig: - general: dict[str, Any] = field(default_factory=lambda: { - 'model': None, - 'auto_transcribe': None, - }) +class Configuration: + def __init__(self, plugin: STTVoiceMessagesPlugin): + self.plugin = plugin + self._available_models: dict[ + str, dict[str, typing.Union[list[str], Any, str]]] = {} + self.check_available_moduls() - openaiwhisper: dict[str, Any] = field(default_factory=lambda: { - 'model_size': 'tiny', - 'multilingual_model': True - }) + @property + def available_models(self) -> dict[str, dict[str, typing.Union[list[str], Any, str]]]: + return self._available_models + def on_setting(self, value: Any, data: Any) -> None: + if isinstance(value, str): + value.strip() + log.debug('plugin config before:\n %s', self.plugin.config.data) + self.plugin.config[data] = value + log.debug('plugin config after:\n %s', self.plugin.config.data) + + def on_set_model(self, value: Any, data: Any) -> None: + if isinstance(value, str): + value.strip() + log.debug('plugin config before:\n %s', self.plugin.config.data) + self.plugin.config['model_class'] = self._available_models[value][ + 'class'] + self.on_setting(value, data) + log.debug('plugin config after:\n %s', self.plugin.config.data) + + @staticmethod + def is_module_available(module: str) -> bool: + try: + __import__(module) + return True + except ModuleNotFoundError: + log.debug('Could not find module %s', module) + return False + except ImportError as ex: + log.debug(str(ex)) + return False + + def check_available_moduls(self): + for model in SUPPORTED_MODELS: + available = True + for modul in SUPPORTED_MODELS[model]['moduls']: + if not self.is_module_available(modul): + available = False + continue + if available: + self._available_models[model] = SUPPORTED_MODELS[model] + + if (self.plugin.config.data['model_class'] is None + and len(self._available_models) > 0): + model = list(self._available_models)[0] + self.on_set_model(model, 'model') + log.debug('Choose first available model!') + else: + log.debug('Available model already chosen!') + + log.debug('models = %s', self._available_models) diff --git a/stt_voice_messages/gtk/config_dialog.py b/stt_voice_messages/gtk/config_dialog.py index a7299fb..cf25efe 100644 --- a/stt_voice_messages/gtk/config_dialog.py +++ b/stt_voice_messages/gtk/config_dialog.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import whisper from gi.repository import Gtk @@ -30,25 +30,12 @@ from gajim.gtk.sidebar_switcher import SideBarSwitcher from gajim.plugins.helpers import get_builder from gajim.plugins.plugins_i18n import _ -from ..configs import * +from ..configs import Configuration if TYPE_CHECKING: - from .. import stt_voice_messages + from ..stt_voice_messages import STTVoiceMessagesPlugin -log = logging.getLogger('gajim.p.stt_voice_messages_config') - - -@staticmethod -def check_module(module: str) -> bool: - try: - __import__(module) - return True - except ModuleNotFoundError: - log.debug('Could not find module %s', module) - return False - except ImportError as ex: - log.debug(str(ex)) - return False +log = logging.getLogger('gajim.p.stt_voice_messages_config_dialog') class PreferenceBox(SettingsBox): @@ -65,10 +52,8 @@ class PreferenceBox(SettingsBox): class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): - def __init__(self, plugin: stt_voice_messages.STTVoiceMessagesPlugin, - parent: Gtk.Window) -> None: + def __init__(self, plugin: STTVoiceMessagesPlugin, parent: Gtk.Window) -> None: Gtk.ApplicationWindow.__init__(self) - self.plugin = plugin self.set_application(app.app) self.set_position(Gtk.WindowPosition.CENTER) @@ -97,34 +82,35 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): self.show_all() class STTBehaviour(PreferenceBox): - def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: + def __init__(self, config: Configuration) -> None: settings = [ Setting(SettingKind.SWITCH, _('Auto Transcribe'), SettingType.VALUE, - value=config_dialog.plugin.config['auto_transcribe'], + value=config.plugin.config['auto_transcribe'], data='auto_transcribe', - callback=config_dialog._on_setting) + callback=config.on_setting) ] PreferenceBox.__init__(self, settings) class Models(PreferenceBox): - def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: - models: list[tuple[str, str]] = [ - ('model_openai', _('OpenAI Whisper')), - ('model_ctranslate2', _('CTranslate2 (not impl)')), - ('model_distill', _('Distill (not impl)')), - ] + def __init__(self, config: Configuration) -> None: + models: list[tuple[str, str]] = [] + for key, value in config.available_models.items(): + assert value['name'] is str + models.append( + (key, value['name']) + ) settings = [ Setting(SettingKind.COMBO, _('Speech To Text Model'), SettingType.VALUE, - value=config_dialog.plugin.config['model'], + value=config.plugin.config['model'], data='model', - callback=config_dialog._on_setting, + callback=config.on_set_model, props={'combo_items': models}, desc=_('Choose Model to use')), ] @@ -132,23 +118,23 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): PreferenceBox.__init__(self, settings) class OpenAIWhisperGeneral(PreferenceBox): - def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None: + def __init__(self, config: Configuration) -> None: settings = [ Setting(SettingKind.POPOVER, _('Language Model Size'), SettingType.VALUE, - value=config_dialog.plugin.config['whisperai_model_size'], + value=config.plugin.config['whisperai_model_size'], data='whisperai_model_size', - callback=config_dialog._on_setting, + callback=config.on_setting, props={'entries': whisper.available_models()}), Setting(SettingKind.SWITCH, _('Translate'), SettingType.VALUE, - value=config_dialog.plugin.config['whisperai_translate'], + value=config.plugin.config['whisperai_translate'], data='whisperai_translate', - callback=config_dialog._on_setting) + callback=config.on_setting) ] PreferenceBox.__init__(self, settings) @@ -160,9 +146,4 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow): pref_box.add(pref) self._prefs[ui_name] = pref - def _on_setting(self, value: Any, data: Any) -> None: - if isinstance(value, str): - value.strip() - log.debug('plugin config before:\n %s', self.plugin.config.data) - self.plugin.config[data] = value - log.debug('plugin config after:\n %s', self.plugin.config.data) + diff --git a/stt_voice_messages/gtk/sttbox.py b/stt_voice_messages/gtk/sttbox.py index 225c324..c9b50ff 100644 --- a/stt_voice_messages/gtk/sttbox.py +++ b/stt_voice_messages/gtk/sttbox.py @@ -14,10 +14,14 @@ # along with Gajim. If not, see . from gi.repository import Gtk +import logging + +from .. import helper from gajim.plugins.gajimplugin import GajimPluginConfig from gajim.plugins.plugins_i18n import _ +log = logging.getLogger('gajim.p.stt_voice_messages_sttbox') class STTBox(Gtk.Box): def __init__(self, @@ -48,15 +52,25 @@ class STTBox(Gtk.Box): self.show_all() - def _on_transcribe_clicked(self, _button: Gtk.Button): - #transcription_task = helper.BackgroundTask( - # self._model.transcribe(), - # self._show_result - #) - #transcription_task.start() - pass + #def update_config(self, config: GajimPluginConfig): + # self._model = config.data['class']() + + def _on_transcribe_clicked(self, _button: Gtk.Button) -> None: + log.debug('config.data = %s', self._config.data) + model_class = self._config.data['model_class'] + if model_class is None: + return + + self._model = model_class() + + transcription_task = helper.BackgroundTask( + self._model.transcribe(self._audio_file), + self._show_result + ) + transcription_task.start() def _show_result(self): + self._text = self._model.result if self._text.strip() != '': self._transcription_label.set_text(self._text.strip()) else: diff --git a/stt_voice_messages/models/distill.py b/stt_voice_messages/models/distill.py new file mode 100644 index 0000000..1e60775 --- /dev/null +++ b/stt_voice_messages/models/distill.py @@ -0,0 +1,45 @@ +# This file is part of Gajim. +# +# Gajim is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Gajim is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Gajim. If not, see . + +import typing +from pathlib import Path + +from gajim.gtk.const import Setting + +from .model import Model + +try: + import ctranslate2 + CTRANSLATE2_AVAILABLE = True +except ModuleNotFoundError: + if typing.TYPE_CHECKING: + import ctranslate2 + + +class WhisperModel(Model): + def __init__(self): + # TODO + self._result: str = '' + + @property + def result(self) -> str: + return self._result + + def transcribe(self, audio_file: Path) -> str: + pass + + def on_setting(self, setting: Setting): + pass + diff --git a/stt_voice_messages/models/openai_whisper.py b/stt_voice_messages/models/openai_whisper.py index edb08a0..e50e808 100644 --- a/stt_voice_messages/models/openai_whisper.py +++ b/stt_voice_messages/models/openai_whisper.py @@ -30,17 +30,23 @@ except ModuleNotFoundError: class WhisperModel(Model): def __init__(self): + # TODO self._model_sizes = ['tiny', 'small', 'base', 'medium', 'large'] self._multilanguage = True + self._result: str = '' self._config = { 'model_size': 'tiny' } + @property + def result(self) -> str: + return self._result + def transcribe(self, audio_file: Path) -> str: model = whisper.load_model(self._config['model_size']) result = model.transcribe(audio_file) - return result['text'] + self._result = result['text'] def on_setting(self, setting: Setting): pass diff --git a/stt_voice_messages/stt_voice_messages.py b/stt_voice_messages/stt_voice_messages.py index ca01dac..f761c1e 100644 --- a/stt_voice_messages/stt_voice_messages.py +++ b/stt_voice_messages/stt_voice_messages.py @@ -24,7 +24,8 @@ from gi.repository import Gtk from gajim.plugins import GajimPlugin from gajim.plugins.plugins_i18n import _ -from .gtk import config_dialog, sttbox +from .gtk.sttbox import STTBox +from .gtk.config_dialog import STTVoiceMessagesConfigDialog log = logging.getLogger('gajim.p.stt_voice_messages') @@ -32,8 +33,7 @@ log = logging.getLogger('gajim.p.stt_voice_messages') class STTVoiceMessagesPlugin(GajimPlugin): def init(self) -> None: self.description = _('Transcribes voice messages to text.') - self.config_dialog = partial(config_dialog.STTVoiceMessagesConfigDialog, - self) + self.config_dialog = partial(STTVoiceMessagesConfigDialog, self) self.gui_extension_points = { 'preview_audio': (self._on_preview_audio_created, None), @@ -41,7 +41,8 @@ class STTVoiceMessagesPlugin(GajimPlugin): self.config_default_values = { 'auto_transcribe': (False, ''), - 'model': ('model_openai', ''), + 'model': ('', ''), + 'model_class': (None, ''), 'whisperai_model_size': ('tiny', ''), 'whisperai_translate': (False, ''), } @@ -60,7 +61,7 @@ class STTVoiceMessagesPlugin(GajimPlugin): def _create_stt_box(self) -> None: assert self._preview_audio_widget is not None - self._stt_box = sttbox.STTBox(self._preview_audio_widget, + self._stt_box = STTBox(self._preview_audio_widget, self.config, self._audio_file) self._preview_audio_widget.pack_end(self._stt_box, False, False, 0)