From 527f7b445253775c98bb47affad89314e1b49856 Mon Sep 17 00:00:00 2001 From: jeffser Date: Fri, 28 Jun 2024 21:29:36 -0600 Subject: [PATCH] Added bearer token --- src/connection_handler.py | 23 +++++++++++++---------- src/window.py | 24 ++++++++++++++++++++++-- src/window.ui | 9 ++++++++- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/connection_handler.py b/src/connection_handler.py index 0119c8f..8b43136 100644 --- a/src/connection_handler.py +++ b/src/connection_handler.py @@ -2,10 +2,19 @@ import json, requests url = None +bearer_token = None + +def get_headers() -> dict: + headers = { + "Content-Type": "application/json" + } + if bearer_token: + headers["Authorization"] = "Bearer " + bearer_token + return headers def simple_get(connection_url:str) -> dict: try: - response = requests.get(connection_url) + response = requests.get(connection_url, headers=get_headers()) if response.status_code == 200: return {"status": "ok", "text": response.text, "status_code": response.status_code} else: @@ -15,10 +24,7 @@ def simple_get(connection_url:str) -> dict: def simple_post(connection_url:str, data) -> dict: try: - headers = { - "Content-Type": "application/json" - } - response = requests.post(connection_url, headers=headers, data=data, stream=False) + response = requests.post(connection_url, headers=get_headers(), data=data, stream=False) if response.status_code == 200: return {"status": "ok", "text": response.text, "status_code": response.status_code} else: @@ -28,7 +34,7 @@ def simple_post(connection_url:str, data) -> dict: def simple_delete(connection_url:str, data) -> dict: try: - response = requests.delete(connection_url, json=data) + response = requests.delete(connection_url, headers=get_headers(), json=data) if response.status_code == 200: return {"status": "ok", "status_code": response.status_code} else: @@ -38,10 +44,7 @@ def simple_delete(connection_url:str, data) -> dict: def stream_post(connection_url:str, data, callback:callable) -> dict: try: - headers = { - "Content-Type": "application/json" - } - response = requests.post(connection_url, headers=headers, data=data, stream=True) + response = requests.post(connection_url, headers=get_headers(), data=data, stream=True) if response.status_code == 200: for line in response.iter_lines(): if line: diff --git a/src/window.py b/src/window.py index 8f2daf4..4c95350 100644 --- a/src/window.py +++ b/src/window.py @@ -49,6 +49,7 @@ class AlpacaWindow(Adw.ApplicationWindow): available_models = None run_on_background = False remote_url = "" + remote_bearer_token = "" run_remote = False model_tweaks = {"temperature": 0.7, "seed": 0, "keep_alive": 5} local_models = [] @@ -116,6 +117,7 @@ class AlpacaWindow(Adw.ApplicationWindow): background_switch = Gtk.Template.Child() remote_connection_switch = Gtk.Template.Child() remote_connection_entry = Gtk.Template.Child() + remote_bearer_token_entry = Gtk.Template.Child() toast_messages = { "error": [ @@ -280,6 +282,17 @@ class AlpacaWindow(Adw.ApplicationWindow): entry.set_css_classes(["error"]) self.show_toast("error", 1, self.preferences_dialog) + @Gtk.Template.Callback() + def change_remote_bearer_token(self, entry): + self.remote_bearer_token = entry.get_text() + self.save_server_config() + return + if self.remote_url and self.run_remote: + connection_handler.url = self.remote_url + if self.verify_connection() == False: + entry.set_css_classes(["error"]) + self.show_toast("error", 1, self.preferences_dialog) + @Gtk.Template.Callback() def pull_featured_model(self, button): action_row = button.get_parent().get_parent().get_parent() @@ -641,7 +654,7 @@ Generate a title following these rules: def save_server_config(self): with open(os.path.join(self.config_dir, "server.json"), "w+") as f: - json.dump({'remote_url': self.remote_url, 'run_remote': self.run_remote, 'local_port': local_instance.port, 'run_on_background': self.run_on_background, 'model_tweaks': self.model_tweaks, 'ollama_overrides': local_instance.overrides}, f, indent=6) + json.dump({'remote_url': self.remote_url, 'remote_bearer_token': self.remote_bearer_token, 'run_remote': self.run_remote, 'local_port': local_instance.port, 'run_on_background': self.run_on_background, 'model_tweaks': self.model_tweaks, 'ollama_overrides': local_instance.overrides}, f, indent=6) def verify_connection(self): response = connection_handler.simple_get(connection_handler.url) @@ -1077,6 +1090,7 @@ Generate a title following these rules: def connect_local(self): self.run_remote = False + connection_handler.bearer_token = None connection_handler.url = f"http://127.0.0.1:{local_instance.port}" local_instance.start() if self.verify_connection() == False: self.connection_error() @@ -1094,10 +1108,12 @@ Generate a title following these rules: if new_value != self.run_remote: self.run_remote = new_value if self.run_remote: + connection_handler.bearer_token = self.remote_bearer_token connection_handler.url = self.remote_url if self.verify_connection() == False: self.connection_error() else: local_instance.stop() else: + connection_handler.bearer_token = None connection_handler.url = f"http://127.0.0.1:{local_instance.port}" local_instance.start() if self.verify_connection() == False: self.connection_error() @@ -1349,6 +1365,7 @@ Generate a title following these rules: self.run_remote = data['run_remote'] local_instance.port = data['local_port'] self.remote_url = data['remote_url'] + self.remote_bearer_token = data['remote_bearer_token'] if 'remote_bearer_token' in data else '' self.run_on_background = data['run_on_background'] #Model Tweaks if "model_tweaks" in data: self.model_tweaks = data['model_tweaks'] @@ -1369,10 +1386,13 @@ Generate a title following these rules: self.background_switch.set_active(self.run_on_background) self.set_hide_on_close(self.run_on_background) self.remote_connection_entry.set_text(self.remote_url) + self.remote_bearer_token_entry.set_text(self.remote_bearer_token) if self.run_remote: - connection_handler.url = data['remote_url'] + connection_handler.bearer_token = self.remote_bearer_token + connection_handler.url = self.remote_url self.remote_connection_switch.set_active(True) else: + connection_handler.bearer_token = None self.remote_connection_switch.set_active(False) connection_handler.url = f"http://127.0.0.1:{local_instance.port}" local_instance.start() diff --git a/src/window.ui b/src/window.ui index 3d57a50..83dad85 100644 --- a/src/window.ui +++ b/src/window.ui @@ -303,6 +303,13 @@ true + + + + Bearer Token (Optional) + true + + @@ -311,7 +318,7 @@ Manage Alpaca's Behavior--> - Run In Background + Run Alpaca In Background