Bläddra i källkod

Remove langchain dataset retrival agent logic (#3311)

Jyong 1 år sedan
förälder
incheckning
b6de97ad53

+ 2 - 0
api/core/app/apps/chat/app_runner.py

@@ -156,6 +156,8 @@ class ChatAppRunner(AppRunner):
 
             dataset_retrieval = DatasetRetrieval()
             context = dataset_retrieval.retrieve(
+                app_id=app_record.id,
+                user_id=application_generate_entity.user_id,
                 tenant_id=app_record.tenant_id,
                 model_config=application_generate_entity.model_config,
                 config=app_config.dataset,

+ 2 - 0
api/core/app/apps/completion/app_runner.py

@@ -116,6 +116,8 @@ class CompletionAppRunner(AppRunner):
 
             dataset_retrieval = DatasetRetrieval()
             context = dataset_retrieval.retrieve(
+                app_id=app_record.id,
+                user_id=application_generate_entity.user_id,
                 tenant_id=app_record.tenant_id,
                 model_config=application_generate_entity.model_config,
                 config=dataset_config,

+ 0 - 59
api/core/rag/retrieval/agent/fake_llm.py

@@ -1,59 +0,0 @@
-import time
-from collections.abc import Mapping
-from typing import Any, Optional
-
-from langchain.callbacks.manager import CallbackManagerForLLMRun
-from langchain.chat_models.base import SimpleChatModel
-from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
-
-
-class FakeLLM(SimpleChatModel):
-    """Fake ChatModel for testing purposes."""
-
-    streaming: bool = False
-    """Whether to stream the results or not."""
-    response: str
-
-    @property
-    def _llm_type(self) -> str:
-        return "fake-chat-model"
-
-    def _call(
-        self,
-        messages: list[BaseMessage],
-        stop: Optional[list[str]] = None,
-        run_manager: Optional[CallbackManagerForLLMRun] = None,
-        **kwargs: Any,
-    ) -> str:
-        """First try to lookup in queries, else return 'foo' or 'bar'."""
-        return self.response
-
-    @property
-    def _identifying_params(self) -> Mapping[str, Any]:
-        return {"response": self.response}
-
-    def get_num_tokens(self, text: str) -> int:
-        return 0
-
-    def _generate(
-        self,
-        messages: list[BaseMessage],
-        stop: Optional[list[str]] = None,
-        run_manager: Optional[CallbackManagerForLLMRun] = None,
-        **kwargs: Any,
-    ) -> ChatResult:
-        output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
-        if self.streaming:
-            for token in output_str:
-                if run_manager:
-                    run_manager.on_llm_new_token(token)
-                    time.sleep(0.01)
-
-        message = AIMessage(content=output_str)
-        generation = ChatGeneration(message=message)
-        llm_output = {"token_usage": {
-            'prompt_tokens': 0,
-            'completion_tokens': 0,
-            'total_tokens': 0,
-        }}
-        return ChatResult(generations=[generation], llm_output=llm_output)

+ 0 - 46
api/core/rag/retrieval/agent/llm_chain.py

@@ -1,46 +0,0 @@
-from typing import Any, Optional
-
-from langchain import LLMChain as LCLLMChain
-from langchain.callbacks.manager import CallbackManagerForChainRun
-from langchain.schema import Generation, LLMResult
-from langchain.schema.language_model import BaseLanguageModel
-
-from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-from core.model_manager import ModelInstance
-from core.rag.retrieval.agent.fake_llm import FakeLLM
-
-
-class LLMChain(LCLLMChain):
-    model_config: ModelConfigWithCredentialsEntity
-    """The language model instance to use."""
-    llm: BaseLanguageModel = FakeLLM(response="")
-    parameters: dict[str, Any] = {}
-
-    def generate(
-        self,
-        input_list: list[dict[str, Any]],
-        run_manager: Optional[CallbackManagerForChainRun] = None,
-    ) -> LLMResult:
-        """Generate LLM result from inputs."""
-        prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
-        messages = prompts[0].to_messages()
-        prompt_messages = lc_messages_to_prompt_messages(messages)
-
-        model_instance = ModelInstance(
-            provider_model_bundle=self.model_config.provider_model_bundle,
-            model=self.model_config.model,
-        )
-
-        result = model_instance.invoke_llm(
-            prompt_messages=prompt_messages,
-            stream=False,
-            stop=stop,
-            model_parameters=self.parameters
-        )
-
-        generations = [
-            [Generation(text=result.message.content)]
-        ]
-
-        return LLMResult(generations=generations)

+ 0 - 179
api/core/rag/retrieval/agent/multi_dataset_router_agent.py

@@ -1,179 +0,0 @@
-from collections.abc import Sequence
-from typing import Any, Optional, Union
-
-from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
-from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
-from langchain.tools import BaseTool
-from pydantic import root_validator
-
-from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-from core.model_manager import ModelInstance
-from core.model_runtime.entities.message_entities import PromptMessageTool
-from core.rag.retrieval.agent.fake_llm import FakeLLM
-
-
-class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
-    """
-    An Multi Dataset Retrieve Agent driven by Router.
-    """
-    model_config: ModelConfigWithCredentialsEntity
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    @root_validator
-    def validate_llm(cls, values: dict) -> dict:
-        return values
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        return True
-
-    def plan(
-        self,
-        intermediate_steps: list[tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        if len(self.tools) == 0:
-            return AgentFinish(return_values={"output": ''}, log='')
-        elif len(self.tools) == 1:
-            tool = next(iter(self.tools))
-            rst = tool.run(tool_input={'query': kwargs['input']})
-            # output = ''
-            # rst_json = json.loads(rst)
-            # for item in rst_json:
-            #     output += f'{item["content"]}\n'
-            return AgentFinish(return_values={"output": rst}, log=rst)
-
-        if intermediate_steps:
-            _, observation = intermediate_steps[-1]
-            return AgentFinish(return_values={"output": observation}, log=observation)
-
-        try:
-            agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
-            if isinstance(agent_decision, AgentAction):
-                tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
-                    tool_inputs['query'] = kwargs['input']
-                    agent_decision.tool_input = tool_inputs
-            else:
-                agent_decision.return_values['output'] = ''
-            return agent_decision
-        except Exception as e:
-            raise e
-
-    def real_plan(
-        self,
-        intermediate_steps: list[tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-        prompt_messages = lc_messages_to_prompt_messages(messages)
-
-        model_instance = ModelInstance(
-            provider_model_bundle=self.model_config.provider_model_bundle,
-            model=self.model_config.model,
-        )
-
-        tools = []
-        for function in self.functions:
-            tool = PromptMessageTool(
-                **function
-            )
-
-            tools.append(tool)
-
-        result = model_instance.invoke_llm(
-            prompt_messages=prompt_messages,
-            tools=tools,
-            stream=False,
-            model_parameters={
-                'temperature': 0.2,
-                'top_p': 0.3,
-                'max_tokens': 1500
-            }
-        )
-
-        ai_message = AIMessage(
-            content=result.message.content or "",
-            additional_kwargs={
-                'function_call': {
-                    'id': result.message.tool_calls[0].id,
-                    **result.message.tool_calls[0].function.dict()
-                } if result.message.tool_calls else None
-            }
-        )
-
-        agent_decision = _parse_ai_message(ai_message)
-        return agent_decision
-
-    async def aplan(
-            self,
-            intermediate_steps: list[tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        raise NotImplementedError()
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_config: ModelConfigWithCredentialsEntity,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            **kwargs: Any,
-    ) -> BaseSingleActionAgent:
-        prompt = cls.create_prompt(
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=system_message,
-        )
-        return cls(
-            model_config=model_config,
-            llm=FakeLLM(response=''),
-            prompt=prompt,
-            tools=tools,
-            callback_manager=callback_manager,
-            **kwargs,
-        )

+ 0 - 0
api/core/rag/retrieval/agent/output_parser/__init__.py


+ 0 - 259
api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py

@@ -1,259 +0,0 @@
-import re
-from collections.abc import Sequence
-from typing import Any, Optional, Union, cast
-
-from langchain import BasePromptTemplate, PromptTemplate
-from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
-from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, OutputParserException
-from langchain.tools import BaseTool
-
-from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.rag.retrieval.agent.llm_chain import LLMChain
-
-FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
-Valid "action" values: "Final Answer" or {tool_names}
-
-Provide only ONE action per $JSON_BLOB, as shown:
-
-```
-{{{{
-  "action": $TOOL_NAME,
-  "action_input": $INPUT
-}}}}
-```
-
-Follow this format:
-
-Question: input question to answer
-Thought: consider previous and subsequent steps
-Action:
-```
-$JSON_BLOB
-```
-Observation: action result
-... (repeat Thought/Action/Observation N times)
-Thought: I know what to respond
-Action:
-```
-{{{{
-  "action": "Final Answer",
-  "action_input": "Final response to human"
-}}}}
-```"""
-
-
-class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
-    dataset_tools: Sequence[BaseTool]
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-        Using the ReACT mode to determine whether an agent is needed is costly,
-        so it's better to just use an Agent for reasoning, which is cheaper.
-
-        :param query:
-        :return:
-        """
-        return True
-
-    def plan(
-            self,
-            intermediate_steps: list[tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date,
-                along with observations
-            callbacks: Callbacks to run.
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        if len(self.dataset_tools) == 0:
-            return AgentFinish(return_values={"output": ''}, log='')
-        elif len(self.dataset_tools) == 1:
-            tool = next(iter(self.dataset_tools))
-            rst = tool.run(tool_input={'query': kwargs['input']})
-            return AgentFinish(return_values={"output": rst}, log=rst)
-
-        if intermediate_steps:
-            _, observation = intermediate_steps[-1]
-            return AgentFinish(return_values={"output": observation}, log=observation)
-
-        full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
-
-        try:
-            full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
-        except Exception as e:
-            raise e
-
-        try:
-            agent_decision = self.output_parser.parse(full_output)
-            if isinstance(agent_decision, AgentAction):
-                tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
-                    tool_inputs['query'] = kwargs['input']
-                    agent_decision.tool_input = tool_inputs
-                elif isinstance(tool_inputs, str):
-                    agent_decision.tool_input = kwargs['input']
-            else:
-                agent_decision.return_values['output'] = ''
-            return agent_decision
-        except OutputParserException:
-            return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
-                                          "I don't know how to respond to that."}, "")
-
-    @classmethod
-    def create_prompt(
-            cls,
-            tools: Sequence[BaseTool],
-            prefix: str = PREFIX,
-            suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-            memory_prompts: Optional[list[BasePromptTemplate]] = None,
-    ) -> BasePromptTemplate:
-        tool_strings = []
-        for tool in tools:
-            args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
-            tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
-        formatted_tools = "\n".join(tool_strings)
-        unique_tool_names = set(tool.name for tool in tools)
-        tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
-        format_instructions = format_instructions.format(tool_names=tool_names)
-        template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        _memory_prompts = memory_prompts or []
-        messages = [
-            SystemMessagePromptTemplate.from_template(template),
-            *_memory_prompts,
-            HumanMessagePromptTemplate.from_template(human_message_template),
-        ]
-        return ChatPromptTemplate(input_variables=input_variables, messages=messages)
-
-    @classmethod
-    def create_completion_prompt(
-            cls,
-            tools: Sequence[BaseTool],
-            prefix: str = PREFIX,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-    ) -> PromptTemplate:
-        """Create prompt in the style of the zero shot agent.
-
-        Args:
-            tools: List of tools the agent will have access to, used to format the
-                prompt.
-            prefix: String to put before the list of tools.
-            input_variables: List of input variables the final prompt will expect.
-
-        Returns:
-            A PromptTemplate with the template assembled from the pieces here.
-        """
-        suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
-Question: {input}
-Thought: {agent_scratchpad}
-"""
-
-        tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
-        tool_names = ", ".join([tool.name for tool in tools])
-        format_instructions = format_instructions.format(tool_names=tool_names)
-        template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        return PromptTemplate(template=template, input_variables=input_variables)
-
-    def _construct_scratchpad(
-            self, intermediate_steps: list[tuple[AgentAction, str]]
-    ) -> str:
-        agent_scratchpad = ""
-        for action, observation in intermediate_steps:
-            agent_scratchpad += action.log
-            agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
-
-        if not isinstance(agent_scratchpad, str):
-            raise ValueError("agent_scratchpad should be of type string.")
-        if agent_scratchpad:
-            llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_config.mode == "chat":
-                return (
-                    f"This was your previous work "
-                    f"(but I haven't seen any of it! I only see what "
-                    f"you return as final answer):\n{agent_scratchpad}"
-                )
-            else:
-                return agent_scratchpad
-        else:
-            return agent_scratchpad
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_config: ModelConfigWithCredentialsEntity,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            output_parser: Optional[AgentOutputParser] = None,
-            prefix: str = PREFIX,
-            suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-            memory_prompts: Optional[list[BasePromptTemplate]] = None,
-            **kwargs: Any,
-    ) -> Agent:
-        """Construct an agent from an LLM and tools."""
-        cls._validate_tools(tools)
-        if model_config.mode == "chat":
-            prompt = cls.create_prompt(
-                tools,
-                prefix=prefix,
-                suffix=suffix,
-                human_message_template=human_message_template,
-                format_instructions=format_instructions,
-                input_variables=input_variables,
-                memory_prompts=memory_prompts,
-            )
-        else:
-            prompt = cls.create_completion_prompt(
-                tools,
-                prefix=prefix,
-                format_instructions=format_instructions,
-                input_variables=input_variables
-            )
-
-        llm_chain = LLMChain(
-            model_config=model_config,
-            prompt=prompt,
-            callback_manager=callback_manager,
-            parameters={
-                'temperature': 0.2,
-                'top_p': 0.3,
-                'max_tokens': 1500
-            }
-        )
-        tool_names = [tool.name for tool in tools]
-        _output_parser = output_parser
-        return cls(
-            llm_chain=llm_chain,
-            allowed_tools=tool_names,
-            output_parser=_output_parser,
-            dataset_tools=tools,
-            **kwargs,
-        )

+ 0 - 117
api/core/rag/retrieval/agent_based_dataset_executor.py

@@ -1,117 +0,0 @@
-import logging
-from typing import Optional, Union
-
-from langchain.agents import AgentExecutor as LCAgentExecutor
-from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
-from langchain.callbacks.manager import Callbacks
-from langchain.tools import BaseTool
-from pydantic import BaseModel, Extra
-
-from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.agent_entities import PlanningStrategy
-from core.entities.message_entities import prompt_messages_to_lc_messages
-from core.helper import moderation
-from core.memory.token_buffer_memory import TokenBufferMemory
-from core.model_runtime.errors.invoke import InvokeError
-from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
-from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
-from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
-from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
-from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
-
-
-class AgentConfiguration(BaseModel):
-    strategy: PlanningStrategy
-    model_config: ModelConfigWithCredentialsEntity
-    tools: list[BaseTool]
-    summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
-    memory: Optional[TokenBufferMemory] = None
-    callbacks: Callbacks = None
-    max_iterations: int = 6
-    max_execution_time: Optional[float] = None
-    early_stopping_method: str = "generate"
-    # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        extra = Extra.forbid
-        arbitrary_types_allowed = True
-
-
-class AgentExecuteResult(BaseModel):
-    strategy: PlanningStrategy
-    output: Optional[str]
-    configuration: AgentConfiguration
-
-
-class AgentExecutor:
-    def __init__(self, configuration: AgentConfiguration):
-        self.configuration = configuration
-        self.agent = self._init_agent()
-
-    def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
-        if self.configuration.strategy == PlanningStrategy.ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools
-                                        if isinstance(t, DatasetRetrieverTool)
-                                        or isinstance(t, DatasetMultiRetrieverTool)]
-            agent = MultiDatasetRouterAgent.from_llm_and_tools(
-                model_config=self.configuration.model_config,
-                tools=self.configuration.tools,
-                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
-                if self.configuration.memory else None,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools
-                                        if isinstance(t, DatasetRetrieverTool)
-                                        or isinstance(t, DatasetMultiRetrieverTool)]
-            agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
-                model_config=self.configuration.model_config,
-                tools=self.configuration.tools,
-                output_parser=StructuredChatOutputParser(),
-                verbose=True
-            )
-        else:
-            raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
-
-        return agent
-
-    def should_use_agent(self, query: str) -> bool:
-        return self.agent.should_use_agent(query)
-
-    def run(self, query: str) -> AgentExecuteResult:
-        moderation_result = moderation.check_moderation(
-            self.configuration.model_config,
-            query
-        )
-
-        if moderation_result:
-            return AgentExecuteResult(
-                output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
-                strategy=self.configuration.strategy,
-                configuration=self.configuration
-            )
-
-        agent_executor = LCAgentExecutor.from_agent_and_tools(
-            agent=self.agent,
-            tools=self.configuration.tools,
-            max_iterations=self.configuration.max_iterations,
-            max_execution_time=self.configuration.max_execution_time,
-            early_stopping_method=self.configuration.early_stopping_method,
-            callbacks=self.configuration.callbacks
-        )
-
-        try:
-            output = agent_executor.run(input=query)
-        except InvokeError as ex:
-            raise ex
-        except Exception as ex:
-            logging.exception("agent_executor run failed")
-            output = None
-
-        return AgentExecuteResult(
-            output=output,
-            strategy=self.configuration.strategy,
-            configuration=self.configuration
-        )

+ 287 - 94
api/core/rag/retrieval/dataset_retrieval.py

@@ -1,23 +1,40 @@
+import threading
 from typing import Optional, cast
 
-from langchain.tools import BaseTool
+from flask import Flask, current_app
 
 from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.agent_entities import PlanningStrategy
 from core.memory.token_buffer_memory import TokenBufferMemory
-from core.model_runtime.entities.model_entities import ModelFeature
+from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.message_entities import PromptMessageTool
+from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
-from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
-from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document
+from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
+from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
+from core.rerank.rerank import RerankRunner
 from extensions.ext_database import db
-from models.dataset import Dataset
+from models.dataset import Dataset, DatasetQuery, DocumentSegment
+from models.dataset import Document as DatasetDocument
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enabled': False
+}
 
 
 class DatasetRetrieval:
-    def retrieve(self, tenant_id: str,
+    def retrieve(self, app_id: str, user_id: str, tenant_id: str,
                  model_config: ModelConfigWithCredentialsEntity,
                  config: DatasetEntity,
                  query: str,
@@ -27,6 +44,8 @@ class DatasetRetrieval:
                  memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
         """
         Retrieve dataset.
+        :param app_id: app_id
+        :param user_id: user_id
         :param tenant_id: tenant id
         :param model_config: model config
         :param config: dataset config
@@ -38,12 +57,22 @@ class DatasetRetrieval:
         :return:
         """
         dataset_ids = config.dataset_ids
+        if len(dataset_ids) == 0:
+            return None
         retrieve_config = config.retrieve_config
 
         # check model is support tool calling
         model_type_instance = model_config.provider_model_bundle.model_type_instance
         model_type_instance = cast(LargeLanguageModel, model_type_instance)
 
+        model_manager = ModelManager()
+        model_instance = model_manager.get_model_instance(
+            tenant_id=tenant_id,
+            model_type=ModelType.LLM,
+            provider=model_config.provider,
+            model=model_config.model
+        )
+
         # get model schema
         model_schema = model_type_instance.get_model_schema(
             model=model_config.model,
@@ -59,56 +88,6 @@ class DatasetRetrieval:
             if ModelFeature.TOOL_CALL in features \
                     or ModelFeature.MULTI_TOOL_CALL in features:
                 planning_strategy = PlanningStrategy.ROUTER
-
-        dataset_retriever_tools = self.to_dataset_retriever_tool(
-            tenant_id=tenant_id,
-            dataset_ids=dataset_ids,
-            retrieve_config=retrieve_config,
-            return_resource=show_retrieve_source,
-            invoke_from=invoke_from,
-            hit_callback=hit_callback
-        )
-
-        if len(dataset_retriever_tools) == 0:
-            return None
-
-        agent_configuration = AgentConfiguration(
-            strategy=planning_strategy,
-            model_config=model_config,
-            tools=dataset_retriever_tools,
-            memory=memory,
-            max_iterations=10,
-            max_execution_time=400.0,
-            early_stopping_method="generate"
-        )
-
-        agent_executor = AgentExecutor(agent_configuration)
-
-        should_use_agent = agent_executor.should_use_agent(query)
-        if not should_use_agent:
-            return None
-
-        result = agent_executor.run(query)
-
-        return result.output
-
-    def to_dataset_retriever_tool(self, tenant_id: str,
-                                  dataset_ids: list[str],
-                                  retrieve_config: DatasetRetrieveConfigEntity,
-                                  return_resource: bool,
-                                  invoke_from: InvokeFrom,
-                                  hit_callback: DatasetIndexToolCallbackHandler) \
-            -> Optional[list[BaseTool]]:
-        """
-        A dataset tool is a tool that can be used to retrieve information from a dataset
-        :param tenant_id: tenant id
-        :param dataset_ids: dataset ids
-        :param retrieve_config: retrieve config
-        :param return_resource: return resource
-        :param invoke_from: invoke from
-        :param hit_callback: hit callback
-        """
-        tools = []
         available_datasets = []
         for dataset_id in dataset_ids:
             # get dataset from dataset id
@@ -127,56 +106,270 @@ class DatasetRetrieval:
                 continue
 
             available_datasets.append(dataset)
-
+        all_documents = []
+        user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
         if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
+            all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
+                                                 model_instance,
+                                                 model_config, planning_strategy)
+        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
+            all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
+                                                   available_datasets, query, retrieve_config.top_k,
+                                                   retrieve_config.score_threshold,
+                                                   retrieve_config.reranking_model.get('reranking_provider_name'),
+                                                   retrieve_config.reranking_model.get('reranking_model_name'))
+
+        document_score_list = {}
+        for item in all_documents:
+            if 'score' in item.metadata and item.metadata['score']:
+                document_score_list[item.metadata['doc_id']] = item.metadata['score']
+
+        document_context_list = []
+        index_node_ids = [document.metadata['doc_id'] for document in all_documents]
+        segments = DocumentSegment.query.filter(
+            DocumentSegment.dataset_id.in_(dataset_ids),
+            DocumentSegment.completed_at.isnot(None),
+            DocumentSegment.status == 'completed',
+            DocumentSegment.enabled == True,
+            DocumentSegment.index_node_id.in_(index_node_ids)
+        ).all()
+
+        if segments:
+            index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
+            sorted_segments = sorted(segments,
+                                     key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
+                                                                                       float('inf')))
+            for segment in sorted_segments:
+                if segment.answer:
+                    document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+                else:
+                    document_context_list.append(segment.content)
+            if show_retrieve_source:
+                context_list = []
+                resource_number = 1
+                for segment in sorted_segments:
+                    dataset = Dataset.query.filter_by(
+                        id=segment.dataset_id
+                    ).first()
+                    document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
+                                                            DatasetDocument.enabled == True,
+                                                            DatasetDocument.archived == False,
+                                                            ).first()
+                    if dataset and document:
+                        source = {
+                            'position': resource_number,
+                            'dataset_id': dataset.id,
+                            'dataset_name': dataset.name,
+                            'document_id': document.id,
+                            'document_name': document.name,
+                            'data_source_type': document.data_source_type,
+                            'segment_id': segment.id,
+                            'retriever_from': invoke_from.to_source(),
+                            'score': document_score_list.get(segment.index_node_id, None)
+                        }
+
+                        if invoke_from.to_source() == 'dev':
+                            source['hit_count'] = segment.hit_count
+                            source['word_count'] = segment.word_count
+                            source['segment_position'] = segment.position
+                            source['index_node_hash'] = segment.index_node_hash
+                        if segment.answer:
+                            source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
+                        else:
+                            source['content'] = segment.content
+                        context_list.append(source)
+                    resource_number += 1
+                if hit_callback:
+                    hit_callback.return_retriever_resource_info(context_list)
+
+            return str("\n".join(document_context_list))
+        return ''
+
+    def single_retrieve(self, app_id: str,
+                        tenant_id: str,
+                        user_id: str,
+                        user_from: str,
+                        available_datasets: list,
+                        query: str,
+                        model_instance: ModelInstance,
+                        model_config: ModelConfigWithCredentialsEntity,
+                        planning_strategy: PlanningStrategy,
+                        ):
+        tools = []
+        for dataset in available_datasets:
+            description = dataset.description
+            if not description:
+                description = 'useful for when you want to answer queries about the ' + dataset.name
+
+            description = description.replace('\n', '').replace('\r', '')
+            message_tool = PromptMessageTool(
+                name=dataset.id,
+                description=description,
+                parameters={
+                    "type": "object",
+                    "properties": {},
+                    "required": [],
+                }
+            )
+            tools.append(message_tool)
+        dataset_id = None
+        if planning_strategy == PlanningStrategy.REACT_ROUTER:
+            react_multi_dataset_router = ReactMultiDatasetRouter()
+            dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
+                                                           user_id, tenant_id)
+
+        elif planning_strategy == PlanningStrategy.ROUTER:
+            function_call_router = FunctionCallMultiDatasetRouter()
+            dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
+
+        if dataset_id:
             # get retrieval model config
-            default_retrieval_model = {
-                'search_method': 'semantic_search',
-                'reranking_enable': False,
-                'reranking_model': {
-                    'reranking_provider_name': '',
-                    'reranking_model_name': ''
-                },
-                'top_k': 2,
-                'score_threshold_enabled': False
-            }
-
-            for dataset in available_datasets:
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+            if dataset:
                 retrieval_model_config = dataset.retrieval_model \
                     if dataset.retrieval_model else default_retrieval_model
 
                 # get top k
                 top_k = retrieval_model_config['top_k']
-
+                # get retrieval method
+                if dataset.indexing_technique == "economy":
+                    retrival_method = 'keyword_search'
+                else:
+                    retrival_method = retrieval_model_config['search_method']
+                # get reranking model
+                reranking_model = retrieval_model_config['reranking_model'] \
+                    if retrieval_model_config['reranking_enable'] else None
                 # get score threshold
-                score_threshold = None
+                score_threshold = .0
                 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
                 if score_threshold_enabled:
                     score_threshold = retrieval_model_config.get("score_threshold")
 
-                tool = DatasetRetrieverTool.from_dataset(
-                    dataset=dataset,
-                    top_k=top_k,
-                    score_threshold=score_threshold,
-                    hit_callbacks=[hit_callback],
-                    return_resource=return_resource,
-                    retriever_from=invoke_from.to_source()
-                )
+                results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
+                                                    query=query,
+                                                    top_k=top_k, score_threshold=score_threshold,
+                                                    reranking_model=reranking_model)
+                self._on_query(query, [dataset_id], app_id, user_from, user_id)
+                if results:
+                    self._on_retrival_end(results)
+                return results
+        return []
 
-                tools.append(tool)
-        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
-            tool = DatasetMultiRetrieverTool.from_dataset(
-                dataset_ids=[dataset.id for dataset in available_datasets],
-                tenant_id=tenant_id,
-                top_k=retrieve_config.top_k or 2,
-                score_threshold=retrieve_config.score_threshold,
-                hit_callbacks=[hit_callback],
-                return_resource=return_resource,
-                retriever_from=invoke_from.to_source(),
-                reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
-                reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
+    def multiple_retrieve(self,
+                          app_id: str,
+                          tenant_id: str,
+                          user_id: str,
+                          user_from: str,
+                          available_datasets: list,
+                          query: str,
+                          top_k: int,
+                          score_threshold: float,
+                          reranking_provider_name: str,
+                          reranking_model_name: str):
+        threads = []
+        all_documents = []
+        dataset_ids = [dataset.id for dataset in available_datasets]
+        for dataset in available_datasets:
+            retrieval_thread = threading.Thread(target=self._retriever, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset.id,
+                'query': query,
+                'top_k': top_k,
+                'all_documents': all_documents,
+            })
+            threads.append(retrieval_thread)
+            retrieval_thread.start()
+        for thread in threads:
+            thread.join()
+        # do rerank for searched documents
+        model_manager = ModelManager()
+        rerank_model_instance = model_manager.get_model_instance(
+            tenant_id=tenant_id,
+            provider=reranking_provider_name,
+            model_type=ModelType.RERANK,
+            model=reranking_model_name
+        )
+
+        rerank_runner = RerankRunner(rerank_model_instance)
+        all_documents = rerank_runner.run(query, all_documents,
+                                          score_threshold,
+                                          top_k)
+        self._on_query(query, dataset_ids, app_id, user_from, user_id)
+        if all_documents:
+            self._on_retrival_end(all_documents)
+        return all_documents
+
+    def _on_retrival_end(self, documents: list[Document]) -> None:
+        """Handle retrival end."""
+        for document in documents:
+            query = db.session.query(DocumentSegment).filter(
+                DocumentSegment.index_node_id == document.metadata['doc_id']
             )
 
-            tools.append(tool)
+            # if 'dataset_id' in document.metadata:
+            if 'dataset_id' in document.metadata:
+                query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
+
+            # add hit count to document segment
+            query.update(
+                {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
+                synchronize_session=False
+            )
+
+            db.session.commit()
+
+    def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
+        """
+        Handle query.
+        """
+        if not query:
+            return
+        for dataset_id in dataset_ids:
+            dataset_query = DatasetQuery(
+                dataset_id=dataset_id,
+                content=query,
+                source='app',
+                source_app_id=app_id,
+                created_by_role=user_from,
+                created_by=user_id
+            )
+            db.session.add(dataset_query)
+        db.session.commit()
+
+    def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            if not dataset:
+                return []
+
+            # get retrieval model , if the model is not setting , using default
+            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
+
+            if dataset.indexing_technique == "economy":
+                # use keyword table query
+                documents = RetrievalService.retrieve(retrival_method='keyword_search',
+                                                      dataset_id=dataset.id,
+                                                      query=query,
+                                                      top_k=top_k
+                                                      )
+                if documents:
+                    all_documents.extend(documents)
+            else:
+                if top_k > 0:
+                    # retrieval source
+                    documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                          dataset_id=dataset.id,
+                                                          query=query,
+                                                          top_k=top_k,
+                                                          score_threshold=retrieval_model['score_threshold']
+                                                          if retrieval_model['score_threshold_enabled'] else None,
+                                                          reranking_model=retrieval_model['reranking_model']
+                                                          if retrieval_model['reranking_enable'] else None
+                                                          )
 
-        return tools
+                    all_documents.extend(documents)

+ 0 - 0
api/core/rag/retrieval/agent/__init__.py → api/core/rag/retrieval/output_parser/__init__.py


+ 0 - 0
api/core/rag/retrieval/agent/output_parser/structured_chat.py → api/core/rag/retrieval/output_parser/structured_chat.py


+ 0 - 0
api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py → api/core/rag/retrieval/router/multi_dataset_function_call_router.py


+ 11 - 11
api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py → api/core/rag/retrieval/router/multi_dataset_react_route.py

@@ -12,8 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
-from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
-from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
+from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
 from core.workflow.nodes.llm.llm_node import LLMNode
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -55,11 +54,10 @@ class ReactMultiDatasetRouter:
             self,
             query: str,
             dataset_tools: list[PromptMessageTool],
-            node_data: KnowledgeRetrievalNodeData,
             model_config: ModelConfigWithCredentialsEntity,
             model_instance: ModelInstance,
             user_id: str,
-            tenant_id: str,
+            tenant_id: str
 
     ) -> Union[str, None]:
         """Given input, decided what to do.
@@ -72,7 +70,8 @@ class ReactMultiDatasetRouter:
             return dataset_tools[0].name
 
         try:
-            return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
+            return self._react_invoke(query=query, model_config=model_config,
+                                      model_instance=model_instance,
                                       tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
         except Exception as e:
             return None
@@ -80,7 +79,6 @@ class ReactMultiDatasetRouter:
     def _react_invoke(
             self,
             query: str,
-            node_data: KnowledgeRetrievalNodeData,
             model_config: ModelConfigWithCredentialsEntity,
             model_instance: ModelInstance,
             tools: Sequence[PromptMessageTool],
@@ -121,7 +119,7 @@ class ReactMultiDatasetRouter:
             model_config=model_config
         )
         result_text, usage = self._invoke_llm(
-            node_data=node_data,
+            completion_param=model_config.parameters,
             model_instance=model_instance,
             prompt_messages=prompt_messages,
             stop=stop,
@@ -134,10 +132,11 @@ class ReactMultiDatasetRouter:
             return agent_decision.tool
         return None
 
-    def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
+    def _invoke_llm(self, completion_param: dict,
                     model_instance: ModelInstance,
                     prompt_messages: list[PromptMessage],
-                    stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
+                    stop: list[str], user_id: str, tenant_id: str
+                    ) -> tuple[str, LLMUsage]:
         """
             Invoke large language model
             :param node_data: node data
@@ -148,7 +147,7 @@ class ReactMultiDatasetRouter:
         """
         invoke_result = model_instance.invoke_llm(
             prompt_messages=prompt_messages,
-            model_parameters=node_data.single_retrieval_config.model.completion_params,
+            model_parameters=completion_param,
             stop=stop,
             stream=True,
             user=user_id,
@@ -203,7 +202,8 @@ class ReactMultiDatasetRouter:
     ) -> list[ChatModelMessage]:
         tool_strings = []
         for tool in tools:
-            tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
+            tool_strings.append(
+                f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
         formatted_tools = "\n".join(tool_strings)
         unique_tool_names = set(tool.name for tool in tools)
         tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)

+ 39 - 201
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -1,28 +1,21 @@
-import threading
 from typing import Any, cast
 
-from flask import Flask, current_app
-
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.entities.agent_entities import PlanningStrategy
 from core.entities.model_entities import ModelStatus
 from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
 from core.model_manager import ModelInstance, ModelManager
-from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.rag.datasource.retrieval_service import RetrievalService
-from core.rerank.rerank import RerankRunner
+from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
-from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
-from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
 from extensions.ext_database import db
-from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
+from models.dataset import Dataset, Document, DocumentSegment
 from models.workflow import WorkflowNodeExecutionStatus
 
 default_retrieval_model = {
@@ -106,10 +99,45 @@ class KnowledgeRetrievalNode(BaseNode):
 
             available_datasets.append(dataset)
         all_documents = []
+        dataset_retrieval = DatasetRetrieval()
         if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
-            all_documents = self._single_retrieve(available_datasets, node_data, query)
+            # fetch model config
+            model_instance, model_config = self._fetch_model_config(node_data)
+            # check model is support tool calling
+            model_type_instance = model_config.provider_model_bundle.model_type_instance
+            model_type_instance = cast(LargeLanguageModel, model_type_instance)
+            # get model schema
+            model_schema = model_type_instance.get_model_schema(
+                model=model_config.model,
+                credentials=model_config.credentials
+            )
+
+            if model_schema:
+                planning_strategy = PlanningStrategy.REACT_ROUTER
+                features = model_schema.features
+                if features:
+                    if ModelFeature.TOOL_CALL in features \
+                            or ModelFeature.MULTI_TOOL_CALL in features:
+                        planning_strategy = PlanningStrategy.ROUTER
+                all_documents = dataset_retrieval.single_retrieve(
+                    available_datasets=available_datasets,
+                    tenant_id=self.tenant_id,
+                    user_id=self.user_id,
+                    app_id=self.app_id,
+                    user_from=self.user_from.value,
+                    query=query,
+                    model_config=model_config,
+                    model_instance=model_instance,
+                    planning_strategy=planning_strategy
+                )
         elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
-            all_documents = self._multiple_retrieve(available_datasets, node_data, query)
+            all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
+                                                                self.user_from.value,
+                                                                available_datasets, query,
+                                                                node_data.multiple_retrieval_config.top_k,
+                                                                node_data.multiple_retrieval_config.score_threshold,
+                                                                node_data.multiple_retrieval_config.reranking_model.provider,
+                                                                node_data.multiple_retrieval_config.reranking_model.model)
 
         context_list = []
         if all_documents:
@@ -184,87 +212,6 @@ class KnowledgeRetrievalNode(BaseNode):
         variable_mapping['query'] = node_data.query_variable_selector
         return variable_mapping
 
-    def _single_retrieve(self, available_datasets, node_data, query):
-        tools = []
-        for dataset in available_datasets:
-            description = dataset.description
-            if not description:
-                description = 'useful for when you want to answer queries about the ' + dataset.name
-
-            description = description.replace('\n', '').replace('\r', '')
-            message_tool = PromptMessageTool(
-                name=dataset.id,
-                description=description,
-                parameters={
-                    "type": "object",
-                    "properties": {},
-                    "required": [],
-                }
-            )
-            tools.append(message_tool)
-        # fetch model config
-        model_instance, model_config = self._fetch_model_config(node_data)
-        # check model is support tool calling
-        model_type_instance = model_config.provider_model_bundle.model_type_instance
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
-        # get model schema
-        model_schema = model_type_instance.get_model_schema(
-            model=model_config.model,
-            credentials=model_config.credentials
-        )
-
-        if not model_schema:
-            return None
-        planning_strategy = PlanningStrategy.REACT_ROUTER
-        features = model_schema.features
-        if features:
-            if ModelFeature.TOOL_CALL in features \
-                    or ModelFeature.MULTI_TOOL_CALL in features:
-                planning_strategy = PlanningStrategy.ROUTER
-        dataset_id = None
-        if planning_strategy == PlanningStrategy.REACT_ROUTER:
-            react_multi_dataset_router = ReactMultiDatasetRouter()
-            dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
-                                                           self.user_id, self.tenant_id)
-
-        elif planning_strategy == PlanningStrategy.ROUTER:
-            function_call_router = FunctionCallMultiDatasetRouter()
-            dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
-        if dataset_id:
-            # get retrieval model config
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-            if dataset:
-                retrieval_model_config = dataset.retrieval_model \
-                    if dataset.retrieval_model else default_retrieval_model
-
-                # get top k
-                top_k = retrieval_model_config['top_k']
-                # get retrieval method
-                if dataset.indexing_technique == "economy":
-                    retrival_method = 'keyword_search'
-                else:
-                    retrival_method = retrieval_model_config['search_method']
-                # get reranking model
-                reranking_model=retrieval_model_config['reranking_model'] \
-                    if retrieval_model_config['reranking_enable'] else None
-                # get score threshold
-                score_threshold = .0
-                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
-                if score_threshold_enabled:
-                    score_threshold = retrieval_model_config.get("score_threshold")
-
-                results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
-                                                    query=query,
-                                                    top_k=top_k, score_threshold=score_threshold,
-                                                    reranking_model=reranking_model)
-                self._on_query(query, [dataset_id])
-                if results:
-                    self._on_retrival_end(results)
-                return results
-        return []
-
     def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
         ModelInstance, ModelConfigWithCredentialsEntity]:
         """
@@ -335,112 +282,3 @@ class KnowledgeRetrievalNode(BaseNode):
             parameters=completion_params,
             stop=stop,
         )
-
-    def _multiple_retrieve(self, available_datasets, node_data, query):
-        threads = []
-        all_documents = []
-        dataset_ids = [dataset.id for dataset in available_datasets]
-        for dataset in available_datasets:
-            retrieval_thread = threading.Thread(target=self._retriever, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'dataset_id': dataset.id,
-                'query': query,
-                'top_k': node_data.multiple_retrieval_config.top_k,
-                'all_documents': all_documents,
-            })
-            threads.append(retrieval_thread)
-            retrieval_thread.start()
-        for thread in threads:
-            thread.join()
-        # do rerank for searched documents
-        model_manager = ModelManager()
-        rerank_model_instance = model_manager.get_model_instance(
-            tenant_id=self.tenant_id,
-            provider=node_data.multiple_retrieval_config.reranking_model.provider,
-            model_type=ModelType.RERANK,
-            model=node_data.multiple_retrieval_config.reranking_model.model
-        )
-
-        rerank_runner = RerankRunner(rerank_model_instance)
-        all_documents = rerank_runner.run(query, all_documents,
-                                          node_data.multiple_retrieval_config.score_threshold,
-                                          node_data.multiple_retrieval_config.top_k)
-        self._on_query(query, dataset_ids)
-        if all_documents:
-            self._on_retrival_end(all_documents)
-        return all_documents
-
-    def _on_retrival_end(self, documents: list[Document]) -> None:
-        """Handle retrival end."""
-        for document in documents:
-            query = db.session.query(DocumentSegment).filter(
-                DocumentSegment.index_node_id == document.metadata['doc_id']
-            )
-
-            # if 'dataset_id' in document.metadata:
-            if 'dataset_id' in document.metadata:
-                query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
-
-            # add hit count to document segment
-            query.update(
-                {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
-                synchronize_session=False
-            )
-
-            db.session.commit()
-
-    def _on_query(self, query: str, dataset_ids: list[str]) -> None:
-        """
-        Handle query.
-        """
-        if not query:
-            return
-        for dataset_id in dataset_ids:
-            dataset_query = DatasetQuery(
-                dataset_id=dataset_id,
-                content=query,
-                source='app',
-                source_app_id=self.app_id,
-                created_by_role=self.user_from.value,
-                created_by=self.user_id
-            )
-            db.session.add(dataset_query)
-        db.session.commit()
-
-    def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
-        with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.tenant_id == self.tenant_id,
-                Dataset.id == dataset_id
-            ).first()
-
-            if not dataset:
-                return []
-
-            # get retrieval model , if the model is not setting , using default
-            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
-
-            if dataset.indexing_technique == "economy":
-                # use keyword table query
-                documents = RetrievalService.retrieve(retrival_method='keyword_search',
-                                                      dataset_id=dataset.id,
-                                                      query=query,
-                                                      top_k=top_k
-                                                      )
-                if documents:
-                    all_documents.extend(documents)
-            else:
-                if top_k > 0:
-                    # retrieval source
-                    documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
-                                                          dataset_id=dataset.id,
-                                                          query=query,
-                                                          top_k=top_k,
-                                                          score_threshold=retrieval_model['score_threshold']
-                                                          if retrieval_model['score_threshold_enabled'] else None,
-                                                          reranking_model=retrieval_model['reranking_model']
-                                                          if retrieval_model['reranking_enable'] else None
-                                                          )
-
-                    all_documents.extend(documents)
-