diff --git a/src/ollama_slixmpp_omemo_bot/ollama_bot.py b/src/ollama_slixmpp_omemo_bot/ollama_bot.py index 75231b1..bbc5b43 100644 --- a/src/ollama_slixmpp_omemo_bot/ollama_bot.py +++ b/src/ollama_slixmpp_omemo_bot/ollama_bot.py @@ -1,5 +1,6 @@ -import re +from enum import Enum from typing import Dict, Optional +import re import ollama from slixmpp import ClientXMPP, JID @@ -7,6 +8,7 @@ from slixmpp.exceptions import IqTimeout, IqError from slixmpp.stanza import Message from slixmpp.types import JidStr, MessageTypes from slixmpp.xmlstream.handler import CoroutineCallback +from slixmpp.xmlstream.handler.coroutine_callback import CoroutineFunction from slixmpp.xmlstream.matcher import MatchXPath from slixmpp_omemo import ( EncryptionPrepareException, @@ -18,17 +20,24 @@ from slixmpp_omemo import ( from omemo.exceptions import MissingBundleException -LEVEL_DEBUG: int = 0 -LEVEL_ERROR: int = 1 +class LEVELS(Enum): + DEBUG = 0 + ERROR = 1 + + +class LLMS(Enum): + LLAMA3 = "llama3" + MISTRAL = "mistral" class OllamaBot(ClientXMPP): eme_ns: str = "eu.siacs.conversations.axolotl" cmd_prefix: str = "!" - debug_level: int = LEVEL_DEBUG + debug_level: LEVELS = LEVELS.DEBUG def __init__(self, jid: JidStr, password: str): ClientXMPP.__init__(self, jid, password) + self.model: LLMS = LLMS.LLAMA3 self.prefix_re: re.Pattern = re.compile(r"^%s" % self.cmd_prefix) self.cmd_re: re.Pattern = re.compile( r"^%s(?P\w+)(?:\s+(?P.*))?" % self.cmd_prefix @@ -58,58 +67,75 @@ class OllamaBot(ClientXMPP): groups = match.groupdict() cmd: str = groups["command"] # args = groups['args'] - if cmd == "help": - await self.cmd_help(mto, mtype) - elif cmd == "verbose": - await self.cmd_verbose(mto, mtype) - elif cmd == "error": - await self.cmd_error(mto, mtype) + match cmd: + case LLMS.LLAMA3.value: + await self.cmd_set_llama3(mto, mtype) + case LLMS.MISTRAL.value: + await self.cmd_set_mistral(mto, mtype) + case "verbose": + await self.cmd_verbose(mto, mtype) + case "error": + await self.cmd_error(mto, mtype) + case "help" | _: + await self.cmd_help(mto, mtype) return None async def cmd_help(self, mto: JID, mtype: Optional[MessageTypes]) -> None: body = ( - "Hello, I am the ollama_slixmpp_omemo_bot!" - "The following commands are available:\n" - f"{self.cmd_prefix}verbose Send message or reply with log messages\n" - f"{self.cmd_prefix}error Send message or reply only on error\n" - f"Typing anything else will be sent to llama3!\n" + "Hello, I am the ollama_slixmpp_omemo_bot!\n\n" + "The following commands are available:\n\n" + f"{self.cmd_prefix}verbose - Send message or reply with log messages.\n\n" + f"{self.cmd_prefix}error -Send message or reply only on error.\n\n" + f"{self.cmd_prefix}llama3 - Enable the llama3 model.\n\n" + f"{self.cmd_prefix}mistral - Enable the mistral model.\n\n" + f"Typing anything else will be sent to {self.model.value}!\n\n" ) return await self.encrypted_reply(mto, mtype, body) + async def cmd_set_llama3(self, mto: JID, mtype: Optional[MessageTypes]) -> None: + self.model = LLMS.LLAMA3 + body: str = f"""Model set to {LLMS.LLAMA3.value}""" + return await self.encrypted_reply(mto, mtype, body) + + async def cmd_set_mistral(self, mto: JID, mtype: Optional[MessageTypes]) -> None: + self.model = LLMS.MISTRAL + body: str = f"""Model set to {LLMS.MISTRAL.value}""" + return await self.encrypted_reply(mto, mtype, body) + async def cmd_verbose(self, mto: JID, mtype: Optional[MessageTypes]) -> None: - self.debug_level: int = LEVEL_DEBUG + self.debug_level = LEVELS.DEBUG body: str = """Debug level set to 'verbose'.""" return await self.encrypted_reply(mto, mtype, body) async def cmd_error(self, mto: JID, mtype: Optional[MessageTypes]) -> None: - self.debug_level: int = LEVEL_ERROR + self.debug_level = LEVELS.ERROR body: str = """Debug level set to 'error'.""" return await self.encrypted_reply(mto, mtype, body) async def message_handler( self, msg: Message, allow_untrusted: bool = False - ) -> None: + ) -> Optional[CoroutineFunction]: mfrom: JID = msg["from"] mto: JID = msg["from"] mtype: Optional[MessageTypes] = msg["type"] if mtype not in ("chat", "normal"): return None if not self["xep_0384"].is_encrypted(msg): - if self.debug_level == LEVEL_DEBUG: + if self.debug_level == LEVELS.DEBUG: await self.plain_reply( mto, mtype, f"Echo unencrypted message: {msg['body']}" ) return None try: encrypted = msg["omemo_encrypted"] - body: Optional[str] = await self["xep_0384"].decrypt_message( + body: Optional[bytes] = await self["xep_0384"].decrypt_message( encrypted, mfrom, allow_untrusted ) if body is not None: - decoded: Optional[str] = body.decode("utf8") + decoded: str = body.decode("utf8") if self.is_command(decoded): await self.handle_command(mto, mtype, decoded) - elif self.debug_level == LEVEL_DEBUG: + elif self.debug_level == LEVELS.DEBUG: ollama_server_response: Optional[str] = ( self.message_to_ollama_server(decoded) ) @@ -126,13 +152,16 @@ class OllamaBot(ClientXMPP): await self.encrypted_reply( mto, mtype, - "Error: Message uses an encrypted " "session I don't know about.", + "Error: Message uses an encrypted session I don't know about.", ) except (UndecidedException, UntrustedException) as exn: await self.plain_reply( mto, mtype, - f"Error: Your device '{exn.device}' is not in my trusted devices.", + ( + f"WARNING: Your device '{exn.device}' is not in my trusted devices." + f"Allowing untrusted..." + ), ) await self.message_handler(msg, allow_untrusted=True) except EncryptionPrepareException: @@ -199,13 +228,8 @@ class OllamaBot(ClientXMPP): def message_to_ollama_server(self, msg: Optional[str]) -> Optional[str]: if msg is not None: response = ollama.chat( - model="llama3", - messages=[ - { - "role": "user", - "content": f"{msg}", - }, - ], + model=self.model.value, + messages=[{"role": "user", "content": f"{msg}"}], ) return response["message"]["content"] return None