|
@@ -78,7 +78,7 @@ class AzureProvider(BaseProvider):
|
|
|
|
|
|
def get_token_type(self):
|
|
|
# TODO: change to dict when implemented
|
|
|
- return lambda value: value
|
|
|
+ return dict
|
|
|
|
|
|
def config_validate(self, config: Union[dict | str]):
|
|
|
"""
|
|
@@ -91,16 +91,34 @@ class AzureProvider(BaseProvider):
|
|
|
if 'openai_api_version' not in config:
|
|
|
config['openai_api_version'] = '2023-03-15-preview'
|
|
|
|
|
|
- self.get_models(credentials=config)
|
|
|
+ models = self.get_models(credentials=config)
|
|
|
+
|
|
|
+ if not models:
|
|
|
+ raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
|
|
|
+ "'gpt-3.5-turbo', 'text-embedding-ada-002'.")
|
|
|
+
|
|
|
+ fixed_model_ids = [
|
|
|
+ 'text-davinci-003',
|
|
|
+ 'gpt-35-turbo',
|
|
|
+ 'text-embedding-ada-002'
|
|
|
+ ]
|
|
|
+
|
|
|
+ current_model_ids = [model['id'] for model in models]
|
|
|
+
|
|
|
+ missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
|
|
|
+ fixed_model_id not in current_model_ids]
|
|
|
+
|
|
|
+ if missing_model_ids:
|
|
|
+ raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
|
|
|
except AzureAuthenticationError:
|
|
|
- raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.')
|
|
|
- except requests.ConnectionError:
|
|
|
- raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.')
|
|
|
+ raise ValidateFailedError('Validation failed, please check your API Key.')
|
|
|
+ except (requests.ConnectionError, requests.RequestException):
|
|
|
+ raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
|
|
|
except AzureRequestFailedError as ex:
|
|
|
- raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex)))
|
|
|
+ raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
|
|
except Exception as ex:
|
|
|
logging.exception('Azure OpenAI Credentials validation failed')
|
|
|
- raise ex
|
|
|
+ raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
|
|
|
|
|
def get_encrypted_token(self, config: Union[dict | str]):
|
|
|
"""
|