Переглянути джерело

[feat] Add AWS Bedrock rerank (#11349)

Co-authored-by: crazywoola <427733928@qq.com>
Warren Chen 4 місяців тому
батько
коміт
376726cf90

+ 1 - 0
api/core/model_runtime/model_providers/bedrock/bedrock.yaml

@@ -16,6 +16,7 @@ help:
 supported_model_types:
   - llm
   - text-embedding
+  - rerank
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 0 - 0
api/core/model_runtime/model_providers/bedrock/rerank/__init__.py


+ 2 - 0
api/core/model_runtime/model_providers/bedrock/rerank/_position.yaml

@@ -0,0 +1,2 @@
+- amazon.rerank-v1
+- cohere.rerank-v3-5:0

+ 4 - 0
api/core/model_runtime/model_providers/bedrock/rerank/amazon.rerank-v1:0.yaml

@@ -0,0 +1,4 @@
+model: amazon.rerank-v1:0
+model_type: rerank
+model_properties:
+  context_size: 5120

+ 4 - 0
api/core/model_runtime/model_providers/bedrock/rerank/cohere.rerank-v3-5:0.yaml

@@ -0,0 +1,4 @@
+model: cohere.rerank-v3-5:0
+model_type: rerank
+model_properties:
+  context_size: 5120

+ 147 - 0
api/core/model_runtime/model_providers/bedrock/rerank/rerank.py

@@ -0,0 +1,147 @@
+from typing import Optional
+
+import boto3
+from botocore.config import Config
+
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class BedrockRerankModel(RerankModel):
+    """
+    Model class for Cohere rerank model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        query: str,
+        docs: list[str],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id
+        :return: rerank result
+        """
+
+        if len(docs) == 0:
+            return RerankResult(model=model, docs=docs)
+
+        # initialize client
+        client_config = Config(region_name=credentials["aws_region"])
+        bedrock_runtime = boto3.client(
+            service_name="bedrock-agent-runtime",
+            config=client_config,
+            aws_access_key_id=credentials.get("aws_access_key_id", ""),
+            aws_secret_access_key=credentials.get("aws_secret_access_key"),
+        )
+        queries = [{"type": "TEXT", "textQuery": {"text": query}}]
+        text_sources = []
+        for text in docs:
+            text_sources.append(
+                {
+                    "type": "INLINE",
+                    "inlineDocumentSource": {
+                        "type": "TEXT",
+                        "textDocument": {
+                            "text": text,
+                        },
+                    },
+                }
+            )
+        modelId = model
+        region = credentials["aws_region"]
+        model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
+        rerankingConfiguration = {
+            "type": "BEDROCK_RERANKING_MODEL",
+            "bedrockRerankingConfiguration": {
+                "numberOfResults": top_n,
+                "modelConfiguration": {
+                    "modelArn": model_package_arn,
+                },
+            },
+        }
+        response = bedrock_runtime.rerank(
+            queries=queries, sources=text_sources, rerankingConfiguration=rerankingConfiguration
+        )
+
+        rerank_documents = []
+        for idx, result in enumerate(response["results"]):
+            # format document
+            index = result["index"]
+            rerank_document = RerankDocument(
+                index=index,
+                text=docs[index],
+                score=result["relevanceScore"],
+            )
+
+            # score threshold check
+            if score_threshold is not None:
+                if rerank_document.score >= score_threshold:
+                    rerank_documents.append(rerank_document)
+            else:
+                rerank_documents.append(rerank_document)
+
+        return RerankResult(model=model, docs=rerank_documents)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self.invoke(
+                model=model,
+                credentials=credentials,
+                query="What is the capital of the United States?",
+                docs=[
+                    "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                    "Census, Carson City had a population of 55,274.",
+                    "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                    "are a political division controlled by the United States. Its capital is Saipan.",
+                ],
+                score_threshold=0.8,
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
+        The value is the md = genai.GenerativeModel(model) error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke emd = genai.GenerativeModel(model) error mapping
+        """
+        return {
+            InvokeConnectionError: [],
+            InvokeServerUnavailableError: [],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [],
+            InvokeBadRequestError: [],
+        }

Різницю між файлами не показано, бо вона завелика
+ 406 - 346
api/poetry.lock


+ 1 - 1
api/pyproject.toml

@@ -20,7 +20,7 @@ azure-ai-inference = "~1.0.0b3"
 azure-ai-ml = "~1.20.0"
 azure-identity = "1.16.1"
 beautifulsoup4 = "4.12.2"
-boto3 = "1.35.17"
+boto3 = "1.35.74"
 bs4 = "~0.0.1"
 cachetools = "~5.3.0"
 celery = "~5.4.0"

Деякі файли не було показано, через те що забагато файлів було змінено