diff --git a/src/connection_handler.py b/src/connection_handler.py index 8b43136..e687df6 100644 --- a/src/connection_handler.py +++ b/src/connection_handler.py @@ -4,17 +4,17 @@ import json, requests url = None bearer_token = None -def get_headers() -> dict: - headers = { - "Content-Type": "application/json" - } +def get_headers(include_json:bool) -> dict: + headers = {} + if include_json: + headers["Content-Type"] = "application/json" if bearer_token: - headers["Authorization"] = "Bearer " + bearer_token - return headers + headers["Authorization"] = "Bearer {}".format(bearer_token) + return headers if len(headers.keys()) > 0 else None def simple_get(connection_url:str) -> dict: try: - response = requests.get(connection_url, headers=get_headers()) + response = requests.get(connection_url, headers=get_headers(False)) if response.status_code == 200: return {"status": "ok", "text": response.text, "status_code": response.status_code} else: @@ -24,7 +24,7 @@ def simple_get(connection_url:str) -> dict: def simple_post(connection_url:str, data) -> dict: try: - response = requests.post(connection_url, headers=get_headers(), data=data, stream=False) + response = requests.post(connection_url, headers=get_headers(True), data=data, stream=False) if response.status_code == 200: return {"status": "ok", "text": response.text, "status_code": response.status_code} else: @@ -34,7 +34,7 @@ def simple_post(connection_url:str, data) -> dict: def simple_delete(connection_url:str, data) -> dict: try: - response = requests.delete(connection_url, headers=get_headers(), json=data) + response = requests.delete(connection_url, headers=get_headers(False), json=data) if response.status_code == 200: return {"status": "ok", "status_code": response.status_code} else: @@ -44,7 +44,7 @@ def simple_delete(connection_url:str, data) -> dict: def stream_post(connection_url:str, data, callback:callable) -> dict: try: - response = requests.post(connection_url, headers=get_headers(), data=data, stream=True) + response = requests.post(connection_url, headers=get_headers(True), data=data, stream=True) if response.status_code == 200: for line in response.iter_lines(): if line: diff --git a/src/window.py b/src/window.py index c265075..3badbff 100644 --- a/src/window.py +++ b/src/window.py @@ -156,8 +156,8 @@ class AlpacaWindow(Adw.ApplicationWindow): @Gtk.Template.Callback() def verify_if_image_can_be_used(self, pspec=None, user_data=None): if self.model_drop_down.get_selected_item() == None: return True - selected = self.model_drop_down.get_selected_item().get_string().split(" (")[0] - if selected in ['llava', 'bakllava', 'moondream', 'llava-llama3']: + selected = self.model_drop_down.get_selected_item().get_string().split(" (")[0].lower() + if selected in [key for key, value in self.available_models.items() if value["image"]]: for name, content in self.attachments.items(): if content['type'] == 'image': content['button'].set_css_classes(["flat"]) @@ -501,7 +501,7 @@ Generate a title following these rules: {message} ```""" current_model = self.model_drop_down.get_selected_item().get_string() - current_model = current_model.replace(' (', ':')[:-1] + current_model = current_model.replace(' (', ':')[:-1].lower() response = connection_handler.simple_post(f"{connection_handler.url}/api/generate", data=json.dumps({"model": current_model, "prompt": prompt, "stream": False})) new_chat_name = json.loads(response['text'])["response"].replace('"', '').replace("'", "") new_chat_name = self.generate_numbered_name(new_chat_name, self.chats["chats"].keys()) diff --git a/src/window.ui b/src/window.ui index c70a889..0f0fc73 100644 --- a/src/window.ui +++ b/src/window.ui @@ -1021,6 +1021,7 @@ pdf png jpeg + jpg webp gif