Selaa lähdekoodia

fix: azure openai deployment list was deprecated suddenly (#611)

John Wang 1 vuosi sitten
vanhempi
commit
cae15013e0
1 muutettua tiedostoa jossa 32 lisäystä ja 44 poistoa
  1. 32 44
      api/core/llm/provider/azure_provider.py

+ 32 - 44
api/core/llm/provider/azure_provider.py

@@ -2,6 +2,7 @@ import json
 import logging
 from typing import Optional, Union
 
+import openai
 import requests
 
 from core.llm.provider.base import BaseProvider
@@ -11,30 +12,37 @@ from models.provider import ProviderName
 
 class AzureProvider(BaseProvider):
     def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
-        credentials = self.get_credentials(model_id) if not credentials else credentials
-        url = "{}/openai/deployments?api-version={}".format(
-            str(credentials.get('openai_api_base')),
-            str(credentials.get('openai_api_version'))
-        )
-
-        headers = {
-            "api-key": str(credentials.get('openai_api_key')),
-            "content-type": "application/json; charset=utf-8"
-        }
-
-        response = requests.get(url, headers=headers)
-
-        if response.status_code == 200:
-            result = response.json()
-            return [{
-                'id': deployment['id'],
-                'name': '{} ({})'.format(deployment['id'], deployment['model'])
-            } for deployment in result['data'] if deployment['status'] == 'succeeded']
-        else:
-            if response.status_code == 401:
-                raise AzureAuthenticationError()
+        return []
+
+    def check_embedding_model(self, credentials: Optional[dict] = None):
+        credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
+        try:
+            result = openai.Embedding.create(input=['test'],
+                                             engine='text-embedding-ada-0021',
+                                             timeout=60,
+                                             api_key=str(credentials.get('openai_api_key')),
+                                             api_base=str(credentials.get('openai_api_base')),
+                                             api_type='azure',
+                                             api_version=str(credentials.get('openai_api_version')))["data"][0][
+                "embedding"]
+        except openai.error.AuthenticationError as e:
+            raise AzureAuthenticationError(str(e))
+        except openai.error.APIConnectionError as e:
+            raise AzureRequestFailedError(
+                'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
+        except openai.error.InvalidRequestError as e:
+            if e.http_status == 404:
+                raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
+                                              "deployment name is exists in Azure AI")
             else:
-                raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
+                raise AzureRequestFailedError(
+                    'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
+        except openai.error.OpenAIError as e:
+            raise AzureRequestFailedError(
+                'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
+
+        if not isinstance(result, list):
+            raise AzureRequestFailedError('Failed to request Azure OpenAI.')
 
     def get_credentials(self, model_id: Optional[str] = None) -> dict:
         """
@@ -94,31 +102,11 @@ class AzureProvider(BaseProvider):
             if 'openai_api_version' not in config:
                 config['openai_api_version'] = '2023-03-15-preview'
 
-            models = self.get_models(credentials=config)
-
-            if not models:
-                raise ValidateFailedError("Please add deployments for "
-                                          "'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
-                                          "and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).")
-
-            fixed_model_ids = [
-                'gpt-35-turbo',
-                'text-embedding-ada-002'
-            ]
-
-            current_model_ids = [model['id'] for model in models]
-
-            missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
-                                 fixed_model_id not in current_model_ids]
-
-            if missing_model_ids:
-                raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
+            self.check_embedding_model(credentials=config)
         except ValidateFailedError as e:
             raise e
         except AzureAuthenticationError:
             raise ValidateFailedError('Validation failed, please check your API Key.')
-        except (requests.ConnectionError, requests.RequestException):
-            raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
         except AzureRequestFailedError as ex:
             raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
         except Exception as ex: