Explorar el Código

fix: dataset desc (#1045)

takatost hace 1 año
padre
commit
7b3314c5db

+ 1 - 1
api/core/agent/agent/multi_dataset_router_agent.py

@@ -52,7 +52,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
         elif len(self.tools) == 1:
             tool = next(iter(self.tools))
             tool = cast(DatasetRetrieverTool, tool)
-            rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
+            rst = tool.run(tool_input={'query': kwargs['input']})
             return AgentFinish(return_values={"output": rst}, log=rst)
 
         if intermediate_steps:

+ 1 - 1
api/core/agent/agent/openai_function_call.py

@@ -45,7 +45,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
         :return:
         """
         original_max_tokens = self.llm.max_tokens
-        self.llm.max_tokens = 15
+        self.llm.max_tokens = 40
 
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()

+ 1 - 1
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -90,7 +90,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         elif len(self.dataset_tools) == 1:
             tool = next(iter(self.dataset_tools))
             tool = cast(DatasetRetrieverTool, tool)
-            rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
+            rst = tool.run(tool_input={'query': kwargs['input']})
             return AgentFinish(return_values={"output": rst}, log=rst)
 
         full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)

+ 10 - 4
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -1,5 +1,6 @@
 import json
 import logging
+from json import JSONDecodeError
 
 from typing import Any, Dict, List, Union, Optional
 
@@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
         input_str: str,
         **kwargs: Any,
     ) -> None:
-        # tool_name = serialized.get('name')
-        input_dict = json.loads(input_str.replace("'", "\""))
-        dataset_id = input_dict.get('dataset_id')
-        query = input_dict.get('query')
+        tool_name: str = serialized.get('name')
+        dataset_id = tool_name.removeprefix('dataset-')
+
+        try:
+            input_dict = json.loads(input_str.replace("'", "\""))
+            query = input_dict.get('query')
+        except JSONDecodeError:
+            query = input_str
+
         self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
 
     def on_tool_end(

+ 4 - 11
api/core/tool/dataset_retriever_tool.py

@@ -1,4 +1,3 @@
-import re
 from typing import Type
 
 from flask import current_app
@@ -16,7 +15,6 @@ from models.dataset import Dataset, DocumentSegment
 
 
 class DatasetRetrieverToolInput(BaseModel):
-    dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
     query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
 
 
@@ -37,27 +35,22 @@ class DatasetRetrieverTool(BaseTool):
             description = 'useful for when you want to answer queries about the ' + dataset.name
 
         description = description.replace('\n', '').replace('\r', '')
-        description += '\nID of dataset MUST be ' + dataset.id
         return cls(
+            name=f'dataset-{dataset.id}',
             tenant_id=dataset.tenant_id,
             dataset_id=dataset.id,
             description=description,
             **kwargs
         )
 
-    def _run(self, dataset_id: str, query: str) -> str:
-        pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
-        match = re.search(pattern, dataset_id, re.IGNORECASE)
-        if match:
-            dataset_id = match.group()
-
+    def _run(self, query: str) -> str:
         dataset = db.session.query(Dataset).filter(
             Dataset.tenant_id == self.tenant_id,
-            Dataset.id == dataset_id
+            Dataset.id == self.dataset_id
         ).first()
 
         if not dataset:
-            return f'[{self.name} failed to find dataset with id {dataset_id}.]'
+            return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
 
         if dataset.indexing_technique == "economy":
             # use keyword table query