|
@@ -2,6 +2,7 @@ import json
|
|
|
import logging
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
+import openai
|
|
|
import requests
|
|
|
|
|
|
from core.llm.provider.base import BaseProvider
|
|
@@ -11,30 +12,37 @@ from models.provider import ProviderName
|
|
|
|
|
|
class AzureProvider(BaseProvider):
|
|
|
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
|
|
|
- credentials = self.get_credentials(model_id) if not credentials else credentials
|
|
|
- url = "{}/openai/deployments?api-version={}".format(
|
|
|
- str(credentials.get('openai_api_base')),
|
|
|
- str(credentials.get('openai_api_version'))
|
|
|
- )
|
|
|
-
|
|
|
- headers = {
|
|
|
- "api-key": str(credentials.get('openai_api_key')),
|
|
|
- "content-type": "application/json; charset=utf-8"
|
|
|
- }
|
|
|
-
|
|
|
- response = requests.get(url, headers=headers)
|
|
|
-
|
|
|
- if response.status_code == 200:
|
|
|
- result = response.json()
|
|
|
- return [{
|
|
|
- 'id': deployment['id'],
|
|
|
- 'name': '{} ({})'.format(deployment['id'], deployment['model'])
|
|
|
- } for deployment in result['data'] if deployment['status'] == 'succeeded']
|
|
|
- else:
|
|
|
- if response.status_code == 401:
|
|
|
- raise AzureAuthenticationError()
|
|
|
+ return []
|
|
|
+
|
|
|
+ def check_embedding_model(self, credentials: Optional[dict] = None):
|
|
|
+ credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
|
|
|
+ try:
|
|
|
+ result = openai.Embedding.create(input=['test'],
|
|
|
+ engine='text-embedding-ada-0021',
|
|
|
+ timeout=60,
|
|
|
+ api_key=str(credentials.get('openai_api_key')),
|
|
|
+ api_base=str(credentials.get('openai_api_base')),
|
|
|
+ api_type='azure',
|
|
|
+ api_version=str(credentials.get('openai_api_version')))["data"][0][
|
|
|
+ "embedding"]
|
|
|
+ except openai.error.AuthenticationError as e:
|
|
|
+ raise AzureAuthenticationError(str(e))
|
|
|
+ except openai.error.APIConnectionError as e:
|
|
|
+ raise AzureRequestFailedError(
|
|
|
+ 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
|
|
|
+ except openai.error.InvalidRequestError as e:
|
|
|
+ if e.http_status == 404:
|
|
|
+ raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
|
|
|
+ "deployment name is exists in Azure AI")
|
|
|
else:
|
|
|
- raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
|
|
|
+ raise AzureRequestFailedError(
|
|
|
+ 'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
|
|
+ except openai.error.OpenAIError as e:
|
|
|
+ raise AzureRequestFailedError(
|
|
|
+ 'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
|
|
+
|
|
|
+ if not isinstance(result, list):
|
|
|
+ raise AzureRequestFailedError('Failed to request Azure OpenAI.')
|
|
|
|
|
|
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
|
|
"""
|
|
@@ -94,31 +102,11 @@ class AzureProvider(BaseProvider):
|
|
|
if 'openai_api_version' not in config:
|
|
|
config['openai_api_version'] = '2023-03-15-preview'
|
|
|
|
|
|
- models = self.get_models(credentials=config)
|
|
|
-
|
|
|
- if not models:
|
|
|
- raise ValidateFailedError("Please add deployments for "
|
|
|
- "'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
|
|
|
- "and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).")
|
|
|
-
|
|
|
- fixed_model_ids = [
|
|
|
- 'gpt-35-turbo',
|
|
|
- 'text-embedding-ada-002'
|
|
|
- ]
|
|
|
-
|
|
|
- current_model_ids = [model['id'] for model in models]
|
|
|
-
|
|
|
- missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
|
|
|
- fixed_model_id not in current_model_ids]
|
|
|
-
|
|
|
- if missing_model_ids:
|
|
|
- raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
|
|
|
+ self.check_embedding_model(credentials=config)
|
|
|
except ValidateFailedError as e:
|
|
|
raise e
|
|
|
except AzureAuthenticationError:
|
|
|
raise ValidateFailedError('Validation failed, please check your API Key.')
|
|
|
- except (requests.ConnectionError, requests.RequestException):
|
|
|
- raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
|
|
|
except AzureRequestFailedError as ex:
|
|
|
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
|
|
except Exception as ex:
|