added model toggle

This commit is contained in:
Matt Freeman 2024-06-30 18:25:32 -04:00
parent a57635a16f
commit 1576952d8f

View File

@ -1,5 +1,6 @@
import re from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
import re
import ollama import ollama
from slixmpp import ClientXMPP, JID from slixmpp import ClientXMPP, JID
@ -7,6 +8,7 @@ from slixmpp.exceptions import IqTimeout, IqError
from slixmpp.stanza import Message from slixmpp.stanza import Message
from slixmpp.types import JidStr, MessageTypes from slixmpp.types import JidStr, MessageTypes
from slixmpp.xmlstream.handler import CoroutineCallback from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.xmlstream.handler.coroutine_callback import CoroutineFunction
from slixmpp.xmlstream.matcher import MatchXPath from slixmpp.xmlstream.matcher import MatchXPath
from slixmpp_omemo import ( from slixmpp_omemo import (
EncryptionPrepareException, EncryptionPrepareException,
@ -18,17 +20,24 @@ from slixmpp_omemo import (
from omemo.exceptions import MissingBundleException from omemo.exceptions import MissingBundleException
LEVEL_DEBUG: int = 0 class LEVELS(Enum):
LEVEL_ERROR: int = 1 DEBUG = 0
ERROR = 1
class LLMS(Enum):
LLAMA3 = "llama3"
MISTRAL = "mistral"
class OllamaBot(ClientXMPP): class OllamaBot(ClientXMPP):
eme_ns: str = "eu.siacs.conversations.axolotl" eme_ns: str = "eu.siacs.conversations.axolotl"
cmd_prefix: str = "!" cmd_prefix: str = "!"
debug_level: int = LEVEL_DEBUG debug_level: LEVELS = LEVELS.DEBUG
def __init__(self, jid: JidStr, password: str): def __init__(self, jid: JidStr, password: str):
ClientXMPP.__init__(self, jid, password) ClientXMPP.__init__(self, jid, password)
self.model: LLMS = LLMS.LLAMA3
self.prefix_re: re.Pattern = re.compile(r"^%s" % self.cmd_prefix) self.prefix_re: re.Pattern = re.compile(r"^%s" % self.cmd_prefix)
self.cmd_re: re.Pattern = re.compile( self.cmd_re: re.Pattern = re.compile(
r"^%s(?P<command>\w+)(?:\s+(?P<args>.*))?" % self.cmd_prefix r"^%s(?P<command>\w+)(?:\s+(?P<args>.*))?" % self.cmd_prefix
@ -58,58 +67,75 @@ class OllamaBot(ClientXMPP):
groups = match.groupdict() groups = match.groupdict()
cmd: str = groups["command"] cmd: str = groups["command"]
# args = groups['args'] # args = groups['args']
if cmd == "help": match cmd:
await self.cmd_help(mto, mtype) case LLMS.LLAMA3.value:
elif cmd == "verbose": await self.cmd_set_llama3(mto, mtype)
await self.cmd_verbose(mto, mtype) case LLMS.MISTRAL.value:
elif cmd == "error": await self.cmd_set_mistral(mto, mtype)
await self.cmd_error(mto, mtype) case "verbose":
await self.cmd_verbose(mto, mtype)
case "error":
await self.cmd_error(mto, mtype)
case "help" | _:
await self.cmd_help(mto, mtype)
return None return None
async def cmd_help(self, mto: JID, mtype: Optional[MessageTypes]) -> None: async def cmd_help(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
body = ( body = (
"Hello, I am the ollama_slixmpp_omemo_bot!" "Hello, I am the ollama_slixmpp_omemo_bot!\n\n"
"The following commands are available:\n" "The following commands are available:\n\n"
f"{self.cmd_prefix}verbose Send message or reply with log messages\n" f"{self.cmd_prefix}verbose - Send message or reply with log messages.\n\n"
f"{self.cmd_prefix}error Send message or reply only on error\n" f"{self.cmd_prefix}error -Send message or reply only on error.\n\n"
f"Typing anything else will be sent to llama3!\n" f"{self.cmd_prefix}llama3 - Enable the llama3 model.\n\n"
f"{self.cmd_prefix}mistral - Enable the mistral model.\n\n"
f"Typing anything else will be sent to {self.model.value}!\n\n"
) )
return await self.encrypted_reply(mto, mtype, body) return await self.encrypted_reply(mto, mtype, body)
async def cmd_set_llama3(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
self.model = LLMS.LLAMA3
body: str = f"""Model set to {LLMS.LLAMA3.value}"""
return await self.encrypted_reply(mto, mtype, body)
async def cmd_set_mistral(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
self.model = LLMS.MISTRAL
body: str = f"""Model set to {LLMS.MISTRAL.value}"""
return await self.encrypted_reply(mto, mtype, body)
async def cmd_verbose(self, mto: JID, mtype: Optional[MessageTypes]) -> None: async def cmd_verbose(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
self.debug_level: int = LEVEL_DEBUG self.debug_level = LEVELS.DEBUG
body: str = """Debug level set to 'verbose'.""" body: str = """Debug level set to 'verbose'."""
return await self.encrypted_reply(mto, mtype, body) return await self.encrypted_reply(mto, mtype, body)
async def cmd_error(self, mto: JID, mtype: Optional[MessageTypes]) -> None: async def cmd_error(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
self.debug_level: int = LEVEL_ERROR self.debug_level = LEVELS.ERROR
body: str = """Debug level set to 'error'.""" body: str = """Debug level set to 'error'."""
return await self.encrypted_reply(mto, mtype, body) return await self.encrypted_reply(mto, mtype, body)
async def message_handler( async def message_handler(
self, msg: Message, allow_untrusted: bool = False self, msg: Message, allow_untrusted: bool = False
) -> None: ) -> Optional[CoroutineFunction]:
mfrom: JID = msg["from"] mfrom: JID = msg["from"]
mto: JID = msg["from"] mto: JID = msg["from"]
mtype: Optional[MessageTypes] = msg["type"] mtype: Optional[MessageTypes] = msg["type"]
if mtype not in ("chat", "normal"): if mtype not in ("chat", "normal"):
return None return None
if not self["xep_0384"].is_encrypted(msg): if not self["xep_0384"].is_encrypted(msg):
if self.debug_level == LEVEL_DEBUG: if self.debug_level == LEVELS.DEBUG:
await self.plain_reply( await self.plain_reply(
mto, mtype, f"Echo unencrypted message: {msg['body']}" mto, mtype, f"Echo unencrypted message: {msg['body']}"
) )
return None return None
try: try:
encrypted = msg["omemo_encrypted"] encrypted = msg["omemo_encrypted"]
body: Optional[str] = await self["xep_0384"].decrypt_message( body: Optional[bytes] = await self["xep_0384"].decrypt_message(
encrypted, mfrom, allow_untrusted encrypted, mfrom, allow_untrusted
) )
if body is not None: if body is not None:
decoded: Optional[str] = body.decode("utf8") decoded: str = body.decode("utf8")
if self.is_command(decoded): if self.is_command(decoded):
await self.handle_command(mto, mtype, decoded) await self.handle_command(mto, mtype, decoded)
elif self.debug_level == LEVEL_DEBUG: elif self.debug_level == LEVELS.DEBUG:
ollama_server_response: Optional[str] = ( ollama_server_response: Optional[str] = (
self.message_to_ollama_server(decoded) self.message_to_ollama_server(decoded)
) )
@ -126,13 +152,16 @@ class OllamaBot(ClientXMPP):
await self.encrypted_reply( await self.encrypted_reply(
mto, mto,
mtype, mtype,
"Error: Message uses an encrypted " "session I don't know about.", "Error: Message uses an encrypted session I don't know about.",
) )
except (UndecidedException, UntrustedException) as exn: except (UndecidedException, UntrustedException) as exn:
await self.plain_reply( await self.plain_reply(
mto, mto,
mtype, mtype,
f"Error: Your device '{exn.device}' is not in my trusted devices.", (
f"WARNING: Your device '{exn.device}' is not in my trusted devices."
f"Allowing untrusted..."
),
) )
await self.message_handler(msg, allow_untrusted=True) await self.message_handler(msg, allow_untrusted=True)
except EncryptionPrepareException: except EncryptionPrepareException:
@ -199,13 +228,8 @@ class OllamaBot(ClientXMPP):
def message_to_ollama_server(self, msg: Optional[str]) -> Optional[str]: def message_to_ollama_server(self, msg: Optional[str]) -> Optional[str]:
if msg is not None: if msg is not None:
response = ollama.chat( response = ollama.chat(
model="llama3", model=self.model.value,
messages=[ messages=[{"role": "user", "content": f"{msg}"}],
{
"role": "user",
"content": f"{msg}",
},
],
) )
return response["message"]["content"] return response["message"]["content"]
return None return None