From 7026655116c1ef6940f890cca765cc57334d331b Mon Sep 17 00:00:00 2001 From: jeffser Date: Sat, 31 Aug 2024 17:14:39 -0600 Subject: [PATCH] New instance manager 'ollama_instance' --- src/connection_handler.py | 105 ++++++++-- src/custom_widgets/chat_widget.py | 18 +- src/custom_widgets/model_widget.py | 38 ++-- src/dialogs.py | 21 +- src/local_instance.py | 61 ------ src/main.py | 2 +- src/meson.build | 1 - src/window.py | 314 ++++++++++++----------------- src/window.ui | 12 +- 9 files changed, 272 insertions(+), 300 deletions(-) delete mode 100644 src/local_instance.py diff --git a/src/connection_handler.py b/src/connection_handler.py index 97c3585..4537361 100644 --- a/src/connection_handler.py +++ b/src/connection_handler.py @@ -2,33 +2,94 @@ """ Handles requests to remote and integrated instances of Ollama """ -import json -import requests +import json, os, requests, subprocess, threading +from .internal import data_dir, cache_dir +from logging import getLogger +from time import sleep + +logger = getLogger(__name__) #OK=200 response.status_code URL = None BEARER_TOKEN = None -def get_headers(include_json:bool) -> dict: - headers = {} - if include_json: - headers["Content-Type"] = "application/json" - if BEARER_TOKEN: - headers["Authorization"] = "Bearer {}".format(BEARER_TOKEN) - return headers if len(headers.keys()) > 0 else None +def log_output(pipe): + with open(os.path.join(data_dir, 'tmp.log'), 'a') as f: + with pipe: + try: + for line in iter(pipe.readline, ''): + print(line, end='') + f.write(line) + f.flush() + except: + pass -def simple_get(connection_url:str) -> dict: - return requests.get(connection_url, headers=get_headers(False)) +class instance(): -def simple_post(connection_url:str, data) -> dict: - return requests.post(connection_url, headers=get_headers(True), data=data, stream=False) + def __init__(self, local_port:int, remote_url:str, remote:bool, tweaks:dict, overrides:dict, bearer_token:str=None): + self.local_port=local_port + self.remote_url=remote_url + self.remote=remote + self.tweaks=tweaks + self.overrides=overrides + self.bearer_token=bearer_token + self.instance = None + if not self.remote: + self.start() -def simple_delete(connection_url:str, data) -> dict: - return requests.delete(connection_url, headers=get_headers(False), json=data) + def get_headers(self, include_json:bool) -> dict: + headers = {} + if include_json: + headers["Content-Type"] = "application/json" + if self.bearer_token and self.remote: + headers["Authorization"] = "Bearer " + self.bearer_token + return headers if len(headers.keys()) > 0 else None -def stream_post(connection_url:str, data, callback:callable) -> dict: - response = requests.post(connection_url, headers=get_headers(True), data=data, stream=True) - if response.status_code == 200: - for line in response.iter_lines(): - if line: - callback(json.loads(line.decode("utf-8"))) - return response + def request(self, connection_type:str, connection_url:str, data:dict=None, callback:callable=None) -> requests.models.Response: + connection_url = '{}/{}'.format(self.remote_url if self.remote else 'http://127.0.0.1:{}'.format(self.local_port), connection_url) + logger.info('Connection: {} : {}'.format(connection_type, connection_url)) + match connection_type: + case "GET": + return requests.get(connection_url, headers=self.get_headers(False)) + case "POST": + if callback: + response = requests.post(connection_url, headers=self.get_headers(True), data=data, stream=True) + if response.status_code == 200: + for line in response.iter_lines(): + if line: + callback(json.loads(line.decode("utf-8"))) + return response + else: + return requests.post(connection_url, headers=self.get_headers(True), data=data, stream=False) + case "DELETE": + return requests.delete(connection_url, headers=self.get_headers(False), json=data) + + def start(self): + if not os.path.isdir(os.path.join(cache_dir, 'tmp/ollama')): + os.mkdir(os.path.join(cache_dir, 'tmp/ollama')) + + params = self.overrides.copy() + params["OLLAMA_HOST"] = f"127.0.0.1:{self.local_port}" # You can't change this directly sorry :3 + params["HOME"] = data_dir + params["TMPDIR"] = os.path.join(cache_dir, 'tmp/ollama') + self.instance = subprocess.Popen(["ollama", "serve"], env={**os.environ, **params}, stderr=subprocess.PIPE, stdout=subprocess.PIPE, text=True) + threading.Thread(target=log_output, args=(self.instance.stdout,)).start() + threading.Thread(target=log_output, args=(self.instance.stderr,)).start() + logger.info("Starting Alpaca's Ollama instance...") + logger.debug(params) + logger.info("Started Alpaca's Ollama instance") + v_str = subprocess.check_output("ollama -v", shell=True).decode('utf-8') + logger.info('Ollama version: {}'.format(v_str.split('client version is ')[1].strip())) + + def stop(self): + logger.info("Stopping Alpaca's Ollama instance") + if self.instance: + self.instance.terminate() + self.instance.wait() + self.instance = None + logger.info("Stopped Alpaca's Ollama instance") + + def reset(self): + logger.info("Resetting Alpaca's Ollama instance") + self.stop() + sleep(1) + self.start() diff --git a/src/custom_widgets/chat_widget.py b/src/custom_widgets/chat_widget.py index 9db1b33..0f39f54 100644 --- a/src/custom_widgets/chat_widget.py +++ b/src/custom_widgets/chat_widget.py @@ -217,9 +217,25 @@ class chat_tab(Gtk.ListBoxRow): ) self.gesture = Gtk.GestureClick(button=3) - self.gesture.connect("released", window.chat_click_handler) + self.gesture.connect("released", self.chat_click_handler) self.add_controller(self.gesture) + def chat_click_handler(self, gesture, n_press, x, y): + chat_row = gesture.get_widget() + popover = Gtk.PopoverMenu( + menu_model=window.chat_right_click_menu, + has_arrow=False, + halign=1, + height_request=155 + ) + window.selected_chat_row = chat_row + position = Gdk.Rectangle() + position.x = x + position.y = y + popover.set_parent(chat_row.get_child()) + popover.set_pointing_to(position) + popover.popup() + class chat_list(Gtk.ListBox): __gtype_name__ = 'AlpacaChatList' diff --git a/src/custom_widgets/model_widget.py b/src/custom_widgets/model_widget.py index 8fb4cd8..69284c4 100644 --- a/src/custom_widgets/model_widget.py +++ b/src/custom_widgets/model_widget.py @@ -9,7 +9,7 @@ gi.require_version('GtkSource', '5') from gi.repository import Gtk, GObject, Gio, Adw, GtkSource, GLib, Gdk import logging, os, datetime, re, shutil, threading, json, sys from ..internal import config_dir, data_dir, cache_dir, source_dir -from .. import connection_handler, available_models_descriptions, dialogs +from .. import available_models_descriptions, dialogs logger = logging.getLogger(__name__) @@ -352,8 +352,7 @@ class available_model(Gtk.ListBoxRow): self.add_controller(event_controller_key) def confirm_pull_model(self, model_name): - ##TODO I really need that instance manager - threading.Thread(target=window.model_manager.pull_model, args=('http://0.0.0.0:11435', model_name)).start() + threading.Thread(target=window.model_manager.pull_model, args=(model_name,)).start() window.navigation_view_manage_models.pop() def show_pull_menu(self): @@ -433,7 +432,7 @@ class model_manager_container(Gtk.Box): def remove_local_model(self, model_name:str): logger.debug("Deleting model") - response = connection_handler.simple_delete(f"{connection_handler.URL}/api/delete", data={"name": model_name}) + response = window.ollama_instance.request("DELETE", "api/delete", json.dumps({"name": model_name})) if response.status_code == 200: self.local_list.remove_model(model_name) @@ -455,17 +454,21 @@ class model_manager_container(Gtk.Box): #Should only be called when the app starts def update_local_list(self): - response = connection_handler.simple_get(f"{connection_handler.URL}/api/tags") - if response.status_code == 200: - self.local_list.remove_all() - data = json.loads(response.text) - if len(data['models']) == 0: - self.local_list.set_visible(False) + try: + response = window.ollama_instance.request("GET", "api/tags") + if response.status_code == 200: + self.local_list.remove_all() + data = json.loads(response.text) + if len(data['models']) == 0: + self.local_list.set_visible(False) + else: + self.local_list.set_visible(True) + for model in data['models']: + self.add_local_model(model['name']) else: - self.local_list.set_visible(True) - for model in data['models']: - self.add_local_model(model['name']) - else: + window.connection_error() + except Exception as e: + logger.error(e) window.connection_error() #Should only be called when the app starts @@ -494,8 +497,7 @@ class model_manager_container(Gtk.Box): content['button'].set_css_classes(["flat", "error"]) return False - #Important: Call this using a thread, if not the app crashes - def pull_model(self, url:str, model_name:str, modelfile:str=None): ##TODO, once you make an instance manager remove the url from this + def pull_model(self, model_name:str, modelfile:str=None): if ':' not in model_name: model_name += ':latest' if model_name not in [model.get_name() for model in list(self.pulling_list)] and model_name not in [model.get_name() for model in list(self.local_list)]: @@ -506,9 +508,9 @@ class model_manager_container(Gtk.Box): GLib.idle_add(self.pulling_list.set_visible, True) if modelfile: - response = connection_handler.stream_post("{}/api/create".format(url), data=json.dumps({"name": model_name, "modelfile": modelfile}), callback=lambda data: model.update(data)) + response = self.ollama_instance.request("POST", "api/create", json.dumps({"name": model_name, "modelfile": modelfile}), lambda data: model.update(data)) else: - response = connection_handler.stream_post("{}/api/pull".format(url), data=json.dumps({"name": model_name}), callback=lambda data: model.update(data)) + response = self.ollama_instance.request("POST", "api/pull", json.dumps({"name": model_name}), lambda data: model.update(data)) if response.status_code == 200 and not model.error: GLib.idle_add(window.show_notification, _("Task Complete"), _("Model '{}' pulled successfully.").format(model_name), Gio.ThemedIcon.new("emblem-ok-symbolic")) diff --git a/src/dialogs.py b/src/dialogs.py index 70834e0..7352f4e 100644 --- a/src/dialogs.py +++ b/src/dialogs.py @@ -3,11 +3,10 @@ Handles UI dialogs """ import os -import logging +import logging, requests from pytube import YouTube from html2text import html2text from gi.repository import Adw, Gtk -from . import connection_handler logger = logging.getLogger(__name__) # CLEAR CHAT | WORKS @@ -188,21 +187,25 @@ def remove_attached_file(self, name): def reconnect_remote_response(self, dialog, task, url_entry, bearer_entry): response = dialog.choose_finish(task) if not task or response == "remote": - self.connect_remote(url_entry.get_text(), bearer_entry.get_text()) + self.ollama_instance.remote_url = url_entry.get_text() + self.ollama_instance.bearer_token = bearer_entry.get_text() + self.ollama_instance.remote = True + self.model_manager.update_local_list() elif response == "local": - self.connect_local() + self.ollama_instance.remote = False + self.model_manager.update_local_list() elif response == "close": self.destroy() -def reconnect_remote(self, current_url, current_bearer_token): +def reconnect_remote(self): entry_url = Gtk.Entry( css_classes = ["error"], - text = current_url, + text = self.ollama_instance.remote_url, placeholder_text = "URL" ) entry_bearer_token = Gtk.Entry( - css_classes = ["error"] if current_bearer_token else None, - text = current_bearer_token, + css_classes = ["error"] if self.ollama_instance.bearer_token else None, + text = self.ollama_instance.bearer_token, placeholder_text = "Bearer Token (Optional)" ) container = Gtk.Box( @@ -374,7 +377,7 @@ def youtube_caption(self, video_url): def attach_website_response(self, dialog, task, url): if dialog.choose_finish(task) == "accept": - response = connection_handler.simple_get(url) + response = requests.get(url) if response.status_code == 200: html = response.text md = html2text(html) diff --git a/src/local_instance.py b/src/local_instance.py deleted file mode 100644 index dbb03a6..0000000 --- a/src/local_instance.py +++ /dev/null @@ -1,61 +0,0 @@ -# local_instance.py -""" -Handles running, stopping and resetting the integrated Ollama instance -""" -import subprocess -import threading -import os -from time import sleep -from logging import getLogger -from .internal import data_dir, cache_dir - - -logger = getLogger(__name__) - -instance = None -port = 11435 -overrides = {} - -def log_output(pipe): - with open(os.path.join(data_dir, 'tmp.log'), 'a') as f: - with pipe: - try: - for line in iter(pipe.readline, ''): - print(line, end='') - f.write(line) - f.flush() - except: - pass - -def start(): - if not os.path.isdir(os.path.join(cache_dir, 'tmp/ollama')): - os.mkdir(os.path.join(cache_dir, 'tmp/ollama')) - global instance - params = overrides.copy() - params["OLLAMA_HOST"] = f"127.0.0.1:{port}" # You can't change this directly sorry :3 - params["HOME"] = data_dir - params["TMPDIR"] = os.path.join(cache_dir, 'tmp/ollama') - instance = subprocess.Popen(["ollama", "serve"], env={**os.environ, **params}, stderr=subprocess.PIPE, stdout=subprocess.PIPE, text=True) - threading.Thread(target=log_output, args=(instance.stdout,)).start() - threading.Thread(target=log_output, args=(instance.stderr,)).start() - logger.info("Starting Alpaca's Ollama instance...") - logger.debug(params) - sleep(1) - logger.info("Started Alpaca's Ollama instance") - v_str = subprocess.check_output("ollama -v", shell=True).decode('utf-8') - logger.info('Ollama version: {}'.format(v_str.split('client version is ')[1].strip())) - -def stop(): - logger.info("Stopping Alpaca's Ollama instance") - global instance - if instance: - instance.terminate() - instance.wait() - instance = None - logger.info("Stopped Alpaca's Ollama instance") - -def reset(): - logger.info("Resetting Alpaca's Ollama instance") - stop() - sleep(1) - start() diff --git a/src/main.py b/src/main.py index 7066681..c86052c 100644 --- a/src/main.py +++ b/src/main.py @@ -57,7 +57,7 @@ class AlpacaApplication(Adw.Application): super().__init__(application_id='com.jeffser.Alpaca', flags=Gio.ApplicationFlags.DEFAULT_FLAGS) self.create_action('quit', lambda *_: self.props.active_window.closing_app(None), ['w', 'q']) - self.create_action('preferences', lambda *_: AlpacaWindow.show_preferences_dialog(self.props.active_window), ['comma']) + self.create_action('preferences', lambda *_: self.props.active_window.preferences_dialog.present(self.props.active_window), ['comma']) self.create_action('about', self.on_about_action) self.version = version diff --git a/src/meson.build b/src/meson.build index aae28e8..1655fc4 100644 --- a/src/meson.build +++ b/src/meson.build @@ -41,7 +41,6 @@ alpaca_sources = [ 'window.py', 'connection_handler.py', 'dialogs.py', - 'local_instance.py', 'available_models.json', 'available_models_descriptions.py', 'internal.py' diff --git a/src/window.py b/src/window.py index 2e1123a..10ad53f 100644 --- a/src/window.py +++ b/src/window.py @@ -31,7 +31,7 @@ gi.require_version('GdkPixbuf', '2.0') from gi.repository import Adw, Gtk, Gdk, GLib, GtkSource, Gio, GdkPixbuf -from . import dialogs, local_instance, connection_handler, available_models_descriptions +from . import dialogs, connection_handler from .custom_widgets import table_widget, message_widget, chat_widget, model_widget from .internal import config_dir, data_dir, cache_dir, source_dir @@ -40,9 +40,6 @@ logger = logging.getLogger(__name__) @Gtk.Template(resource_path='/com/jeffser/Alpaca/window.ui') class AlpacaWindow(Adw.ApplicationWindow): app_dir = os.getenv("FLATPAK_DEST") - config_dir = config_dir - data_dir = data_dir - cache_dir = cache_dir __gtype_name__ = 'AlpacaWindow' @@ -53,17 +50,14 @@ class AlpacaWindow(Adw.ApplicationWindow): _ = gettext.gettext #Variables - available_models = None - run_on_background = False - remote_url = "" - remote_bearer_token = "" - run_remote = False + model_tweaks = {"temperature": 0.7, "seed": 0, "keep_alive": 5} pulling_models = {} attachments = {} header_bar = Gtk.Template.Child() #Override elements + overrides_group = Gtk.Template.Child() override_HSA_OVERRIDE_GFX_VERSION = Gtk.Template.Child() override_CUDA_VISIBLE_DEVICES = Gtk.Template.Child() override_HIP_VISIBLE_DEVICES = Gtk.Template.Child() @@ -76,9 +70,7 @@ class AlpacaWindow(Adw.ApplicationWindow): create_model_name = Gtk.Template.Child() create_model_system = Gtk.Template.Child() create_model_modelfile = Gtk.Template.Child() - temperature_spin = Gtk.Template.Child() - seed_spin = Gtk.Template.Child() - keep_alive_spin = Gtk.Template.Child() + tweaks_group = Gtk.Template.Child() preferences_dialog = Gtk.Template.Child() shortcut_window : Gtk.ShortcutsWindow = Gtk.Template.Child() file_preview_dialog = Gtk.Template.Child() @@ -115,10 +107,10 @@ class AlpacaWindow(Adw.ApplicationWindow): chat_list_container = Gtk.Template.Child() chat_list_box = None + ollama_instance = None + model_manager = None add_chat_button = Gtk.Template.Child() - loading_spinner = None - background_switch = Gtk.Template.Child() remote_connection_switch = Gtk.Template.Child() remote_connection_entry = Gtk.Template.Child() @@ -155,12 +147,12 @@ class AlpacaWindow(Adw.ApplicationWindow): for name, content in self.attachments.items(): if content["type"] == 'image': if self.model_manager.verify_if_image_can_be_used(): - attached_images.append(os.path.join(self.data_dir, "chats", current_chat.get_name(), message_id, name)) + attached_images.append(os.path.join(data_dir, "chats", current_chat.get_name(), message_id, name)) else: - attached_files[os.path.join(self.data_dir, "chats", current_chat.get_name(), message_id, name)] = content['type'] - if not os.path.exists(os.path.join(self.data_dir, "chats", current_chat.get_name(), message_id)): - os.makedirs(os.path.join(self.data_dir, "chats", current_chat.get_name(), message_id)) - shutil.copy(content['path'], os.path.join(self.data_dir, "chats", current_chat.get_name(), message_id, name)) + attached_files[os.path.join(data_dir, "chats", current_chat.get_name(), message_id, name)] = content['type'] + if not os.path.exists(os.path.join(data_dir, "chats", current_chat.get_name(), message_id)): + os.makedirs(os.path.join(data_dir, "chats", current_chat.get_name(), message_id)) + shutil.copy(content['path'], os.path.join(data_dir, "chats", current_chat.get_name(), message_id, name)) content["button"].get_parent().remove(content["button"]) self.attachments = {} self.attachment_box.set_visible(False) @@ -179,8 +171,8 @@ class AlpacaWindow(Adw.ApplicationWindow): data = { "model": current_model, "messages": self.convert_history_to_ollama(current_chat), - "options": {"temperature": self.model_tweaks["temperature"], "seed": self.model_tweaks["seed"]}, - "keep_alive": f"{self.model_tweaks['keep_alive']}m" + "options": {"temperature": self.ollama_instance.tweaks["temperature"], "seed": self.ollama_instance.tweaks["seed"]}, + "keep_alive": f"{self.ollama_instance.tweaks['keep_alive']}m" } self.message_text_view.get_buffer().set_text("", 0) @@ -221,32 +213,41 @@ class AlpacaWindow(Adw.ApplicationWindow): self.welcome_carousel.scroll_to(self.welcome_carousel.get_nth_page(self.welcome_carousel.get_position()+1), True) else: self.welcome_dialog.force_close() - if not self.verify_connection(): - self.connection_error() + + @Gtk.Template.Callback() + def change_remote_connection(self, switcher, *_): + logger.debug("Connection switched") + self.ollama_instance.remote = self.remote_connection_switch.get_active() + if self.model_manager: + self.model_manager.update_local_list() + self.save_server_config() @Gtk.Template.Callback() def change_remote_url(self, entry): if not entry.get_text().startswith("http"): entry.set_text("http://{}".format(entry.get_text())) return - self.remote_url = entry.get_text() + if entry.get_text() != entry.get_text().rstrip('/'): + entry.set_text(entry.get_text().rstrip('/')) + return logger.debug(f"Changing remote url: {self.remote_url}") - if self.run_remote: - connection_handler.URL = self.remote_url - if self.verify_connection() == False: - entry.set_css_classes(["error"]) - self.show_toast(_("Failed to connect to server"), self.preferences_dialog) + self.ollama_instance.remote_url = entry.get_text() + if self.ollama_instance.remote and self.model_manager: + self.model_manager.update_local_list() + self.save_server_config() @Gtk.Template.Callback() def change_remote_bearer_token(self, entry): - self.remote_bearer_token = entry.get_text() + self.ollama_instance.bearer_token = entry.get_text() + if self.ollama_instance.remote_url and self.ollama_instance.remote and self.model_manager: + self.model_manager.update_local_list() + self.save_server_config() + + @Gtk.Template.Callback() + def switch_run_on_background(self): + logger.debug("Switching run on background") + self.set_hide_on_close(self.background_switch.get_active()) self.save_server_config() - return - if self.remote_url and self.run_remote: - connection_handler.URL = self.remote_url - if self.verify_connection() == False: - entry.set_css_classes(["error"]) - self.show_toast(_("Failed to connect to server"), self.preferences_dialog) @Gtk.Template.Callback() def closing_app(self, user_data): @@ -256,7 +257,7 @@ class AlpacaWindow(Adw.ApplicationWindow): logger.info("Hiding app...") else: logger.info("Closing app...") - local_instance.stop() + self.ollama_instance.stop() self.get_application().quit() @Gtk.Template.Callback() @@ -266,8 +267,8 @@ class AlpacaWindow(Adw.ApplicationWindow): value = round(value) else: value = round(value, 1) - if self.model_tweaks[spin.get_name()] is not None and self.model_tweaks[spin.get_name()] != value: - self.model_tweaks[spin.get_name()] = value + if self.ollama_instance.tweaks[spin.get_name()] != value: + self.ollama_instance.tweaks[spin.get_name()] = value self.save_server_config() @Gtk.Template.Callback() @@ -279,22 +280,21 @@ class AlpacaWindow(Adw.ApplicationWindow): for line in modelfile_raw.split('\n'): if not line.startswith('SYSTEM') and not line.startswith('FROM'): modelfile.append(line) - threading.Thread(target=self.model_manager.pull_model, kwargs={"url": connection_handler.URL, "model_name": name, "modelfile": '\n'.join(modelfile)}).start() + threading.Thread(target=self.model_manager.pull_model, kwargs={"model_name": name, "modelfile": '\n'.join(modelfile)}).start() self.navigation_view_manage_models.pop() @Gtk.Template.Callback() def override_changed(self, entry): name = entry.get_name() value = entry.get_text() - if (not value and name not in local_instance.overrides) or (value and value in local_instance.overrides and local_instance.overrides[name] == value): - return - if not value: - del local_instance.overrides[name] - else: - local_instance.overrides[name] = value - self.save_server_config() - if not self.run_remote: - local_instance.reset() + if self.ollama_instance: + if value: + self.ollama_instance.overrides[name] = value + elif name in self.ollama_instance.overrides: + del self.ollama_instance.overrides[name] + if not self.ollama_instance.remote: + self.ollama_instance.reset() + self.save_server_config() @Gtk.Template.Callback() def link_button_handler(self, button): @@ -339,7 +339,7 @@ class AlpacaWindow(Adw.ApplicationWindow): modelfile_buffer.delete(modelfile_buffer.get_start_iter(), modelfile_buffer.get_end_iter()) self.create_model_system.set_text('') if not file: - response = connection_handler.simple_post(f"{connection_handler.URL}/api/show", json.dumps({"name": self.convert_model_name(model, 1)})) + response = self.ollama_instance.request("POST", "api/show", json.dumps({"name": self.convert_model_name(model, 1)})) if response.status_code == 200: data = json.loads(response.text) modelfile = [] @@ -426,7 +426,7 @@ class AlpacaWindow(Adw.ApplicationWindow): del new_message['files'] new_message['content'] = '' for name, file_type in message['files'].items(): - file_path = os.path.join(self.data_dir, "chats", chat.get_name(), message_id, name) + file_path = os.path.join(data_dir, "chats", chat.get_name(), message_id, name) file_data = self.get_content_of_file(file_path, file_type) if file_data: new_message['content'] += f"```[{name}]\n{file_data}\n```" @@ -434,7 +434,7 @@ class AlpacaWindow(Adw.ApplicationWindow): if 'images' in message and len(message['images']) > 0: new_message['images'] = [] for name in message['images']: - file_path = os.path.join(self.data_dir, "chats", chat.get_name(), message_id, name) + file_path = os.path.join(data_dir, "chats", chat.get_name(), message_id, name) image_data = self.get_content_of_file(file_path, 'image') if image_data: new_message['images'].append(image_data) @@ -460,19 +460,29 @@ Generate a title following these rules: data = {"model": current_model, "prompt": prompt, "stream": False} if 'images' in message: data["images"] = message['images'] - response = connection_handler.simple_post(f"{connection_handler.URL}/api/generate", data=json.dumps(data)) + response = self.ollama_instance.request("POST", "api/generate", json.dumps(data)) if response.status_code == 200: new_chat_name = json.loads(response.text)["response"].strip().removeprefix("Title: ").removeprefix("title: ").strip('\'"').replace('\n', ' ').title().replace('\'S', '\'s') new_chat_name = new_chat_name[:50] + (new_chat_name[50:] and '...') self.chat_list_box.rename_chat(old_chat_name, new_chat_name) def save_server_config(self): - with open(os.path.join(self.config_dir, "server.json"), "w+", encoding="utf-8") as f: - json.dump({'remote_url': self.remote_url, 'remote_bearer_token': self.remote_bearer_token, 'run_remote': self.run_remote, 'local_port': local_instance.port, 'run_on_background': self.run_on_background, 'model_tweaks': self.model_tweaks, 'ollama_overrides': local_instance.overrides}, f, indent=6) + with open(os.path.join(config_dir, "server.json"), "w+", encoding="utf-8") as f: + data = { + 'remote_url': self.ollama_instance.remote_url, + 'remote_bearer_token': self.ollama_instance.bearer_token, + 'run_remote': self.ollama_instance.remote, + 'local_port': self.ollama_instance.local_port, + 'run_on_background': self.background_switch.get_active(), + 'model_tweaks': self.ollama_instance.tweaks, + 'ollama_overrides': self.ollama_instance.overrides + } + + json.dump(data, f, indent=6) def verify_connection(self): try: - response = connection_handler.simple_get(f"{connection_handler.URL}/api/tags") + response = self.ollama_instance.request("GET", "api/tags") if response.status_code == 200: self.save_server_config() #self.update_list_local_models() @@ -515,7 +525,7 @@ Generate a title following these rules: if self.regenerate_button: GLib.idle_add(self.chat_list_box.get_current_chat().remove, self.regenerate_button) try: - response = connection_handler.stream_post(f"{connection_handler.URL}/api/chat", data=json.dumps(data), callback=lambda data, message_element=message_element: GLib.idle_add(message_element.update_message, data)) + response = self.ollama_instance.request("POST", "api/generate", json.dumps(data), lambda data, message_element=message_element: GLib.idle_add(message_element.update_message, data)) if response.status_code != 200: raise Exception('Network Error') except Exception as e: @@ -528,10 +538,10 @@ Generate a title following these rules: def save_history(self, chat:chat_widget.chat=None): logger.debug("Saving history") history = None - if chat and os.path.exists(os.path.join(self.data_dir, "chats", "chats.json")): + if chat and os.path.exists(os.path.join(data_dir, "chats", "chats.json")): history = {'chats': {chat.get_name(): {'messages': chat.messages_to_dict()}}} try: - with open(os.path.join(self.data_dir, "chats", "chats.json"), "r", encoding="utf-8") as f: + with open(os.path.join(data_dir, "chats", "chats.json"), "r", encoding="utf-8") as f: data = json.load(f) for chat_tab in self.chat_list_box.tab_list: if chat_tab.chat_window.get_name() != chat.get_name(): @@ -545,20 +555,20 @@ Generate a title following these rules: for chat_tab in self.chat_list_box.tab_list: history['chats'][chat_tab.chat_window.get_name()] = {'messages': chat_tab.chat_window.messages_to_dict()} - with open(os.path.join(self.data_dir, "chats", "chats.json"), "w+", encoding="utf-8") as f: + with open(os.path.join(data_dir, "chats", "chats.json"), "w+", encoding="utf-8") as f: json.dump(history, f, indent=4) def load_history(self): logger.debug("Loading history") - if os.path.exists(os.path.join(self.data_dir, "chats", "chats.json")): + if os.path.exists(os.path.join(data_dir, "chats", "chats.json")): try: - with open(os.path.join(self.data_dir, "chats", "chats.json"), "r", encoding="utf-8") as f: + with open(os.path.join(data_dir, "chats", "chats.json"), "r", encoding="utf-8") as f: data = json.load(f) selected_chat = None if len(list(data)) == 0: data['chats'][_("New Chat")] = {"messages": {}} - if os.path.exists(os.path.join(self.data_dir, "chats", "selected_chat.txt")): - with open(os.path.join(self.data_dir, "chats", "selected_chat.txt"), 'r') as scf: + if os.path.exists(os.path.join(data_dir, "chats", "selected_chat.txt")): + with open(os.path.join(data_dir, "chats", "selected_chat.txt"), 'r') as scf: selected_chat = scf.read() elif 'selected_chat' in data and data['selected_chat'] in data['chats']: selected_chat = data['selected_chat'] @@ -602,78 +612,14 @@ Generate a title following these rules: self.pulling_models[model_name]['overlay'].get_parent().get_parent().remove(self.pulling_models[model_name]['overlay'].get_parent()) del self.pulling_models[model_name] - def chat_click_handler(self, gesture, n_press, x, y): - chat_row = gesture.get_widget() - popover = Gtk.PopoverMenu( - menu_model=self.chat_right_click_menu, - has_arrow=False, - halign=1, - height_request=155 - ) - self.selected_chat_row = chat_row - position = Gdk.Rectangle() - position.x = x - position.y = y - popover.set_parent(chat_row.get_child()) - popover.set_pointing_to(position) - popover.popup() - - def show_preferences_dialog(self): - logger.debug("Showing preferences dialog") - self.preferences_dialog.present(self) - - def connect_remote(self, url, bearer_token): - logger.debug(f"Connecting to remote: {url}") - connection_handler.URL = url - connection_handler.BEARER_TOKEN = bearer_token - self.remote_url = connection_handler.URL - self.remote_connection_entry.set_text(self.remote_url) - if self.verify_connection() == False: self.connection_error() - - def connect_local(self): - logger.debug("Connecting to Alpaca's Ollama instance") - self.run_remote = False - connection_handler.BEARER_TOKEN = None - connection_handler.URL = f"http://127.0.0.1:{local_instance.port}" - local_instance.start() - if self.verify_connection() == False: - self.connection_error() - else: - self.remote_connection_switch.set_active(False) - def connection_error(self): logger.error("Connection error") - if self.run_remote: - dialogs.reconnect_remote(self, connection_handler.URL, connection_handler.BEARER_TOKEN) + if self.ollama_instance.remote: + dialogs.reconnect_remote(self) else: - local_instance.reset() + self.ollama_instance.reset() self.show_toast(_("There was an error with the local Ollama instance, so it has been reset"), self.main_overlay) - def connection_switched(self): - logger.debug("Connection switched") - new_value = self.remote_connection_switch.get_active() - if new_value != self.run_remote: - self.run_remote = new_value - if self.run_remote: - connection_handler.BEARER_TOKEN = self.remote_bearer_token - connection_handler.URL = self.remote_url - if self.verify_connection() == False: - self.connection_error() - else: - local_instance.stop() - else: - connection_handler.BEARER_TOKEN = None - connection_handler.URL = f"http://127.0.0.1:{local_instance.port}" - local_instance.start() - if self.verify_connection() == False: - self.connection_error() - - def switch_run_on_background(self): - logger.debug("Switching run on background") - self.run_on_background = self.background_switch.get_active() - self.set_hide_on_close(self.run_on_background) - self.verify_connection() - def get_content_of_file(self, file_path, file_type): if not os.path.exists(file_path): return None if file_type == 'image': @@ -792,11 +738,11 @@ Generate a title following these rules: if texture: if self.model_manager.verify_if_image_can_be_used(): pixbuf = Gdk.pixbuf_get_from_texture(texture) - if not os.path.exists(os.path.join(self.cache_dir, 'tmp/images/')): - os.makedirs(os.path.join(self.cache_dir, 'tmp/images/')) - image_name = self.generate_numbered_name('image.png', os.listdir(os.path.join(self.cache_dir, os.path.join(self.cache_dir, 'tmp/images')))) - pixbuf.savev(os.path.join(self.cache_dir, 'tmp/images/{}'.format(image_name)), "png", [], []) - self.attach_file(os.path.join(self.cache_dir, 'tmp/images/{}'.format(image_name)), 'image') + if not os.path.exists(os.path.join(cache_dir, 'tmp/images/')): + os.makedirs(os.path.join(cache_dir, 'tmp/images/')) + image_name = self.generate_numbered_name('image.png', os.listdir(os.path.join(cache_dir, os.path.join(cache_dir, 'tmp/images')))) + pixbuf.savev(os.path.join(cache_dir, 'tmp/images/{}'.format(image_name)), "png", [], []) + self.attach_file(os.path.join(cache_dir, 'tmp/images/{}'.format(image_name)), 'image') else: self.show_toast(_("Image recognition is only available on specific models"), self.main_overlay) except Exception as e: @@ -840,8 +786,8 @@ Generate a title following these rules: self.chat_list_box = chat_widget.chat_list() self.chat_list_container.set_child(self.chat_list_box) GtkSource.init() - if not os.path.exists(os.path.join(self.data_dir, "chats")): - os.makedirs(os.path.join(self.data_dir, "chats")) + if not os.path.exists(os.path.join(data_dir, "chats")): + os.makedirs(os.path.join(data_dir, "chats")) enter_key_controller = Gtk.EventControllerKey.new() enter_key_controller.connect("key-pressed", lambda controller, keyval, keycode, state: self.handle_enter_key() if keyval==Gdk.KEY_Return and not (state & Gdk.ModifierType.SHIFT_MASK) else None) self.message_text_view.add_controller(enter_key_controller) @@ -873,54 +819,58 @@ Generate a title following these rules: self.message_text_view.connect("paste-clipboard", self.on_clipboard_paste) self.file_preview_remove_button.connect('clicked', lambda button : dialogs.remove_attached_file(self, button.get_name())) self.attachment_button.connect("clicked", lambda button, file_filter=self.file_filter_attachments: dialogs.attach_file(self, file_filter)) - self.create_model_name.get_delegate().connect("insert-text", lambda *_ : self.check_alphanumeric(*_, ['-', '.', '_'])) + self.create_model_name.get_delegate().connect("insert-text", lambda *_: self.check_alphanumeric(*_, ['-', '.', '_'])) self.remote_connection_entry.connect("entry-activated", lambda entry : entry.set_css_classes([])) - self.remote_connection_switch.connect("notify", lambda pspec, user_data : self.connection_switched()) - self.background_switch.connect("notify", lambda pspec, user_data : self.switch_run_on_background()) self.set_focus(self.message_text_view) - if os.path.exists(os.path.join(self.config_dir, "server.json")): - with open(os.path.join(self.config_dir, "server.json"), "r", encoding="utf-8") as f: - data = json.load(f) - self.run_remote = data['run_remote'] - local_instance.port = data['local_port'] - self.remote_url = data['remote_url'] - self.remote_bearer_token = data['remote_bearer_token'] if 'remote_bearer_token' in data else '' - self.run_on_background = data['run_on_background'] - #Model Tweaks - if "model_tweaks" in data: self.model_tweaks = data['model_tweaks'] - self.temperature_spin.set_value(self.model_tweaks['temperature']) - self.seed_spin.set_value(self.model_tweaks['seed']) - self.keep_alive_spin.set_value(self.model_tweaks['keep_alive']) - #Overrides - if "ollama_overrides" in data: - local_instance.overrides = data['ollama_overrides'] - for element in [ - self.override_HSA_OVERRIDE_GFX_VERSION, - self.override_CUDA_VISIBLE_DEVICES, - self.override_HIP_VISIBLE_DEVICES]: - override = element.get_name() - if override in local_instance.overrides: - element.set_text(local_instance.overrides[override]) + if os.path.exists(os.path.join(config_dir, "server.json")): + try: + with open(os.path.join(config_dir, "server.json"), "r", encoding="utf-8") as f: + data = json.load(f) + self.ollama_instance = connection_handler.instance(data['local_port'], data['remote_url'], data['run_remote'], data['model_tweaks'], data['ollama_overrides'], data['remote_bearer_token']) - self.background_switch.set_active(self.run_on_background) - self.set_hide_on_close(self.run_on_background) - self.remote_connection_entry.set_text(self.remote_url) - self.remote_bearer_token_entry.set_text(self.remote_bearer_token) - if self.run_remote: - connection_handler.BEARER_TOKEN = self.remote_bearer_token - connection_handler.URL = self.remote_url - self.remote_connection_switch.set_active(True) - else: - connection_handler.BEARER_TOKEN = None - self.remote_connection_switch.set_active(False) - connection_handler.URL = f"http://127.0.0.1:{local_instance.port}" - local_instance.start() - else: - local_instance.start() - connection_handler.URL = f"http://127.0.0.1:{local_instance.port}" + #self.run_remote = data['run_remote'] + #local_instance.port = data['local_port'] + #self.remote_url = data['remote_url'] + #self.remote_bearer_token = data['remote_bearer_token'] if 'remote_bearer_token' in data else '' + #self.run_on_background = data['run_on_background'] + ##Model Tweaks + #if "model_tweaks" in data: self.model_tweaks = data['model_tweaks'] + #self.temperature_spin.set_value(self.model_tweaks['temperature']) + #self.seed_spin.set_value(self.model_tweaks['seed']) + #self.keep_alive_spin.set_value(self.model_tweaks['keep_alive']) + #Overrides + #if "ollama_overrides" in data: + #local_instance.overrides = data['ollama_overrides'] + + for element in list(list(list(list(self.tweaks_group)[0])[1])[0]): + if element.get_name() in self.ollama_instance.tweaks: + element.set_value(self.ollama_instance.tweaks[element.get_name()]) + + for element in list(list(list(list(self.overrides_group)[0])[1])[0]): + if element.get_name() in self.ollama_instance.overrides: + element.set_text(self.ollama_instance.overrides[element.get_name()]) + + self.background_switch.set_active(data['run_on_background']) + self.set_hide_on_close(self.background_switch.get_active()) + self.remote_connection_entry.set_text(self.ollama_instance.remote_url) + self.remote_bearer_token_entry.set_text(self.ollama_instance.bearer_token) + self.remote_connection_switch.set_active(self.ollama_instance.remote) + #if self.run_remote: + #connection_handler.BEARER_TOKEN = self.remote_bearer_token + #connection_handler.URL = self.remote_url + #self.remote_connection_switch.set_active(True) + #else: + #connection_handler.BEARER_TOKEN = None + #self.remote_connection_switch.set_active(False) + #connection_handler.URL = f"http://127.0.0.1:{local_instance.port}" + #local_instance.start() + except Exception as e: + logger.error(e) + if not self.ollama_instance: + self.ollama_instance = connection_handler.instance(11435, '', False, {'temperature': 0.7, 'seed': 0, 'keep_alive': 5}, {}, None) + self.save_server_config() self.welcome_dialog.present(self) - if self.verify_connection() is False: - self.connection_error() + self.model_manager = model_widget.model_manager_container() self.model_scroller.set_child(self.model_manager) self.model_manager.update_local_list() diff --git a/src/window.ui b/src/window.ui index 08ca3b8..7161f1f 100644 --- a/src/window.ui +++ b/src/window.ui @@ -234,6 +234,7 @@ + Use Remote Connection to Ollama @@ -257,15 +258,16 @@ + Run Alpaca In Background - + - + temperature Temperature @@ -281,7 +283,7 @@ - + seed Seed @@ -296,7 +298,7 @@ - + keep_alive Keep Alive Time @@ -319,7 +321,7 @@ Ollama Instance brain-augemnted-symbolic - + Ollama Overrides Manage the arguments used on Ollama, any changes on this page only applies to the integrated instance, the instance will restart if you make changes.