ソースを参照

Fix #12448 - update bedrock retrieve tool, support hybrid search type and re… (#12446)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
ybalbert001 3 ヶ月 前
コミット
2a14c67edc

+ 34 - 5
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py

@@ -14,14 +14,38 @@ class BedrockRetrieveTool(BuiltinTool):
     topk: int = None
 
     def _bedrock_retrieve(
-        self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
+        self,
+        query_input: str,
+        knowledge_base_id: str,
+        num_results: int,
+        search_type: str,
+        rerank_model_id: str,
+        metadata_filter: Optional[dict] = None,
     ):
         try:
             retrieval_query = {"text": query_input}
 
-            retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
+            if search_type not in ["HYBRID", "SEMANTIC"]:
+                raise RuntimeException("search_type should be HYBRID or SEMANTIC")
+
+            retrieval_configuration = {
+                "vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
+            }
+
+            if rerank_model_id != "default":
+                model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
+                rerankingConfiguration = {
+                    "bedrockRerankingConfiguration": {
+                        "numberOfRerankedResults": num_results,
+                        "modelConfiguration": {"modelArn": model_for_rerank_arn},
+                    },
+                    "type": "BEDROCK_RERANKING_MODEL",
+                }
 
-            # Add metadata filter to retrieval configuration if present
+                retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
+                retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5
+
+            # 如果有元数据过滤条件,则添加到检索配置中
             if metadata_filter:
                 retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
 
@@ -77,15 +101,20 @@ class BedrockRetrieveTool(BuiltinTool):
             if not query:
                 return self.create_text_message("Please input query")
 
-            # Get metadata filter conditions (if they exist)
+            # 获取元数据过滤条件(如果存在)
             metadata_filter_str = tool_parameters.get("metadata_filter")
             metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
 
+            search_type = tool_parameters.get("search_type")
+            rerank_model_id = tool_parameters.get("rerank_model_id")
+
             line = 4
             retrieved_docs = self._bedrock_retrieve(
                 query_input=query,
                 knowledge_base_id=self.knowledge_base_id,
                 num_results=self.topk,
+                search_type=search_type,
+                rerank_model_id=rerank_model_id,
                 metadata_filter=metadata_filter,
             )
 
@@ -109,7 +138,7 @@ class BedrockRetrieveTool(BuiltinTool):
         if not parameters.get("query"):
             raise ValueError("query is required")
 
-        # Optional: Validate if metadata filter is a valid JSON string (if provided)
+        # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
         metadata_filter_str = parameters.get("metadata_filter")
         if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
             raise ValueError("metadata_filter must be a valid JSON object")

+ 51 - 0
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml

@@ -59,6 +59,57 @@ parameters:
     max: 10
     default: 5
 
+  - name: search_type
+    type: select
+    required: false
+    label:
+      en_US: search type
+      zh_Hans: 搜索类型
+      pt_BR: search type
+    human_description:
+      en_US: search type
+      zh_Hans: 搜索类型
+      pt_BR: search type
+    llm_description: search type
+    default: SEMANTIC
+    options:
+      - value: SEMANTIC
+        label:
+          en_US: SEMANTIC
+          zh_Hans: 语义搜索
+      - value: HYBRID
+        label:
+          en_US: HYBRID
+          zh_Hans: 混合搜索
+    form: form
+
+  - name: rerank_model_id
+    type: select
+    required: false
+    label:
+      en_US: rerank model id
+      zh_Hans: 重拍模型ID
+      pt_BR: rerank model id
+    human_description:
+      en_US: rerank model id
+      zh_Hans: 重拍模型ID
+      pt_BR: rerank model id
+    llm_description: rerank model id
+    options:
+      - value: default
+        label:
+          en_US: default
+          zh_Hans: 默认
+      - value: cohere.rerank-v3-5:0
+        label:
+          en_US: cohere.rerank-v3-5:0
+          zh_Hans: cohere.rerank-v3-5:0
+      - value: amazon.rerank-v1:0
+        label:
+          en_US: amazon.rerank-v1:0
+          zh_Hans: amazon.rerank-v1:0
+    form: form
+
   - name: aws_region
     type: string
     required: false