|
@@ -1,4 +1,3 @@
|
|
|
-import math
|
|
|
from typing import Optional
|
|
|
|
|
|
from langchain import WikipediaAPIWrapper
|
|
@@ -50,6 +49,7 @@ class OrchestratorRuleParser:
|
|
|
tool_configs = agent_mode_config.get('tools', [])
|
|
|
agent_provider_name = model_dict.get('provider', 'openai')
|
|
|
agent_model_name = model_dict.get('name', 'gpt-4')
|
|
|
+ dataset_configs = self.app_model_config.dataset_configs_dict
|
|
|
|
|
|
agent_model_instance = ModelFactory.get_text_generation_model(
|
|
|
tenant_id=self.tenant_id,
|
|
@@ -96,13 +96,14 @@ class OrchestratorRuleParser:
|
|
|
summary_model_instance = None
|
|
|
|
|
|
tools = self.to_tools(
|
|
|
- agent_model_instance=agent_model_instance,
|
|
|
tool_configs=tool_configs,
|
|
|
+ callbacks=[agent_callback, DifyStdOutCallbackHandler()],
|
|
|
+ agent_model_instance=agent_model_instance,
|
|
|
conversation_message_task=conversation_message_task,
|
|
|
rest_tokens=rest_tokens,
|
|
|
- callbacks=[agent_callback, DifyStdOutCallbackHandler()],
|
|
|
return_resource=return_resource,
|
|
|
- retriever_from=retriever_from
|
|
|
+ retriever_from=retriever_from,
|
|
|
+ dataset_configs=dataset_configs
|
|
|
)
|
|
|
|
|
|
if len(tools) == 0:
|
|
@@ -170,20 +171,12 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
|
|
|
- conversation_message_task: ConversationMessageTask,
|
|
|
- rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
|
|
|
- retriever_from: str = 'dev') -> list[BaseTool]:
|
|
|
+ def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
|
|
|
"""
|
|
|
Convert app agent tool configs to tools
|
|
|
|
|
|
- :param agent_model_instance:
|
|
|
- :param rest_tokens:
|
|
|
:param tool_configs: app agent tool configs
|
|
|
- :param conversation_message_task:
|
|
|
:param callbacks:
|
|
|
- :param return_resource:
|
|
|
- :param retriever_from:
|
|
|
:return:
|
|
|
"""
|
|
|
tools = []
|
|
@@ -195,15 +188,15 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
tool = None
|
|
|
if tool_type == "dataset":
|
|
|
- tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
|
|
|
+ tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
|
|
|
elif tool_type == "web_reader":
|
|
|
- tool = self.to_web_reader_tool(agent_model_instance)
|
|
|
+ tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
|
|
|
elif tool_type == "google_search":
|
|
|
- tool = self.to_google_search_tool()
|
|
|
+ tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
|
|
|
elif tool_type == "wikipedia":
|
|
|
- tool = self.to_wikipedia_tool()
|
|
|
+ tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
|
|
|
elif tool_type == "current_datetime":
|
|
|
- tool = self.to_current_datetime_tool()
|
|
|
+ tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
|
|
|
|
|
|
if tool:
|
|
|
if tool.callbacks is not None:
|
|
@@ -215,12 +208,15 @@ class OrchestratorRuleParser:
|
|
|
return tools
|
|
|
|
|
|
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
|
|
- rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
|
|
|
+ dataset_configs: dict, rest_tokens: int,
|
|
|
+ return_resource: bool = False, retriever_from: str = 'dev',
|
|
|
+ **kwargs) \
|
|
|
-> Optional[BaseTool]:
|
|
|
"""
|
|
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
|
|
:param rest_tokens:
|
|
|
:param tool_config:
|
|
|
+ :param dataset_configs:
|
|
|
:param conversation_message_task:
|
|
|
:param return_resource:
|
|
|
:param retriever_from:
|
|
@@ -238,10 +234,20 @@ class OrchestratorRuleParser:
|
|
|
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
|
|
return None
|
|
|
|
|
|
- k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
|
|
+ top_k = dataset_configs.get("top_k", 2)
|
|
|
+
|
|
|
+ # dynamically adjust top_k when the remaining token number is not enough to support top_k
|
|
|
+ top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
|
|
|
+
|
|
|
+ score_threshold = None
|
|
|
+ score_threshold_config = dataset_configs.get("score_threshold")
|
|
|
+ if score_threshold_config and score_threshold_config.get("enable"):
|
|
|
+ score_threshold = score_threshold_config.get("value")
|
|
|
+
|
|
|
tool = DatasetRetrieverTool.from_dataset(
|
|
|
dataset=dataset,
|
|
|
- k=k,
|
|
|
+ top_k=top_k,
|
|
|
+ score_threshold=score_threshold,
|
|
|
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
|
|
conversation_message_task=conversation_message_task,
|
|
|
return_resource=return_resource,
|
|
@@ -250,7 +256,7 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
return tool
|
|
|
|
|
|
- def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
|
|
|
+ def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
|
|
|
"""
|
|
|
A tool for reading web pages
|
|
|
|
|
@@ -278,7 +284,7 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
return tool
|
|
|
|
|
|
- def to_google_search_tool(self) -> Optional[BaseTool]:
|
|
|
+ def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
|
|
|
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
|
|
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
|
|
if not func_kwargs:
|
|
@@ -296,12 +302,12 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
return tool
|
|
|
|
|
|
- def to_current_datetime_tool(self) -> Optional[BaseTool]:
|
|
|
+ def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
|
|
|
tool = DatetimeTool()
|
|
|
|
|
|
return tool
|
|
|
|
|
|
- def to_wikipedia_tool(self) -> Optional[BaseTool]:
|
|
|
+ def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
|
|
|
class WikipediaInput(BaseModel):
|
|
|
query: str = Field(..., description="search query.")
|
|
|
|
|
@@ -312,22 +318,18 @@ class OrchestratorRuleParser:
|
|
|
)
|
|
|
|
|
|
@classmethod
|
|
|
- def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
|
|
- DEFAULT_K = 2
|
|
|
- CONTEXT_TOKENS_PERCENT = 0.3
|
|
|
- MAX_K = 10
|
|
|
-
|
|
|
+ def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
|
|
|
if rest_tokens == -1:
|
|
|
- return DEFAULT_K
|
|
|
+ return top_k
|
|
|
|
|
|
processing_rule = dataset.latest_process_rule
|
|
|
if not processing_rule:
|
|
|
- return DEFAULT_K
|
|
|
+ return top_k
|
|
|
|
|
|
if processing_rule.mode == "custom":
|
|
|
rules = processing_rule.rules_dict
|
|
|
if not rules:
|
|
|
- return DEFAULT_K
|
|
|
+ return top_k
|
|
|
|
|
|
segmentation = rules["segmentation"]
|
|
|
segment_max_tokens = segmentation["max_tokens"]
|
|
@@ -335,14 +337,7 @@ class OrchestratorRuleParser:
|
|
|
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
|
|
|
|
|
# when rest_tokens is less than default context tokens
|
|
|
- if rest_tokens < segment_max_tokens * DEFAULT_K:
|
|
|
+ if rest_tokens < segment_max_tokens * top_k:
|
|
|
return rest_tokens // segment_max_tokens
|
|
|
|
|
|
- context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
|
|
-
|
|
|
- # when context_limit_tokens is less than default context tokens, use default_k
|
|
|
- if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
|
|
- return DEFAULT_K
|
|
|
-
|
|
|
- # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
|
|
|
- return min(context_limit_tokens // segment_max_tokens, MAX_K)
|
|
|
+ return min(top_k, 10)
|