Improve multi modul support and refactor
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
import logging
|
||||
import typing
|
||||
from pathlib import Path
|
||||
@@ -40,45 +41,51 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger('gajim.p.sttvm_config_dialog')
|
||||
|
||||
SUPPORTED_MODELS: dict[str, dict[str, typing.Union[list[str], Any, str]]] = {
|
||||
'model_openaiwhisper': {
|
||||
'moduls': ['whisper'],
|
||||
'class': openai_whisper.WhisperModel,
|
||||
'name': 'OpenAI Whisper'
|
||||
},
|
||||
'model_ctranslate2': {
|
||||
'moduls': ['ctranslate2'],
|
||||
'class': None,
|
||||
'name': _('CTranslate2')
|
||||
},
|
||||
'model_faster-whisper': {
|
||||
'moduls': ['faster-whisper'],
|
||||
'class': None,
|
||||
'name:': _('Faster-Whisper')
|
||||
},
|
||||
'model_distill': {
|
||||
'moduls': ['transformers', 'accelerate', 'datasets[audio]'],
|
||||
'class': None,
|
||||
'name': _('Distill')
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class Model:
|
||||
name: str
|
||||
required_moduls: list[str]
|
||||
klass: object
|
||||
config: Any
|
||||
instance: typing.Optional[object] = None
|
||||
|
||||
|
||||
SUPPORTED_MODELS: dict[str, Model] = {
|
||||
'model_openaiwhisper': Model('OpenAI Whisper',
|
||||
['whisper'],
|
||||
openai_whisper.WhisperModel,
|
||||
OpenAIWhisperSettings),
|
||||
'model_ctranslate2': Model('CTranslate2',
|
||||
['ctranslate2'],
|
||||
None,
|
||||
None),
|
||||
'model_faster-whisper': Model('Fast-Whisper',
|
||||
['faster-whisper'],
|
||||
None,
|
||||
None),
|
||||
'model_distill': Model('Distill',
|
||||
['transformers', 'accelerate', 'datasets[audio]'],
|
||||
None,
|
||||
None)
|
||||
}
|
||||
|
||||
|
||||
class Configuration:
|
||||
def __init__(self, plugin: STTVoiceMessagesPlugin):
|
||||
self._plugin = plugin
|
||||
self._openaiwhisper_settings = OpenAIWhisperSettings()
|
||||
self._available_models: dict[
|
||||
str, dict[str, typing.Union[list[str], Any, str]]] = {}
|
||||
|
||||
self._available_models: dict[str, Model] = {}
|
||||
self.check_available_moduls()
|
||||
|
||||
log.debug('config = %s', self._plugin.config['model_openaiwhisper'])
|
||||
|
||||
@property
|
||||
def plugin(self) -> STTVoiceMessagesPlugin:
|
||||
return self._plugin
|
||||
|
||||
@property
|
||||
def available_models(self) -> dict[
|
||||
str, dict[str, typing.Union[list[str], Any, str]]]:
|
||||
def available_models(self) -> dict[str, Model]:
|
||||
return self._available_models
|
||||
|
||||
def on_setting(self, value: Any, data: Any) -> None:
|
||||
@@ -86,58 +93,61 @@ class Configuration:
|
||||
value.strip()
|
||||
|
||||
log.debug('plugin config before:\n %s', self.plugin.config.data)
|
||||
# TODO: Is 'modelname_key = value' a good design?
|
||||
self.plugin.config[data] = value
|
||||
|
||||
# TODO: Apply setting only to specific instance
|
||||
self._plugin.config['model_instance'].on_setting(data, value)
|
||||
log.debug('plugin config after:\n %s', self.plugin.config.data)
|
||||
|
||||
def on_set_model(self, value: Any, data: Any) -> None:
|
||||
def on_config_model(self, model: str, value: Any, data: Any) -> None:
|
||||
if isinstance(value, str):
|
||||
value.strip()
|
||||
|
||||
log.debug('plugin config before:\n %s', self.plugin.config.data[model])
|
||||
setattr(self.plugin.config.data[model], data, value)
|
||||
log.debug('plugin config after:\n %s', self.plugin.config.data[model])
|
||||
|
||||
self._plugin.config.data[model].instance.set_config(self.plugin.config.data[model])
|
||||
|
||||
def on_set_model(self, model: Any) -> None:
|
||||
if isinstance(model, str):
|
||||
model.strip()
|
||||
log.debug('plugin config before:\n %s', self.plugin.config.data)
|
||||
|
||||
self._available_models[value]['model_instance'] = \
|
||||
self._available_models[value]['class']()
|
||||
if (self.plugin.config.data[model].instance is None and
|
||||
self._available_models[model].klass is not None):
|
||||
self.plugin.config.data[model].instance = \
|
||||
self._available_models[model].klass()
|
||||
else:
|
||||
return
|
||||
|
||||
self.plugin.config['model_class'] = self._available_models[value][
|
||||
'class']
|
||||
self.plugin.config['model_instance'] = self._available_models[value][
|
||||
'model_instance']
|
||||
|
||||
self.on_setting(value, data)
|
||||
self.plugin.config['model'] = model
|
||||
log.debug('plugin config after:\n %s', self.plugin.config.data)
|
||||
|
||||
@staticmethod
|
||||
def is_module_available(module: str) -> bool:
|
||||
try:
|
||||
__import__(module)
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
log.debug('Could not find module %s', module)
|
||||
return False
|
||||
except ImportError as ex:
|
||||
log.debug(str(ex))
|
||||
return False
|
||||
|
||||
def check_available_moduls(self):
|
||||
def is_module_available(module: str) -> bool:
|
||||
try:
|
||||
__import__(module)
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
log.debug('Could not find module %s', module)
|
||||
return False
|
||||
except ImportError as ex:
|
||||
log.debug(str(ex))
|
||||
return False
|
||||
|
||||
for model in SUPPORTED_MODELS:
|
||||
available = True
|
||||
for modul in SUPPORTED_MODELS[model]['moduls']:
|
||||
if not self.is_module_available(modul):
|
||||
for modul in SUPPORTED_MODELS[model].required_moduls:
|
||||
if not is_module_available(modul):
|
||||
available = False
|
||||
continue
|
||||
if available:
|
||||
self._available_models[model] = SUPPORTED_MODELS[model]
|
||||
if SUPPORTED_MODELS[model].config is not None:
|
||||
log.debug('created config for model = %s: %s', model, self._available_models[model])
|
||||
log.debug('plugin config for model = %s', self.plugin.config[model])
|
||||
self.plugin.config.data[model].instance = None
|
||||
self._available_models[model].config = self.plugin.config[model]
|
||||
|
||||
if (self.plugin.config.data['model_class'] is None
|
||||
and len(self._available_models) > 0):
|
||||
model = list(self._available_models)[0]
|
||||
self.on_set_model(model, 'model')
|
||||
log.debug('Choose first available model!')
|
||||
else:
|
||||
log.debug('Available model already chosen!')
|
||||
self.on_set_model(self._plugin.config['model'])
|
||||
|
||||
log.debug('models = %s', self._available_models)
|
||||
|
||||
@@ -188,6 +198,9 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow):
|
||||
|
||||
self.show_all()
|
||||
|
||||
############################################################################
|
||||
# General Settings
|
||||
############################################################################
|
||||
class STTBehaviour(PreferenceBox):
|
||||
def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None:
|
||||
settings = [
|
||||
@@ -206,7 +219,7 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow):
|
||||
models: list[tuple[str, str]] = []
|
||||
for key, value in config_dialog.config.available_models.items():
|
||||
models.append(
|
||||
(key, str(value['name']))
|
||||
(key, str(value.name))
|
||||
)
|
||||
|
||||
settings = [
|
||||
@@ -222,29 +235,37 @@ class STTVoiceMessagesConfigDialog(Gtk.ApplicationWindow):
|
||||
|
||||
PreferenceBox.__init__(self, settings)
|
||||
|
||||
############################################################################
|
||||
# OpenAI Whisper Settings
|
||||
############################################################################
|
||||
class OpenAIWhisperGeneral(PreferenceBox):
|
||||
def __init__(self, config_dialog: STTVoiceMessagesConfigDialog) -> None:
|
||||
|
||||
self._model = 'model_openaiwhisper'
|
||||
self._config_dialog = config_dialog
|
||||
|
||||
settings = [
|
||||
Setting(SettingKind.POPOVER,
|
||||
_('Language Model Size'),
|
||||
SettingType.VALUE,
|
||||
value=config_dialog.plugin.config[
|
||||
'whisperai_model_size'],
|
||||
data='whisperai_model_size',
|
||||
callback=config_dialog.config.on_setting,
|
||||
value=config_dialog.config.available_models[self._model].config.model_size,
|
||||
data='model_size',
|
||||
callback=self._set_config,
|
||||
props={'entries': whisper.available_models()}),
|
||||
|
||||
Setting(SettingKind.SWITCH,
|
||||
_('Translate'),
|
||||
SettingType.VALUE,
|
||||
value=config_dialog.plugin.config[
|
||||
'whisperai_translate'],
|
||||
data='whisperai_translate',
|
||||
callback=config_dialog.config.on_setting)
|
||||
value=config_dialog.config.available_models[self._model].config.translate_to_english,
|
||||
data='translate_to_english',
|
||||
callback=self._set_config)
|
||||
]
|
||||
|
||||
PreferenceBox.__init__(self, settings)
|
||||
|
||||
def _set_config(self, value: Any, data: Any):
|
||||
self._config_dialog.config.on_config_model(self._model, value, data)
|
||||
|
||||
def _add_prefs(self, prefs: list[tuple[str, type[PreferenceBox]]]):
|
||||
for ui_name, klass in prefs:
|
||||
pref_box = getattr(self._ui, ui_name)
|
||||
|
||||
@@ -57,7 +57,8 @@ class STTBox(Gtk.Box):
|
||||
|
||||
def _on_transcribe_clicked(self, _button: Gtk.Button) -> None:
|
||||
log.debug('config.data = %s', self._config.data)
|
||||
model = self._config.data['model_instance']
|
||||
model_name = self._config.data['model']
|
||||
model = self._config.data[model_name].instance
|
||||
if model is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -19,5 +19,5 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class OpenAIWhisperSettings:
|
||||
whisperai_model_size: str = field(default='tiny', init=True)
|
||||
|
||||
model_size: str = field(default='tiny', init=True)
|
||||
translate_to_english: bool = field(default=False, init=True)
|
||||
@@ -16,8 +16,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from gajim.gtk.const import Setting
|
||||
|
||||
from ..helper import Results
|
||||
|
||||
|
||||
@@ -26,7 +24,3 @@ class Model(ABC):
|
||||
@abstractmethod
|
||||
def transcribe(self, result: Results, audio_file: Path) -> str:
|
||||
return ''
|
||||
|
||||
@abstractmethod
|
||||
def on_setting(self, setting: Setting):
|
||||
pass
|
||||
@@ -46,11 +46,10 @@ class WhisperModel(Model):
|
||||
return self._result
|
||||
|
||||
def transcribe(self, result: Results, audio_file: Path) -> str:
|
||||
model = whisper.load_model(self._config['whisperai_model_size'])
|
||||
log.debug('model size is used = %s', self._config['whisperai_model_size'])
|
||||
result.text = model.transcribe(audio_file)['text']
|
||||
model = whisper.load_model(self._config.model_size)
|
||||
log.debug('model size is used = %s', self._config.model_size)
|
||||
result.text = model.transcribe(audio_file)['text'] # pyright: ignore [reportAttributeAccessIssue]
|
||||
|
||||
def on_setting(self, key, value):
|
||||
log.debug('key = %s, value = %s', key, value)
|
||||
self._config[key] = value
|
||||
def set_config(self, config: OpenAIWhisperSettings) -> None:
|
||||
self._config = config
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from gajim.plugins.plugins_i18n import _
|
||||
|
||||
from .gtk.config_dialog import *
|
||||
from .gtk.sttbox import STTBox
|
||||
from .models.model_settings import *
|
||||
|
||||
log = logging.getLogger('gajim.p.stt_voice_messages')
|
||||
|
||||
@@ -29,6 +30,17 @@ log = logging.getLogger('gajim.p.stt_voice_messages')
|
||||
class STTVoiceMessagesPlugin(GajimPlugin):
|
||||
def init(self) -> None:
|
||||
self.description = _('Transcribes voice messages to text.')
|
||||
|
||||
self.config_default_values = {
|
||||
'auto_transcribe': (False, ''),
|
||||
'model': ('model_openaiwhisper', ''),
|
||||
'model_openaiwhisper': (
|
||||
OpenAIWhisperSettings(
|
||||
model_size='tiny',
|
||||
translate_to_english=False),
|
||||
'')
|
||||
}
|
||||
|
||||
self._config = Configuration(self)
|
||||
self._config.check_available_moduls()
|
||||
self.config_dialog = partial(STTVoiceMessagesConfigDialog, self._config)
|
||||
@@ -37,14 +49,6 @@ class STTVoiceMessagesPlugin(GajimPlugin):
|
||||
'preview_audio': (self._on_preview_audio_created, None),
|
||||
}
|
||||
|
||||
self.config_default_values = {
|
||||
'auto_transcribe': (False, ''),
|
||||
'model': ('', ''),
|
||||
'model_class': (None, ''),
|
||||
'whisperai_model_size': ('tiny', ''),
|
||||
'whisperai_translate': (False, ''),
|
||||
}
|
||||
|
||||
self._audio_file: str = ''
|
||||
self._preview_audio_widget = None
|
||||
self._stt_box = None
|
||||
|
||||
Reference in New Issue
Block a user