|
@@ -1,8 +1,5 @@
|
|
|
from typing import Optional
|
|
|
|
|
|
-import boto3
|
|
|
-from botocore.config import Config
|
|
|
-
|
|
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
@@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
|
|
+from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
|
|
|
|
|
|
|
|
class BedrockRerankModel(RerankModel):
|
|
@@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel):
|
|
|
return RerankResult(model=model, docs=docs)
|
|
|
|
|
|
# initialize client
|
|
|
- client_config = Config(region_name=credentials["aws_region"])
|
|
|
- bedrock_runtime = boto3.client(
|
|
|
- service_name="bedrock-agent-runtime",
|
|
|
- config=client_config,
|
|
|
- aws_access_key_id=credentials.get("aws_access_key_id", ""),
|
|
|
- aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
|
|
- )
|
|
|
+ bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
|
|
|
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
|
|
|
text_sources = []
|
|
|
for text in docs:
|