Parcourir la source

[Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#12227)

Warren Chen il y a 3 mois
Parent
commit
562450751f

+ 4 - 4
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py

@@ -21,7 +21,7 @@ class BedrockRetrieveTool(BuiltinTool):
 
             retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
 
-            # 如果有元数据过滤条件,则添加到检索配置中
+            # Add metadata filter to retrieval configuration if present
             if metadata_filter:
                 retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
 
@@ -77,7 +77,7 @@ class BedrockRetrieveTool(BuiltinTool):
             if not query:
                 return self.create_text_message("Please input query")
 
-            # 获取元数据过滤条件(如果存在)
+            # Get metadata filter conditions (if they exist)
             metadata_filter_str = tool_parameters.get("metadata_filter")
             metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
 
@@ -86,7 +86,7 @@ class BedrockRetrieveTool(BuiltinTool):
                 query_input=query,
                 knowledge_base_id=self.knowledge_base_id,
                 num_results=self.topk,
-                metadata_filter=metadata_filter,  # 将元数据过滤条件传递给检索方法
+                metadata_filter=metadata_filter,
             )
 
             line = 5
@@ -109,7 +109,7 @@ class BedrockRetrieveTool(BuiltinTool):
         if not parameters.get("query"):
             raise ValueError("query is required")
 
-        # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
+        # Optional: Validate if metadata filter is a valid JSON string (if provided)
         metadata_filter_str = parameters.get("metadata_filter")
         if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
             raise ValueError("metadata_filter must be a valid JSON object")

+ 3 - 3
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml

@@ -73,9 +73,9 @@ parameters:
     llm_description: AWS region where the Bedrock Knowledge Base is located
     form: form
 
-  - name: metadata_filter
-    type: string
-    required: false
+  - name: metadata_filter   # 新增的元数据过滤参数
+    type: string            # 可以是字符串类型,包含 JSON 格式的过滤条件
+    required: false         # 可选参数
     label:
       en_US: Metadata Filter
       zh_Hans: 元数据过滤器

+ 6 - 6
api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py

@@ -6,8 +6,8 @@ import boto3
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
 
-# 定义标签映射
-LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
+# Define label mappings
+LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"}
 
 
 class ContentModerationTool(BuiltinTool):
@@ -28,12 +28,12 @@ class ContentModerationTool(BuiltinTool):
         # Handle nested JSON if present
         if isinstance(json_obj, dict) and "body" in json_obj:
             body_content = json.loads(json_obj["body"])
-            raw_label = body_content.get("label")
+            prediction_result = body_content.get("prediction")
         else:
-            raw_label = json_obj.get("label")
+            prediction_result = json_obj.get("prediction")
 
-        # 映射标签并返回
-        result = LABEL_MAPPING.get(raw_label, "NO_SAFE")  # 如果映射中没有找到,默认返回NO_SAFE
+        # Map labels and return
+        result = LABEL_MAPPING.get(prediction_result, "NO_SAFE")  # If not found in mapping, default to NO_SAFE
         return result
 
     def _invoke(