소스 검색

[ref] use one method to get boto client for aws bedrock (#11506)

Warren Chen 4 달 전
부모
커밋
7b5839335a

+ 21 - 0
api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py

@@ -0,0 +1,21 @@
+import boto3
+from botocore.config import Config
+
+
+def get_bedrock_client(service_name, credentials=None):
+    client_config = Config(region_name=credentials["aws_region"])
+    aws_access_key_id = credentials["aws_access_key_id"]
+    aws_secret_access_key = credentials["aws_secret_access_key"]
+    if aws_access_key_id and aws_secret_access_key:
+        # use aksk to call bedrock
+        client = boto3.client(
+            service_name=service_name,
+            config=client_config,
+            aws_access_key_id=aws_access_key_id,
+            aws_secret_access_key=aws_secret_access_key,
+        )
+    else:
+        # use iam without aksk to call
+        client = boto3.client(service_name=service_name, config=client_config)
+
+    return client

+ 2 - 7
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
 
 logger = logging.getLogger(__name__)
 ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
@@ -173,13 +174,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param stream: is stream response
         :return: full response or stream response chunk generator result
         """
-        bedrock_client = boto3.client(
-            service_name="bedrock-runtime",
-            aws_access_key_id=credentials.get("aws_access_key_id"),
-            aws_secret_access_key=credentials.get("aws_secret_access_key"),
-            region_name=credentials["aws_region"],
-        )
-
+        bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
         system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
         inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
 

+ 2 - 10
api/core/model_runtime/model_providers/bedrock/rerank/rerank.py

@@ -1,8 +1,5 @@
 from typing import Optional
 
-import boto3
-from botocore.config import Config
-
 from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
@@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
 
 
 class BedrockRerankModel(RerankModel):
@@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel):
             return RerankResult(model=model, docs=docs)
 
         # initialize client
-        client_config = Config(region_name=credentials["aws_region"])
-        bedrock_runtime = boto3.client(
-            service_name="bedrock-agent-runtime",
-            config=client_config,
-            aws_access_key_id=credentials.get("aws_access_key_id", ""),
-            aws_secret_access_key=credentials.get("aws_secret_access_key"),
-        )
+        bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
         queries = [{"type": "TEXT", "textQuery": {"text": query}}]
         text_sources = []
         for text in docs:

+ 2 - 10
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -3,8 +3,6 @@ import logging
 import time
 from typing import Optional
 
-import boto3
-from botocore.config import Config
 from botocore.exceptions import (
     ClientError,
     EndpointConnectionError,
@@ -25,6 +23,7 @@ from core.model_runtime.errors.invoke import (
     InvokeServerUnavailableError,
 )
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
 
 logger = logging.getLogger(__name__)
 
@@ -48,14 +47,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
         :param input_type: input type
         :return: embeddings result
         """
-        client_config = Config(region_name=credentials["aws_region"])
-
-        bedrock_runtime = boto3.client(
-            service_name="bedrock-runtime",
-            config=client_config,
-            aws_access_key_id=credentials.get("aws_access_key_id"),
-            aws_secret_access_key=credentials.get("aws_secret_access_key"),
-        )
+        bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
 
         embeddings = []
         token_usage = 0