Browse Source

feat: Support for Vertex AI - load Default Application Configuration (#4641)

Co-authored-by: miendinh <miendinh@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
miendinh 11 months ago
parent
commit
f804adbff3

+ 5 - 2
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -164,10 +164,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
             config_kwargs["stop_sequences"] = stop
 
         service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
-        service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
         project_id = credentials["vertex_project_id"]
         location = credentials["vertex_location"]
-        aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+        if service_account_info:
+            service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+            aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+        else:
+            aiplatform.init(project=project_id, location=location)
 
         history = []
         system_instruction = GEMINI_BLOCK_MODE_PROMPT

+ 10 - 6
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

@@ -41,15 +41,16 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
         :return: embeddings result
         """
         service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
-        service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
         project_id = credentials["vertex_project_id"]
         location = credentials["vertex_location"]
-        aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+        if service_account_info:
+            service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+            aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+        else:
+            aiplatform.init(project=project_id, location=location)
 
         client = VertexTextEmbeddingModel.from_pretrained(model)
 
-        
-
         embeddings_batch, embedding_used_tokens = self._embedding_invoke(
             client=client,
             texts=texts
@@ -103,10 +104,13 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
         """
         try:
             service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
-            service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
             project_id = credentials["vertex_project_id"]
             location = credentials["vertex_location"]
-            aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+            if service_account_info:
+                service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+                aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+            else:
+                aiplatform.init(project=project_id, location=location)
 
             client = VertexTextEmbeddingModel.from_pretrained(model)
 

+ 2 - 2
api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml

@@ -36,8 +36,8 @@ provider_credential_schema:
         en_US: Enter your Google Cloud Location
     - variable: vertex_service_account_key
       label:
-        en_US: Service Account Key
+        en_US: Service Account Key (Leave blank if you use Application Default Credentials)
       type: secret-input
-      required: true
+      required: false
       placeholder:
         en_US: Enter your Google Cloud Service Account Key in base64 format