فهرست منبع

Fix model provider of vertex ai (#11437)

Kazuki Takamatsu 4 ماه پیش
والد
کامیت
4d7cfd0de5

+ 6 - 4
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -104,13 +104,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
         """
         # use Anthropic official SDK references
         # - https://github.com/anthropics/anthropic-sdk-python
-        service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+        service_account_key = credentials.get("vertex_service_account_key", "")
         project_id = credentials["vertex_project_id"]
         SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
         token = ""
 
         # get access token from service account credential
-        if service_account_info:
+        if service_account_key:
+            service_account_info = json.loads(base64.b64decode(service_account_key))
             credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
             request = google.auth.transport.requests.Request()
             credentials.refresh(request)
@@ -478,10 +479,11 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
         if stop:
             config_kwargs["stop_sequences"] = stop
 
-        service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+        service_account_key = credentials.get("vertex_service_account_key", "")
         project_id = credentials["vertex_project_id"]
         location = credentials["vertex_location"]
-        if service_account_info:
+        if service_account_key:
+            service_account_info = json.loads(base64.b64decode(service_account_key))
             service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
             aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
         else:

+ 6 - 4
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

@@ -48,10 +48,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
         :param input_type: input type
         :return: embeddings result
         """
-        service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+        service_account_key = credentials.get("vertex_service_account_key", "")
         project_id = credentials["vertex_project_id"]
         location = credentials["vertex_location"]
-        if service_account_info:
+        if service_account_key:
+            service_account_info = json.loads(base64.b64decode(service_account_key))
             service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
             aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
         else:
@@ -100,10 +101,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
         :return:
         """
         try:
-            service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+            service_account_key = credentials.get("vertex_service_account_key", "")
             project_id = credentials["vertex_project_id"]
             location = credentials["vertex_location"]
-            if service_account_info:
+            if service_account_key:
+                service_account_info = json.loads(base64.b64decode(service_account_key))
                 service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
                 aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
             else: