diff --git a/src/alpaca.gresource.xml b/src/alpaca.gresource.xml index 4a6b21e..50cd23b 100644 --- a/src/alpaca.gresource.xml +++ b/src/alpaca.gresource.xml @@ -29,6 +29,7 @@ icons/edit-symbolic.svg icons/image-missing-symbolic.svg icons/update-symbolic.svg + icons/down-symbolic.svg window.ui gtk/help-overlay.ui diff --git a/src/icons/down-symbolic.svg b/src/icons/down-symbolic.svg new file mode 100644 index 0000000..652f8c0 --- /dev/null +++ b/src/icons/down-symbolic.svg @@ -0,0 +1,2 @@ + + diff --git a/src/style.css b/src/style.css index e1f8711..bbc20df 100644 --- a/src/style.css +++ b/src/style.css @@ -12,3 +12,9 @@ border-radius: 5px; padding: 5px; } +.model_list_box { + padding: 0; +} +.model_list_box > * { + margin: 0; +} diff --git a/src/window.py b/src/window.py index 4bdeddd..e38b34c 100644 --- a/src/window.py +++ b/src/window.py @@ -107,8 +107,6 @@ class AlpacaWindow(Adw.ApplicationWindow): file_filter_gguf = Gtk.Template.Child() file_filter_attachments = Gtk.Template.Child() attachment_button = Gtk.Template.Child() - model_drop_down = Gtk.Template.Child() - model_string_list = Gtk.Template.Child() chat_right_click_menu = Gtk.Template.Child() model_tag_list_box = Gtk.Template.Child() navigation_view_manage_models = Gtk.Template.Child() @@ -118,6 +116,9 @@ class AlpacaWindow(Adw.ApplicationWindow): model_searchbar = Gtk.Template.Child() no_results_page = Gtk.Template.Child() model_link_button = Gtk.Template.Child() + model_list_box = Gtk.Template.Child() + model_popover = Gtk.Template.Child() + model_selector_button = Gtk.Template.Child() manage_models_dialog = Gtk.Template.Child() pulling_model_list_box = Gtk.Template.Child() @@ -136,22 +137,6 @@ class AlpacaWindow(Adw.ApplicationWindow): style_manager = Adw.StyleManager() - @Gtk.Template.Callback() - def verify_if_image_can_be_used(self, pspec=None, user_data=None): - logger.debug("Verifying if image can be used") - if self.model_drop_down.get_selected_item() == None: - return True - selected = self.convert_model_name(self.model_drop_down.get_selected_item().get_string(), 1).split(":")[0] - if selected in [key for key, value in self.available_models.items() if value["image"]]: - for name, content in self.attachments.items(): - if content['type'] == 'image': - content['button'].set_css_classes(["flat"]) - return True - for name, content in self.attachments.items(): - if content['type'] == 'image': - content['button'].set_css_classes(["flat", "error"]) - return False - @Gtk.Template.Callback() def stop_message(self, button=None): if self.loading_spinner: @@ -191,7 +176,7 @@ class AlpacaWindow(Adw.ApplicationWindow): self.chats['order'].remove(self.chats['selected_chat']) self.chats['order'].insert(0, self.chats['selected_chat']) self.save_history() - current_model = self.convert_model_name(self.model_drop_down.get_selected_item().get_string(), 1) + current_model = self.get_current_model(1) if current_model is None: self.show_toast(_("Please select a model before chatting"), self.main_overlay) return @@ -288,10 +273,9 @@ class AlpacaWindow(Adw.ApplicationWindow): self.load_history_into_chat() if len(self.chats["chats"][self.chats["selected_chat"]]["messages"].keys()) > 0: last_model_used = self.chats["chats"][self.chats["selected_chat"]]["messages"][list(self.chats["chats"][self.chats["selected_chat"]]["messages"].keys())[-1]]["model"] - last_model_used = self.convert_model_name(last_model_used, 0) - for i in range(self.model_string_list.get_n_items()): - if self.model_string_list.get_string(i) == last_model_used: - self.model_drop_down.set_selected(i) + for i, m in enumerate(self.local_models): + if m == last_model_used: + self.model_list_box.select_row(self.model_list_box.get_row_at_index(i)) break self.save_history() @@ -427,11 +411,69 @@ class AlpacaWindow(Adw.ApplicationWindow): self.available_model_list_box.set_visible(True) self.no_results_page.set_visible(False) + @Gtk.Template.Callback() + def close_model_popup(self, *_): + self.model_popover.hide() + + @Gtk.Template.Callback() + def change_model(self, listbox=None, row=None): + if not row: + current_model = self.convert_model_name(self.model_selector_button.get_child().get_label(), 1) + print("c ", current_model) + for i, m in enumerate(self.local_models): + if m == current_model: + self.model_list_box.select_row(self.model_list_box.get_row_at_index(i)) + return + self.model_list_box.select_row(self.model_list_box.get_row_at_index(0)) + return + button_content = Gtk.Box( + spacing=10 + ) + button_content.append( + Gtk.Label( + label=row.get_child().get_label(), + ellipsize=2 + ) + ) + button_content.append( + Gtk.Image.new_from_icon_name("down-symbolic") + ) + self.model_selector_button.set_child(button_content) + self.close_model_popup() + self.verify_if_image_can_be_used() + + def verify_if_image_can_be_used(self): + logger.debug("Verifying if image can be used") + selected = self.get_current_model(1) + if selected == None: + return True + selected = selected.split(":")[0] + if selected in [key for key, value in self.available_models.items() if value["image"]]: + for name, content in self.attachments.items(): + if content['type'] == 'image': + content['button'].set_css_classes(["flat"]) + return True + for name, content in self.attachments.items(): + if content['type'] == 'image': + content['button'].set_css_classes(["flat", "error"]) + return False + def convert_model_name(self, name:str, mode:int) -> str: # mode=0 name:tag -> Name (tag) | mode=1 Name (tag) -> name:tag + try: + if mode == 0: + return "{} ({})".format(name.split(":")[0].replace("-", " ").title(), name.split(":")[1]) + if mode == 1: + return "{}:{}".format(name.split(" (")[0].replace(" ", "-").lower(), name.split(" (")[1][:-1]) + except Exception as e: + pass + + def get_current_model(self, mode:int) -> str: + if not self.model_list_box.get_selected_row(): + return None if mode == 0: - return "{} ({})".format(name.split(":")[0].replace("-", " ").title(), name.split(":")[1]) + return self.model_list_box.get_selected_row().get_child().get_label() if mode == 1: - return "{}:{}".format(name.split(" (")[0].replace(" ", "-").lower(), name.split(" (")[1][:-1]) + return self.model_list_box.get_selected_row().get_name() def check_alphanumeric(self, editable, text, length, position): new_text = ''.join([char for char in text if char.isalnum() or char in ['-', '.', ':', '_']]) @@ -593,7 +635,7 @@ Generate a title following these rules: ```PROMPT {message['content']} ```""" - current_model = self.convert_model_name(self.model_drop_down.get_selected_item().get_string(), 1) + current_model = self.get_current_model(1) data = {"model": current_model, "prompt": prompt, "stream": False} if 'images' in message: data["images"] = message['images'] @@ -776,8 +818,7 @@ Generate a title following these rules: logger.debug("Updating list of local models") self.local_models = [] response = connection_handler.simple_get(f"{connection_handler.url}/api/tags") - for i in range(self.model_string_list.get_n_items() -1, -1, -1): - self.model_string_list.remove(i) + self.model_list_box.remove_all() if response.status_code == 200: self.local_model_list_box.remove_all() if len(json.loads(response.text)['models']) == 0: @@ -801,9 +842,17 @@ Generate a title following these rules: model_row.add_suffix(button) self.local_model_list_box.append(model_row) - self.model_string_list.append(model_name) + selector_row = Gtk.ListBoxRow( + child = Gtk.Label( + label=model_name, halign=1, hexpand=True + ), + halign=0, + hexpand=True, + name=model["name"], + tooltip_text=model_name + ) + self.model_list_box.append(selector_row) self.local_models.append(model["name"]) - #self.verify_if_image_can_be_used() else: self.connection_error() @@ -1058,7 +1107,7 @@ Generate a title following these rules: if message_id in self.chats["chats"][self.chats["selected_chat"]]["messages"]: del self.chats["chats"][self.chats["selected_chat"]]["messages"][message_id] data = { - "model": self.convert_model_name(self.model_drop_down.get_selected_item().get_string(), 1), + "model": self.get_current_model(1), "messages": history, "options": {"temperature": self.model_tweaks["temperature"], "seed": self.model_tweaks["seed"]}, "keep_alive": f"{self.model_tweaks['keep_alive']}m" @@ -1091,6 +1140,7 @@ Generate a title following these rules: data = {"name": model} response = connection_handler.stream_post(f"{connection_handler.url}/api/pull", data=json.dumps(data), callback=lambda data, model_name=model: self.pull_model_update(data, model_name)) GLib.idle_add(self.update_list_local_models) + GLib.idle_add(self.change_model) if response.status_code == 200 and 'error' not in self.pulling_models[model]: GLib.idle_add(self.show_notification, _("Task Complete"), _("Model '{}' pulled successfully.").format(model), Gio.ThemedIcon.new("emblem-ok-symbolic")) @@ -1239,10 +1289,9 @@ Generate a title following these rules: self.chats["order"].append(chat_name) if len(self.chats["chats"][self.chats["selected_chat"]]["messages"].keys()) > 0: last_model_used = self.chats["chats"][self.chats["selected_chat"]]["messages"][list(self.chats["chats"][self.chats["selected_chat"]]["messages"].keys())[-1]]["model"] - last_model_used = self.convert_model_name(last_model_used, 0) - for i in range(self.model_string_list.get_n_items()): - if self.model_string_list.get_string(i) == last_model_used: - self.model_drop_down.set_selected(i) + for i, m in enumerate(self.local_models): + if m == last_model_used: + self.model_list_box.select_row(self.model_list_box.get_row_at_index(i)) break except Exception as e: logger.error(e) @@ -1322,6 +1371,7 @@ Generate a title following these rules: self.update_list_local_models() if response.status_code == 200: self.show_toast(_("Model deleted successfully"), self.manage_models_overlay) + self.change_model() else: self.manage_models_dialog.close() self.connection_error() @@ -1644,25 +1694,6 @@ Generate a title following these rules: clipboard.read_text_async(None, self.cb_text_received) clipboard.read_texture_async(None, self.cb_image_received) - - def on_model_dropdown_setup(self, factory, list_item): - label = Gtk.Label() - label.set_ellipsize(2) - label.set_xalign(0) - list_item.set_child(label) - - def on_model_dropdown_bind(self, factory, list_item): - label = list_item.get_child() - item = list_item.get_item() - label.set_text(item.get_string()) - label.set_tooltip_text(item.get_string()) - - def setup_model_dropdown(self): - factory = Gtk.SignalListItemFactory() - factory.connect("setup", self.on_model_dropdown_setup) - factory.connect("bind", self.on_model_dropdown_bind) - self.model_drop_down.set_factory(factory) - def handle_enter_key(self): self.send_message() return True @@ -1701,7 +1732,6 @@ Generate a title following these rules: self.remote_connection_entry.connect("entry-activated", lambda entry : entry.set_css_classes([])) self.remote_connection_switch.connect("notify", lambda pspec, user_data : self.connection_switched()) self.background_switch.connect("notify", lambda pspec, user_data : self.switch_run_on_background()) - self.setup_model_dropdown() if os.path.exists(os.path.join(self.config_dir, "server.json")): with open(os.path.join(self.config_dir, "server.json"), "r", encoding="utf-8") as f: data = json.load(f) diff --git a/src/window.ui b/src/window.ui index 2b3a665..82a6de9 100644 --- a/src/window.ui +++ b/src/window.ui @@ -76,30 +76,64 @@ 0 12 - - - 260 - false - Select Model - - - - + + Select Model + + + + + (None) + 2 + + + + + down-symbolic + + + + + 1 + + + + + + 1 + 10 + + + true + Manage Models + Manage Models + app.manage_models + + + + + + + + + + true + + + + + + -