459 lines
18 KiB
Python
459 lines
18 KiB
Python
#chat_widget.py
|
|
"""
|
|
Handles the chat widget (testing)
|
|
"""
|
|
|
|
import gi
|
|
gi.require_version('Gtk', '4.0')
|
|
gi.require_version('GtkSource', '5')
|
|
from gi.repository import Gtk, Gio, Adw, Gdk, GLib
|
|
import logging, os, datetime, shutil, random, tempfile, tarfile, json
|
|
from ..internal import data_dir
|
|
from .message_widget import message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
window = None
|
|
|
|
possible_prompts = [
|
|
"What can you do?",
|
|
"Give me a pancake recipe",
|
|
"Why is the sky blue?",
|
|
"Can you tell me a joke?",
|
|
"Give me a healthy breakfast recipe",
|
|
"How to make a pizza",
|
|
"Can you write a poem?",
|
|
"Can you write a story?",
|
|
"What is GNU-Linux?",
|
|
"Which is the best Linux distro?",
|
|
"Why is Pluto not a planet?",
|
|
"What is a black-hole?",
|
|
"Tell me how to stay fit",
|
|
"Write a conversation between sun and Earth",
|
|
"Why is the grass green?",
|
|
"Write an Haïku about AI",
|
|
"What is the meaning of life?",
|
|
"Explain quantum physics in simple terms",
|
|
"Explain the theory of relativity",
|
|
"Explain how photosynthesis works",
|
|
"Recommend a film about nature",
|
|
"What is nostalgia?"
|
|
]
|
|
|
|
class chat(Gtk.ScrolledWindow):
|
|
__gtype_name__ = 'AlpacaChat'
|
|
|
|
def __init__(self, name:str):
|
|
self.container = Gtk.Box(
|
|
orientation=1,
|
|
hexpand=True,
|
|
vexpand=True,
|
|
spacing=12,
|
|
margin_top=12,
|
|
margin_bottom=12,
|
|
margin_start=12,
|
|
margin_end=12
|
|
)
|
|
self.clamp = Adw.Clamp(
|
|
maximum_size=1000,
|
|
tightening_threshold=800,
|
|
child=self.container
|
|
)
|
|
super().__init__(
|
|
child=self.clamp,
|
|
propagate_natural_height=True,
|
|
kinetic_scrolling=True,
|
|
vexpand=True,
|
|
hexpand=True,
|
|
css_classes=["undershoot-bottom"],
|
|
name=name,
|
|
hscrollbar_policy=2
|
|
)
|
|
self.messages = {}
|
|
self.welcome_screen = None
|
|
self.regenerate_button = None
|
|
self.busy = False
|
|
#self.get_vadjustment().connect('notify::page-size', lambda va, *_: va.set_value(va.get_upper() - va.get_page_size()) if va.get_value() == 0 else None)
|
|
##TODO Figure out how to do this with the search thing
|
|
|
|
def stop_message(self):
|
|
self.busy = False
|
|
window.switch_send_stop_button(True)
|
|
|
|
def clear_chat(self):
|
|
if self.busy:
|
|
self.stop_message()
|
|
self.messages = {}
|
|
self.stop_message()
|
|
for widget in list(self.container):
|
|
self.container.remove(widget)
|
|
self.show_welcome_screen(len(window.model_manager.get_model_list()) > 0)
|
|
print('clear chat for some reason')
|
|
|
|
def add_message(self, message_id:str, model:str=None):
|
|
msg = message(message_id, model)
|
|
self.messages[message_id] = msg
|
|
self.container.append(msg)
|
|
|
|
def send_sample_prompt(self, prompt):
|
|
buffer = window.message_text_view.get_buffer()
|
|
buffer.delete(buffer.get_start_iter(), buffer.get_end_iter())
|
|
buffer.insert(buffer.get_start_iter(), prompt, len(prompt.encode('utf-8')))
|
|
window.send_message()
|
|
|
|
def show_welcome_screen(self, show_prompts:bool):
|
|
if self.welcome_screen:
|
|
self.container.remove(self.welcome_screen)
|
|
self.welcome_screen = None
|
|
if len(list(self.container)) > 0:
|
|
self.clear_chat()
|
|
return
|
|
button_container = Gtk.Box(
|
|
orientation=1,
|
|
spacing=10,
|
|
halign=3
|
|
)
|
|
if show_prompts:
|
|
for prompt in random.sample(possible_prompts, 3):
|
|
prompt_button = Gtk.Button(
|
|
label=prompt,
|
|
tooltip_text=_("Send prompt: '{}'").format(prompt)
|
|
)
|
|
prompt_button.connect('clicked', lambda *_, prompt=prompt : self.send_sample_prompt(prompt))
|
|
button_container.append(prompt_button)
|
|
else:
|
|
button = Gtk.Button(
|
|
label=_("Open Model Manager"),
|
|
tooltip_text=_("Open Model Manager"),
|
|
css_classes=["suggested-action", "pill"]
|
|
)
|
|
button.set_action_name('app.manage_models')
|
|
button_container.append(button)
|
|
|
|
self.welcome_screen = Adw.StatusPage(
|
|
icon_name="com.jeffser.Alpaca",
|
|
title="Alpaca",
|
|
description=_("Try one of these prompts") if show_prompts else _("It looks like you don't have any models downloaded yet. Download models to get started!"),
|
|
child=button_container,
|
|
vexpand=True
|
|
)
|
|
|
|
self.container.append(self.welcome_screen)
|
|
|
|
def load_chat_messages(self, messages:dict):
|
|
if len(messages.keys()) > 0:
|
|
if self.welcome_screen:
|
|
self.container.remove(self.welcome_screen)
|
|
self.welcome_screen = None
|
|
for message_id, message_data in messages.items():
|
|
if message_data['content']:
|
|
self.add_message(message_id, message_data['model'] if message_data['role'] == 'assistant' else None)
|
|
message_element = self.messages[message_id]
|
|
if 'images' in message_data:
|
|
images=[]
|
|
for image in message_data['images']:
|
|
images.append(os.path.join(data_dir, "chats", self.get_name(), message_id, image))
|
|
message_element.add_images(images)
|
|
if 'files' in message_data:
|
|
files={}
|
|
for file_name, file_type in message_data['files'].items():
|
|
files[os.path.join(data_dir, "chats", self.get_name(), message_id, file_name)] = file_type
|
|
message_element.add_attachments(files)
|
|
GLib.idle_add(message_element.set_text, message_data['content'])
|
|
GLib.idle_add(message_element.add_footer, datetime.datetime.strptime(message_data['date'] + (":00" if message_data['date'].count(":") == 1 else ""), '%Y/%m/%d %H:%M:%S'))
|
|
else:
|
|
self.show_welcome_screen(len(window.model_manager.get_model_list()) > 0)
|
|
|
|
def messages_to_dict(self) -> dict:
|
|
messages_dict = {}
|
|
for message_id, message_element in self.messages.items():
|
|
if message_element.text and message_element.dt:
|
|
messages_dict[message_id] = {
|
|
'role': 'assistant' if message_element.bot else 'user',
|
|
'model': message_element.model,
|
|
'date': message_element.dt.strftime("%Y/%m/%d %H:%M:%S"),
|
|
'content': message_element.text
|
|
}
|
|
|
|
if message_element.image_c:
|
|
images = []
|
|
for file in message_element.image_c.files:
|
|
images.append(file.image_name)
|
|
messages_dict[message_id]['images'] = images
|
|
|
|
if message_element.attachment_c:
|
|
files = {}
|
|
for file in message_element.attachment_c.files:
|
|
files[file.file_name] = file.file_type
|
|
messages_dict[message_id]['files'] = files
|
|
return messages_dict
|
|
|
|
def show_regenerate_button(self, msg:message):
|
|
if self.regenerate_button:
|
|
self.remove(self.regenerate_button)
|
|
self.regenerate_button = Gtk.Button(
|
|
child=Adw.ButtonContent(
|
|
icon_name='update-symbolic',
|
|
label=_('Regenerate Response')
|
|
),
|
|
css_classes=["suggested-action"],
|
|
halign=3
|
|
)
|
|
self.regenerate_button.connect('clicked', lambda *_: msg.action_buttons.regenerate_message())
|
|
self.container.append(self.regenerate_button)
|
|
|
|
class chat_tab(Gtk.ListBoxRow):
|
|
__gtype_name__ = 'AlpacaChatTab'
|
|
|
|
def __init__(self, chat_window:chat):
|
|
self.chat_window=chat_window
|
|
self.spinner = Gtk.Spinner(
|
|
spinning=True,
|
|
visible=False
|
|
)
|
|
self.label = Gtk.Label(
|
|
label=self.chat_window.get_name(),
|
|
tooltip_text=self.chat_window.get_name(),
|
|
hexpand=True,
|
|
halign=0,
|
|
wrap=True,
|
|
ellipsize=3,
|
|
wrap_mode=2,
|
|
xalign=0
|
|
)
|
|
self.indicator = Gtk.Image.new_from_icon_name("chat-bubble-text-symbolic")
|
|
self.indicator.set_visible(False)
|
|
self.indicator.set_css_classes(['accent'])
|
|
container = Gtk.Box(
|
|
orientation=0,
|
|
spacing=5
|
|
)
|
|
container.append(self.label)
|
|
container.append(self.spinner)
|
|
container.append(self.indicator)
|
|
super().__init__(
|
|
css_classes = ["chat_row"],
|
|
height_request = 45,
|
|
child = container
|
|
)
|
|
|
|
self.gesture = Gtk.GestureClick(button=3)
|
|
self.gesture.connect("released", self.chat_click_handler)
|
|
self.add_controller(self.gesture)
|
|
|
|
def chat_click_handler(self, gesture, n_press, x, y):
|
|
chat_row = gesture.get_widget()
|
|
popover = Gtk.PopoverMenu(
|
|
menu_model=window.chat_right_click_menu,
|
|
has_arrow=False,
|
|
halign=1,
|
|
height_request=155
|
|
)
|
|
window.selected_chat_row = chat_row
|
|
position = Gdk.Rectangle()
|
|
position.x = x
|
|
position.y = y
|
|
popover.set_parent(chat_row.get_child())
|
|
popover.set_pointing_to(position)
|
|
popover.popup()
|
|
|
|
class chat_list(Gtk.ListBox):
|
|
__gtype_name__ = 'AlpacaChatList'
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
selection_mode=1,
|
|
css_classes=["navigation-sidebar"]
|
|
)
|
|
self.connect("row-selected", lambda listbox, row: self.chat_changed(row))
|
|
self.tab_list = []
|
|
|
|
def update_welcome_screens(self, show_prompts:bool):
|
|
for tab in self.tab_list:
|
|
if tab.chat_window.welcome_screen:
|
|
tab.chat_window.show_welcome_screen(show_prompts)
|
|
|
|
def get_tab_by_name(self, chat_name:str) -> chat_tab:
|
|
for tab in self.tab_list:
|
|
if tab.chat_window.get_name() == chat_name:
|
|
return tab
|
|
|
|
def get_chat_by_name(self, chat_name:str) -> chat:
|
|
tab = self.get_tab_by_name(chat_name)
|
|
if tab:
|
|
return tab.chat_window
|
|
|
|
def get_current_chat(self) -> chat:
|
|
row = self.get_selected_row()
|
|
if row:
|
|
return self.get_selected_row().chat_window
|
|
|
|
def send_tab_to_top(self, tab:chat_tab):
|
|
self.unselect_all()
|
|
self.tab_list.remove(tab)
|
|
self.tab_list.insert(0, tab)
|
|
self.remove(tab)
|
|
self.prepend(tab)
|
|
self.select_row(tab)
|
|
|
|
def append_chat(self, chat_name:str) -> chat:
|
|
chat_name = window.generate_numbered_name(chat_name, [tab.chat_window.get_name() for tab in self.tab_list])
|
|
chat_window = chat(chat_name)
|
|
tab = chat_tab(chat_window)
|
|
self.append(tab)
|
|
self.tab_list.append(tab)
|
|
window.chat_stack.add_child(chat_window)
|
|
return chat_window
|
|
|
|
def prepend_chat(self, chat_name:str) -> chat:
|
|
chat_name = window.generate_numbered_name(chat_name, [tab.chat_window.get_name() for tab in self.tab_list])
|
|
chat_window = chat(chat_name)
|
|
tab = chat_tab(chat_window)
|
|
self.prepend(tab)
|
|
self.tab_list.insert(0, tab)
|
|
chat_window.show_welcome_screen(len(window.model_manager.get_model_list()) > 0)
|
|
window.chat_stack.add_child(chat_window)
|
|
window.chat_list_box.select_row(tab)
|
|
return chat_window
|
|
|
|
def new_chat(self):
|
|
window.save_history(self.prepend_chat(_("New Chat")))
|
|
|
|
def delete_chat(self, chat_name:str):
|
|
chat_tab = None
|
|
for c in self.tab_list:
|
|
if c.chat_window.get_name() == chat_name:
|
|
chat_tab = c
|
|
if chat_tab:
|
|
chat_tab.chat_window.stop_message()
|
|
window.chat_stack.remove(chat_tab.chat_window)
|
|
self.tab_list.remove(chat_tab)
|
|
self.remove(chat_tab)
|
|
if os.path.exists(os.path.join(data_dir, "chats", chat_name)):
|
|
shutil.rmtree(os.path.join(data_dir, "chats", chat_name))
|
|
if len(self.tab_list) == 0:
|
|
self.new_chat()
|
|
if not self.get_current_chat() or self.get_current_chat() == chat_tab.chat_window:
|
|
self.select_row(self.get_row_at_index(0))
|
|
window.save_history()
|
|
|
|
def rename_chat(self, old_chat_name:str, new_chat_name:str):
|
|
if new_chat_name == old_chat_name:
|
|
return
|
|
tab = self.get_tab_by_name(old_chat_name)
|
|
if tab:
|
|
new_chat_name = window.generate_numbered_name(new_chat_name, [tab.chat_window.get_name() for tab in self.tab_list])
|
|
tab.label.set_label(new_chat_name)
|
|
tab.label.set_tooltip_text(new_chat_name)
|
|
tab.chat_window.set_name(new_chat_name)
|
|
if os.path.exists(os.path.join(data_dir, "chats", old_chat_name)):
|
|
shutil.move(os.path.join(data_dir, "chats", old_chat_name), os.path.join(data_dir, "chats", new_chat_name))
|
|
window.save_history(tab.chat_window)
|
|
|
|
def duplicate_chat(self, chat_name:str):
|
|
new_chat_name = window.generate_numbered_name(_("Copy of {}").format(chat_name), [tab.chat_window.get_name() for tab in self.tab_list])
|
|
try:
|
|
shutil.copytree(os.path.join(data_dir, "chats", chat_name), os.path.join(data_dir, "chats", new_chat_name))
|
|
except Exception as e:
|
|
logger.error(e)
|
|
self.prepend_chat(new_chat_name)
|
|
created_chat = self.get_tab_by_name(new_chat_name).chat_window
|
|
created_chat.load_chat_messages(self.get_tab_by_name(chat_name).chat_window.messages_to_dict())
|
|
window.save_history(created_chat)
|
|
|
|
def on_replace_contents(self, file, result):
|
|
file.replace_contents_finish(result)
|
|
window.show_toast(_("Chat exported successfully"), window.main_overlay)
|
|
|
|
def on_export_chat(self, file_dialog, result, chat_name):
|
|
file = file_dialog.save_finish(result)
|
|
if not file:
|
|
return
|
|
json_data = json.dumps({chat_name: {"messages": self.get_chat_by_name(chat_name).messages_to_dict()}}, indent=4).encode("UTF-8")
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
json_path = os.path.join(temp_dir, "data.json")
|
|
with open(json_path, "wb") as json_file:
|
|
json_file.write(json_data)
|
|
|
|
tar_path = os.path.join(temp_dir, chat_name)
|
|
with tarfile.open(tar_path, "w") as tar:
|
|
tar.add(json_path, arcname="data.json")
|
|
directory = os.path.join(data_dir, "chats", chat_name)
|
|
if os.path.exists(directory) and os.path.isdir(directory):
|
|
tar.add(directory, arcname=os.path.basename(directory))
|
|
|
|
with open(tar_path, "rb") as tar:
|
|
tar_content = tar.read()
|
|
|
|
file.replace_contents_async(
|
|
tar_content,
|
|
etag=None,
|
|
make_backup=False,
|
|
flags=Gio.FileCreateFlags.NONE,
|
|
cancellable=None,
|
|
callback=self.on_replace_contents
|
|
)
|
|
|
|
def export_chat(self, chat_name:str):
|
|
logger.info("Exporting chat")
|
|
file_dialog = Gtk.FileDialog(initial_name=f"{chat_name}.tar")
|
|
file_dialog.save(parent=window, cancellable=None, callback=lambda file_dialog, result, chat_name=chat_name: self.on_export_chat(file_dialog, result, chat_name))
|
|
|
|
def on_chat_imported(self, file_dialog, result):
|
|
file = file_dialog.open_finish(result)
|
|
if not file:
|
|
return
|
|
stream = file.read(None)
|
|
data_stream = Gio.DataInputStream.new(stream)
|
|
tar_content = data_stream.read_bytes(1024 * 1024, None)
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
tar_filename = os.path.join(temp_dir, "imported_chat.tar")
|
|
|
|
with open(tar_filename, "wb") as tar_file:
|
|
tar_file.write(tar_content.get_data())
|
|
|
|
with tarfile.open(tar_filename, "r") as tar:
|
|
tar.extractall(path=temp_dir)
|
|
chat_name = None
|
|
chat_content = None
|
|
for member in tar.getmembers():
|
|
if member.name == "data.json":
|
|
json_filepath = os.path.join(temp_dir, member.name)
|
|
with open(json_filepath, "r", encoding="utf-8") as json_file:
|
|
data = json.load(json_file)
|
|
for chat_name, chat_content in data.items():
|
|
new_chat_name = window.generate_numbered_name(chat_name, [tab.chat_window.get_name() for tab in self.tab_list])
|
|
src_path = os.path.join(temp_dir, chat_name)
|
|
dest_path = os.path.join(data_dir, "chats", new_chat_name)
|
|
if os.path.exists(src_path) and os.path.isdir(src_path) and not os.path.exists(dest_path):
|
|
shutil.copytree(src_path, dest_path)
|
|
|
|
created_chat = self.prepend_chat(new_chat_name)
|
|
created_chat.load_chat_messages(chat_content['messages'])
|
|
window.save_history(created_chat)
|
|
window.show_toast(_("Chat imported successfully"), window.main_overlay)
|
|
|
|
def import_chat(self):
|
|
logger.info("Importing chat")
|
|
file_dialog = Gtk.FileDialog(default_filter=window.file_filter_tar)
|
|
file_dialog.open(window, None, self.on_chat_imported)
|
|
|
|
def chat_changed(self, row):
|
|
if row:
|
|
current_tab_i = next((i for i, t in enumerate(self.tab_list) if t.chat_window == window.chat_stack.get_visible_child()), -1)
|
|
if self.tab_list.index(row) != current_tab_i:
|
|
if window.searchentry_messages.get_text() != '':
|
|
window.searchentry_messages.set_text('')
|
|
window.message_search_changed(window.searchentry_messages, window.chat_stack.get_visible_child())
|
|
window.message_searchbar.set_search_mode(False)
|
|
window.chat_stack.set_transition_type(4 if self.tab_list.index(row) > current_tab_i else 5)
|
|
window.chat_stack.set_visible_child(row.chat_window)
|
|
window.switch_send_stop_button(not row.chat_window.busy)
|
|
if len(row.chat_window.messages) > 0:
|
|
last_model_used = row.chat_window.messages[list(row.chat_window.messages)[-1]].model
|
|
window.model_manager.change_model(last_model_used)
|
|
if row.indicator.get_visible():
|
|
row.indicator.set_visible(False)
|