added model toggle
This commit is contained in:
parent
a57635a16f
commit
1576952d8f
@ -1,5 +1,6 @@
|
|||||||
import re
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
import re
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
from slixmpp import ClientXMPP, JID
|
from slixmpp import ClientXMPP, JID
|
||||||
@ -7,6 +8,7 @@ from slixmpp.exceptions import IqTimeout, IqError
|
|||||||
from slixmpp.stanza import Message
|
from slixmpp.stanza import Message
|
||||||
from slixmpp.types import JidStr, MessageTypes
|
from slixmpp.types import JidStr, MessageTypes
|
||||||
from slixmpp.xmlstream.handler import CoroutineCallback
|
from slixmpp.xmlstream.handler import CoroutineCallback
|
||||||
|
from slixmpp.xmlstream.handler.coroutine_callback import CoroutineFunction
|
||||||
from slixmpp.xmlstream.matcher import MatchXPath
|
from slixmpp.xmlstream.matcher import MatchXPath
|
||||||
from slixmpp_omemo import (
|
from slixmpp_omemo import (
|
||||||
EncryptionPrepareException,
|
EncryptionPrepareException,
|
||||||
@ -18,17 +20,24 @@ from slixmpp_omemo import (
|
|||||||
from omemo.exceptions import MissingBundleException
|
from omemo.exceptions import MissingBundleException
|
||||||
|
|
||||||
|
|
||||||
LEVEL_DEBUG: int = 0
|
class LEVELS(Enum):
|
||||||
LEVEL_ERROR: int = 1
|
DEBUG = 0
|
||||||
|
ERROR = 1
|
||||||
|
|
||||||
|
|
||||||
|
class LLMS(Enum):
|
||||||
|
LLAMA3 = "llama3"
|
||||||
|
MISTRAL = "mistral"
|
||||||
|
|
||||||
|
|
||||||
class OllamaBot(ClientXMPP):
|
class OllamaBot(ClientXMPP):
|
||||||
eme_ns: str = "eu.siacs.conversations.axolotl"
|
eme_ns: str = "eu.siacs.conversations.axolotl"
|
||||||
cmd_prefix: str = "!"
|
cmd_prefix: str = "!"
|
||||||
debug_level: int = LEVEL_DEBUG
|
debug_level: LEVELS = LEVELS.DEBUG
|
||||||
|
|
||||||
def __init__(self, jid: JidStr, password: str):
|
def __init__(self, jid: JidStr, password: str):
|
||||||
ClientXMPP.__init__(self, jid, password)
|
ClientXMPP.__init__(self, jid, password)
|
||||||
|
self.model: LLMS = LLMS.LLAMA3
|
||||||
self.prefix_re: re.Pattern = re.compile(r"^%s" % self.cmd_prefix)
|
self.prefix_re: re.Pattern = re.compile(r"^%s" % self.cmd_prefix)
|
||||||
self.cmd_re: re.Pattern = re.compile(
|
self.cmd_re: re.Pattern = re.compile(
|
||||||
r"^%s(?P<command>\w+)(?:\s+(?P<args>.*))?" % self.cmd_prefix
|
r"^%s(?P<command>\w+)(?:\s+(?P<args>.*))?" % self.cmd_prefix
|
||||||
@ -58,58 +67,75 @@ class OllamaBot(ClientXMPP):
|
|||||||
groups = match.groupdict()
|
groups = match.groupdict()
|
||||||
cmd: str = groups["command"]
|
cmd: str = groups["command"]
|
||||||
# args = groups['args']
|
# args = groups['args']
|
||||||
if cmd == "help":
|
match cmd:
|
||||||
await self.cmd_help(mto, mtype)
|
case LLMS.LLAMA3.value:
|
||||||
elif cmd == "verbose":
|
await self.cmd_set_llama3(mto, mtype)
|
||||||
await self.cmd_verbose(mto, mtype)
|
case LLMS.MISTRAL.value:
|
||||||
elif cmd == "error":
|
await self.cmd_set_mistral(mto, mtype)
|
||||||
await self.cmd_error(mto, mtype)
|
case "verbose":
|
||||||
|
await self.cmd_verbose(mto, mtype)
|
||||||
|
case "error":
|
||||||
|
await self.cmd_error(mto, mtype)
|
||||||
|
case "help" | _:
|
||||||
|
await self.cmd_help(mto, mtype)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def cmd_help(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
async def cmd_help(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
||||||
body = (
|
body = (
|
||||||
"Hello, I am the ollama_slixmpp_omemo_bot!"
|
"Hello, I am the ollama_slixmpp_omemo_bot!\n\n"
|
||||||
"The following commands are available:\n"
|
"The following commands are available:\n\n"
|
||||||
f"{self.cmd_prefix}verbose Send message or reply with log messages\n"
|
f"{self.cmd_prefix}verbose - Send message or reply with log messages.\n\n"
|
||||||
f"{self.cmd_prefix}error Send message or reply only on error\n"
|
f"{self.cmd_prefix}error -Send message or reply only on error.\n\n"
|
||||||
f"Typing anything else will be sent to llama3!\n"
|
f"{self.cmd_prefix}llama3 - Enable the llama3 model.\n\n"
|
||||||
|
f"{self.cmd_prefix}mistral - Enable the mistral model.\n\n"
|
||||||
|
f"Typing anything else will be sent to {self.model.value}!\n\n"
|
||||||
)
|
)
|
||||||
return await self.encrypted_reply(mto, mtype, body)
|
return await self.encrypted_reply(mto, mtype, body)
|
||||||
|
|
||||||
|
async def cmd_set_llama3(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
||||||
|
self.model = LLMS.LLAMA3
|
||||||
|
body: str = f"""Model set to {LLMS.LLAMA3.value}"""
|
||||||
|
return await self.encrypted_reply(mto, mtype, body)
|
||||||
|
|
||||||
|
async def cmd_set_mistral(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
||||||
|
self.model = LLMS.MISTRAL
|
||||||
|
body: str = f"""Model set to {LLMS.MISTRAL.value}"""
|
||||||
|
return await self.encrypted_reply(mto, mtype, body)
|
||||||
|
|
||||||
async def cmd_verbose(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
async def cmd_verbose(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
||||||
self.debug_level: int = LEVEL_DEBUG
|
self.debug_level = LEVELS.DEBUG
|
||||||
body: str = """Debug level set to 'verbose'."""
|
body: str = """Debug level set to 'verbose'."""
|
||||||
return await self.encrypted_reply(mto, mtype, body)
|
return await self.encrypted_reply(mto, mtype, body)
|
||||||
|
|
||||||
async def cmd_error(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
async def cmd_error(self, mto: JID, mtype: Optional[MessageTypes]) -> None:
|
||||||
self.debug_level: int = LEVEL_ERROR
|
self.debug_level = LEVELS.ERROR
|
||||||
body: str = """Debug level set to 'error'."""
|
body: str = """Debug level set to 'error'."""
|
||||||
return await self.encrypted_reply(mto, mtype, body)
|
return await self.encrypted_reply(mto, mtype, body)
|
||||||
|
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
self, msg: Message, allow_untrusted: bool = False
|
self, msg: Message, allow_untrusted: bool = False
|
||||||
) -> None:
|
) -> Optional[CoroutineFunction]:
|
||||||
mfrom: JID = msg["from"]
|
mfrom: JID = msg["from"]
|
||||||
mto: JID = msg["from"]
|
mto: JID = msg["from"]
|
||||||
mtype: Optional[MessageTypes] = msg["type"]
|
mtype: Optional[MessageTypes] = msg["type"]
|
||||||
if mtype not in ("chat", "normal"):
|
if mtype not in ("chat", "normal"):
|
||||||
return None
|
return None
|
||||||
if not self["xep_0384"].is_encrypted(msg):
|
if not self["xep_0384"].is_encrypted(msg):
|
||||||
if self.debug_level == LEVEL_DEBUG:
|
if self.debug_level == LEVELS.DEBUG:
|
||||||
await self.plain_reply(
|
await self.plain_reply(
|
||||||
mto, mtype, f"Echo unencrypted message: {msg['body']}"
|
mto, mtype, f"Echo unencrypted message: {msg['body']}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
encrypted = msg["omemo_encrypted"]
|
encrypted = msg["omemo_encrypted"]
|
||||||
body: Optional[str] = await self["xep_0384"].decrypt_message(
|
body: Optional[bytes] = await self["xep_0384"].decrypt_message(
|
||||||
encrypted, mfrom, allow_untrusted
|
encrypted, mfrom, allow_untrusted
|
||||||
)
|
)
|
||||||
if body is not None:
|
if body is not None:
|
||||||
decoded: Optional[str] = body.decode("utf8")
|
decoded: str = body.decode("utf8")
|
||||||
if self.is_command(decoded):
|
if self.is_command(decoded):
|
||||||
await self.handle_command(mto, mtype, decoded)
|
await self.handle_command(mto, mtype, decoded)
|
||||||
elif self.debug_level == LEVEL_DEBUG:
|
elif self.debug_level == LEVELS.DEBUG:
|
||||||
ollama_server_response: Optional[str] = (
|
ollama_server_response: Optional[str] = (
|
||||||
self.message_to_ollama_server(decoded)
|
self.message_to_ollama_server(decoded)
|
||||||
)
|
)
|
||||||
@ -126,13 +152,16 @@ class OllamaBot(ClientXMPP):
|
|||||||
await self.encrypted_reply(
|
await self.encrypted_reply(
|
||||||
mto,
|
mto,
|
||||||
mtype,
|
mtype,
|
||||||
"Error: Message uses an encrypted " "session I don't know about.",
|
"Error: Message uses an encrypted session I don't know about.",
|
||||||
)
|
)
|
||||||
except (UndecidedException, UntrustedException) as exn:
|
except (UndecidedException, UntrustedException) as exn:
|
||||||
await self.plain_reply(
|
await self.plain_reply(
|
||||||
mto,
|
mto,
|
||||||
mtype,
|
mtype,
|
||||||
f"Error: Your device '{exn.device}' is not in my trusted devices.",
|
(
|
||||||
|
f"WARNING: Your device '{exn.device}' is not in my trusted devices."
|
||||||
|
f"Allowing untrusted..."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
await self.message_handler(msg, allow_untrusted=True)
|
await self.message_handler(msg, allow_untrusted=True)
|
||||||
except EncryptionPrepareException:
|
except EncryptionPrepareException:
|
||||||
@ -199,13 +228,8 @@ class OllamaBot(ClientXMPP):
|
|||||||
def message_to_ollama_server(self, msg: Optional[str]) -> Optional[str]:
|
def message_to_ollama_server(self, msg: Optional[str]) -> Optional[str]:
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
response = ollama.chat(
|
response = ollama.chat(
|
||||||
model="llama3",
|
model=self.model.value,
|
||||||
messages=[
|
messages=[{"role": "user", "content": f"{msg}"}],
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"{msg}",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
return response["message"]["content"]
|
return response["message"]["content"]
|
||||||
return None
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user