|
@@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
|
|
|
class SageMakerReRankTool(BuiltinTool):
|
|
|
sagemaker_client: Any = None
|
|
|
- sagemaker_endpoint: str | None = None
|
|
|
- topk: int | None = None
|
|
|
+ sagemaker_endpoint: str = None
|
|
|
|
|
|
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
|
|
inputs = [query_input] * len(docs)
|
|
@@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool):
|
|
|
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
|
|
|
|
|
line = 2
|
|
|
- if not self.topk:
|
|
|
- self.topk = tool_parameters.get("topk", 5)
|
|
|
+ topk = tool_parameters.get("topk", 5)
|
|
|
|
|
|
line = 3
|
|
|
query = tool_parameters.get("query", "")
|
|
@@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool):
|
|
|
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
|
|
|
|
|
line = 9
|
|
|
- return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
|
|
+ return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]]
|
|
|
|
|
|
except Exception as e:
|
|
|
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|