Kaynağa Gözat

Fix: The topk parameter doesn't work in sagemaker rerank tool (#12150)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
ybalbert001 3 ay önce
ebeveyn
işleme
0fdb39f1c3

+ 3 - 5
api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py

@@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
 
 class SageMakerReRankTool(BuiltinTool):
     sagemaker_client: Any = None
-    sagemaker_endpoint: str | None = None
-    topk: int | None = None
+    sagemaker_endpoint: str = None
 
     def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
         inputs = [query_input] * len(docs)
@@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool):
                 self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
 
             line = 2
-            if not self.topk:
-                self.topk = tool_parameters.get("topk", 5)
+            topk = tool_parameters.get("topk", 5)
 
             line = 3
             query = tool_parameters.get("query", "")
@@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool):
             sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
 
             line = 9
-            return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
+            return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]]
 
         except Exception as e:
             return self.create_text_message(f"Exception {str(e)}, line : {line}")