John Wang 1 рік тому
батько
коміт
6da5e54180

+ 1 - 1
api/controllers/console/workspace/providers.py

@@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource):
         args = parser.parse_args()
 
         # todo: remove this when the provider is supported
-        if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
+        if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
                         ProviderName.HUGGINGFACEHUB.value]:
             return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
 

+ 25 - 7
api/core/llm/provider/azure_provider.py

@@ -78,7 +78,7 @@ class AzureProvider(BaseProvider):
 
     def get_token_type(self):
         # TODO: change to dict when implemented
-        return lambda value: value
+        return dict
 
     def config_validate(self, config: Union[dict | str]):
         """
@@ -91,16 +91,34 @@ class AzureProvider(BaseProvider):
             if 'openai_api_version' not in config:
                 config['openai_api_version'] = '2023-03-15-preview'
 
-            self.get_models(credentials=config)
+            models = self.get_models(credentials=config)
+
+            if not models:
+                raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
+                                          "'gpt-3.5-turbo', 'text-embedding-ada-002'.")
+
+            fixed_model_ids = [
+                'text-davinci-003',
+                'gpt-35-turbo',
+                'text-embedding-ada-002'
+            ]
+
+            current_model_ids = [model['id'] for model in models]
+
+            missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
+                                 fixed_model_id not in current_model_ids]
+
+            if missing_model_ids:
+                raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
         except AzureAuthenticationError:
-            raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.')
-        except requests.ConnectionError:
-            raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.')
+            raise ValidateFailedError('Validation failed, please check your API Key.')
+        except (requests.ConnectionError, requests.RequestException):
+            raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
         except AzureRequestFailedError as ex:
-            raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex)))
+            raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
         except Exception as ex:
             logging.exception('Azure OpenAI Credentials validation failed')
-            raise ex
+            raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
 
     def get_encrypted_token(self, config: Union[dict | str]):
         """