Przeglądaj źródła

refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614)

takatost 1 rok temu
rodzic
commit
dd961985f0
29 zmienionych plików z 41 dodań i 2016 usunięć
  1. 0 49
      api/core/agent/agent/calc_token_mixin.py
  2. 0 361
      api/core/agent/agent/openai_function_call.py
  3. 0 306
      api/core/agent/agent/structured_chat.py
  4. 1 39
      api/core/app_runner/assistant_app_runner.py
  5. 1 1
      api/core/app_runner/basic_app_runner.py
  6. 8 0
      api/core/entities/agent_entities.py
  7. 0 199
      api/core/features/agent_runner.py
  8. 0 0
      api/core/features/dataset_retrieval/__init__.py
  9. 0 0
      api/core/features/dataset_retrieval/agent/__init__.py
  10. 0 0
      api/core/features/dataset_retrieval/agent/agent_llm_callback.py
  11. 0 0
      api/core/features/dataset_retrieval/agent/fake_llm.py
  12. 2 2
      api/core/features/dataset_retrieval/agent/llm_chain.py
  13. 1 1
      api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py
  14. 0 0
      api/core/features/dataset_retrieval/agent/output_parser/__init__.py
  15. 0 0
      api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py
  16. 1 1
      api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py
  17. 6 36
      api/core/features/dataset_retrieval/agent_based_dataset_executor.py
  18. 2 1
      api/core/features/dataset_retrieval/dataset_retrieval.py
  19. 0 189
      api/core/third_party/spark/spark_llm.py
  20. 0 24
      api/core/tool/current_datetime_tool.py
  21. 0 63
      api/core/tool/provider/base.py
  22. 0 2
      api/core/tool/provider/errors.py
  23. 0 77
      api/core/tool/provider/serpapi_provider.py
  24. 0 43
      api/core/tool/provider/tool_provider_service.py
  25. 0 51
      api/core/tool/serpapi_wrapper.py
  26. 0 443
      api/core/tool/web_reader_tool.py
  27. 18 18
      api/core/tools/tool/dataset_retriever_tool.py
  28. 0 109
      api/core/tools/utils/web_reader_tool.py
  29. 1 1
      api/services/app_model_config_service.py

+ 0 - 49
api/core/agent/agent/calc_token_mixin.py

@@ -1,49 +0,0 @@
-from typing import cast
-
-from core.entities.application_entities import ModelConfigEntity
-from core.model_runtime.entities.message_entities import PromptMessage
-from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-
-
-class CalcTokenMixin:
-
-    def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
-        """
-        Got the rest tokens available for the model after excluding messages tokens and completion max tokens
-
-        :param model_config:
-        :param messages:
-        :return:
-        """
-        model_type_instance = model_config.provider_model_bundle.model_type_instance
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
-        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
-
-        max_tokens = 0
-        for parameter_rule in model_config.model_schema.parameter_rules:
-            if (parameter_rule.name == 'max_tokens'
-                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
-                max_tokens = (model_config.parameters.get(parameter_rule.name)
-                              or model_config.parameters.get(parameter_rule.use_template)) or 0
-
-        if model_context_tokens is None:
-            return 0
-
-        if max_tokens is None:
-            max_tokens = 0
-
-        prompt_tokens = model_type_instance.get_num_tokens(
-            model_config.model,
-            model_config.credentials,
-            messages
-        )
-
-        rest_tokens = model_context_tokens - max_tokens - prompt_tokens
-
-        return rest_tokens
-
-
-class ExceededLLMTokensLimitError(Exception):
-    pass

+ 0 - 361
api/core/agent/agent/openai_function_call.py

@@ -1,361 +0,0 @@
-from collections.abc import Sequence
-from typing import Any, Optional, Union
-
-from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
-from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
-from langchain.memory.prompt import SUMMARY_PROMPT
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import (
-    AgentAction,
-    AgentFinish,
-    AIMessage,
-    BaseMessage,
-    HumanMessage,
-    SystemMessage,
-    get_buffer_string,
-)
-from langchain.tools import BaseTool
-from pydantic import root_validator
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-from core.model_manager import ModelInstance
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
-from core.third_party.langchain.llms.fake import FakeLLM
-
-
-class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
-    moving_summary_buffer: str = ""
-    moving_summary_index: int = 0
-    summary_model_config: ModelConfigEntity = None
-    model_config: ModelConfigEntity
-    agent_llm_callback: Optional[AgentLLMCallback] = None
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    @root_validator
-    def validate_llm(cls, values: dict) -> dict:
-        return values
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_config: ModelConfigEntity,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            agent_llm_callback: Optional[AgentLLMCallback] = None,
-            **kwargs: Any,
-    ) -> BaseSingleActionAgent:
-        prompt = cls.create_prompt(
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=system_message,
-        )
-        return cls(
-            model_config=model_config,
-            llm=FakeLLM(response=''),
-            prompt=prompt,
-            tools=tools,
-            callback_manager=callback_manager,
-            agent_llm_callback=agent_llm_callback,
-            **kwargs,
-        )
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        original_max_tokens = 0
-        for parameter_rule in self.model_config.model_schema.parameter_rules:
-            if (parameter_rule.name == 'max_tokens'
-                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
-                original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
-                              or self.model_config.parameters.get(parameter_rule.use_template)) or 0
-
-        self.model_config.parameters['max_tokens'] = 40
-
-        prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
-        messages = prompt.to_messages()
-
-        try:
-            prompt_messages = lc_messages_to_prompt_messages(messages)
-            model_instance = ModelInstance(
-                provider_model_bundle=self.model_config.provider_model_bundle,
-                model=self.model_config.model,
-            )
-
-            tools = []
-            for function in self.functions:
-                tool = PromptMessageTool(
-                    **function
-                )
-
-                tools.append(tool)
-
-            result = model_instance.invoke_llm(
-                prompt_messages=prompt_messages,
-                tools=tools,
-                stream=False,
-                model_parameters={
-                    'temperature': 0.2,
-                    'top_p': 0.3,
-                    'max_tokens': 1500
-                }
-            )
-        except Exception as e:
-            raise e
-
-        self.model_config.parameters['max_tokens'] = original_max_tokens
-
-        return True if result.message.tool_calls else False
-
-    def plan(
-            self,
-            intermediate_steps: list[tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-
-        prompt_messages = lc_messages_to_prompt_messages(messages)
-
-        # summarize messages if rest_tokens < 0
-        try:
-            prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
-        except ExceededLLMTokensLimitError as e:
-            return AgentFinish(return_values={"output": str(e)}, log=str(e))
-
-        model_instance = ModelInstance(
-            provider_model_bundle=self.model_config.provider_model_bundle,
-            model=self.model_config.model,
-        )
-
-        tools = []
-        for function in self.functions:
-            tool = PromptMessageTool(
-                **function
-            )
-
-            tools.append(tool)
-
-        result = model_instance.invoke_llm(
-            prompt_messages=prompt_messages,
-            tools=tools,
-            stream=False,
-            callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
-            model_parameters={
-                'temperature': 0.2,
-                'top_p': 0.3,
-                'max_tokens': 1500
-            }
-        )
-
-        ai_message = AIMessage(
-            content=result.message.content or "",
-            additional_kwargs={
-                'function_call': {
-                    'id': result.message.tool_calls[0].id,
-                    **result.message.tool_calls[0].function.dict()
-                } if result.message.tool_calls else None
-            }
-        )
-        agent_decision = _parse_ai_message(ai_message)
-
-        if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
-            tool_inputs = agent_decision.tool_input
-            if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
-                tool_inputs['query'] = kwargs['input']
-                agent_decision.tool_input = tool_inputs
-
-        return agent_decision
-
-    @classmethod
-    def get_system_message(cls):
-        return SystemMessage(content="You are a helpful AI assistant.\n"
-                                     "The current date or current time you know is wrong.\n"
-                                     "Respond directly if appropriate.")
-
-    def return_stopped_response(
-            self,
-            early_stopping_method: str,
-            intermediate_steps: list[tuple[AgentAction, str]],
-            **kwargs: Any,
-    ) -> AgentFinish:
-        try:
-            return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
-        except ValueError:
-            return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
-
-    def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
-        # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(
-            self.model_config,
-            messages,
-            **kwargs
-        )
-
-        rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
-        if rest_tokens >= 0:
-            return messages
-
-        system_message = None
-        human_message = None
-        should_summary_messages = []
-        for message in messages:
-            if isinstance(message, SystemMessage):
-                system_message = message
-            elif isinstance(message, HumanMessage):
-                human_message = message
-            else:
-                should_summary_messages.append(message)
-
-        if len(should_summary_messages) > 2:
-            ai_message = should_summary_messages[-2]
-            function_message = should_summary_messages[-1]
-            should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
-            self.moving_summary_index = len(should_summary_messages)
-        else:
-            error_msg = "Exceeded LLM tokens limit, stopped."
-            raise ExceededLLMTokensLimitError(error_msg)
-
-        new_messages = [system_message, human_message]
-
-        if self.moving_summary_index == 0:
-            should_summary_messages.insert(0, human_message)
-
-        self.moving_summary_buffer = self.predict_new_summary(
-            messages=should_summary_messages,
-            existing_summary=self.moving_summary_buffer
-        )
-
-        new_messages.append(AIMessage(content=self.moving_summary_buffer))
-        new_messages.append(ai_message)
-        new_messages.append(function_message)
-
-        return new_messages
-
-    def predict_new_summary(
-        self, messages: list[BaseMessage], existing_summary: str
-    ) -> str:
-        new_lines = get_buffer_string(
-            messages,
-            human_prefix="Human",
-            ai_prefix="AI",
-        )
-
-        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
-        return chain.predict(summary=existing_summary, new_lines=new_lines)
-
-    def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
-        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
-
-        Official documentation: https://github.com/openai/openai-cookbook/blob/
-        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
-        if model_config.provider == 'azure_openai':
-            model = model_config.model
-            model = model.replace("gpt-35", "gpt-3.5")
-        else:
-            model = model_config.credentials.get("base_model_name")
-
-        tiktoken_ = _import_tiktoken()
-        try:
-            encoding = tiktoken_.encoding_for_model(model)
-        except KeyError:
-            model = "cl100k_base"
-            encoding = tiktoken_.get_encoding(model)
-
-        if model.startswith("gpt-3.5-turbo"):
-            # every message follows <im_start>{role/name}\n{content}<im_end>\n
-            tokens_per_message = 4
-            # if there's a name, the role is omitted
-            tokens_per_name = -1
-        elif model.startswith("gpt-4"):
-            tokens_per_message = 3
-            tokens_per_name = 1
-        else:
-            raise NotImplementedError(
-                f"get_num_tokens_from_messages() is not presently implemented "
-                f"for model {model}."
-                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
-                "information on how messages are converted to tokens."
-            )
-        num_tokens = 0
-        for m in messages:
-            message = _convert_message_to_dict(m)
-            num_tokens += tokens_per_message
-            for key, value in message.items():
-                if key == "function_call":
-                    for f_key, f_value in value.items():
-                        num_tokens += len(encoding.encode(f_key))
-                        num_tokens += len(encoding.encode(f_value))
-                else:
-                    num_tokens += len(encoding.encode(value))
-
-                if key == "name":
-                    num_tokens += tokens_per_name
-        # every reply is primed with <im_start>assistant
-        num_tokens += 3
-
-        if kwargs.get('functions'):
-            for function in kwargs.get('functions'):
-                num_tokens += len(encoding.encode('name'))
-                num_tokens += len(encoding.encode(function.get("name")))
-                num_tokens += len(encoding.encode('description'))
-                num_tokens += len(encoding.encode(function.get("description")))
-                parameters = function.get("parameters")
-                num_tokens += len(encoding.encode('parameters'))
-                if 'title' in parameters:
-                    num_tokens += len(encoding.encode('title'))
-                    num_tokens += len(encoding.encode(parameters.get("title")))
-                num_tokens += len(encoding.encode('type'))
-                num_tokens += len(encoding.encode(parameters.get("type")))
-                if 'properties' in parameters:
-                    num_tokens += len(encoding.encode('properties'))
-                    for key, value in parameters.get('properties').items():
-                        num_tokens += len(encoding.encode(key))
-                        for field_key, field_value in value.items():
-                            num_tokens += len(encoding.encode(field_key))
-                            if field_key == 'enum':
-                                for enum_field in field_value:
-                                    num_tokens += 3
-                                    num_tokens += len(encoding.encode(enum_field))
-                            else:
-                                num_tokens += len(encoding.encode(field_key))
-                                num_tokens += len(encoding.encode(str(field_value)))
-                if 'required' in parameters:
-                    num_tokens += len(encoding.encode('required'))
-                    for required_field in parameters['required']:
-                        num_tokens += 3
-                        num_tokens += len(encoding.encode(required_field))
-
-        return num_tokens

+ 0 - 306
api/core/agent/agent/structured_chat.py

@@ -1,306 +0,0 @@
-import re
-from collections.abc import Sequence
-from typing import Any, Optional, Union, cast
-
-from langchain import BasePromptTemplate, PromptTemplate
-from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
-from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.memory.prompt import SUMMARY_PROMPT
-from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
-from langchain.schema import (
-    AgentAction,
-    AgentFinish,
-    AIMessage,
-    BaseMessage,
-    HumanMessage,
-    OutputParserException,
-    get_buffer_string,
-)
-from langchain.tools import BaseTool
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-
-FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
-Valid "action" values: "Final Answer" or {tool_names}
-
-Provide only ONE action per $JSON_BLOB, as shown:
-
-```
-{{{{
-  "action": $TOOL_NAME,
-  "action_input": $INPUT
-}}}}
-```
-
-Follow this format:
-
-Question: input question to answer
-Thought: consider previous and subsequent steps
-Action:
-```
-$JSON_BLOB
-```
-Observation: action result
-... (repeat Thought/Action/Observation N times)
-Thought: I know what to respond
-Action:
-```
-{{{{
-  "action": "Final Answer",
-  "action_input": "Final response to human"
-}}}}
-```"""
-
-
-class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
-    moving_summary_buffer: str = ""
-    moving_summary_index: int = 0
-    summary_model_config: ModelConfigEntity = None
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-        Using the ReACT mode to determine whether an agent is needed is costly,
-        so it's better to just use an Agent for reasoning, which is cheaper.
-
-        :param query:
-        :return:
-        """
-        return True
-
-    def plan(
-        self,
-        intermediate_steps: list[tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date,
-                along with observatons
-            callbacks: Callbacks to run.
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
-        prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
-
-        messages = []
-        if prompts:
-            messages = prompts[0].to_messages()
-
-        prompt_messages = lc_messages_to_prompt_messages(messages)
-
-        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
-        if rest_tokens < 0:
-            full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
-
-        try:
-            full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
-        except Exception as e:
-            raise e
-
-        try:
-            agent_decision = self.output_parser.parse(full_output)
-            if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
-                tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
-                    tool_inputs['query'] = kwargs['input']
-                    agent_decision.tool_input = tool_inputs
-            return agent_decision
-        except OutputParserException:
-            return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
-                                          "I don't know how to respond to that."}, "")
-
-    def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2 and self.summary_model_config:
-            should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
-            should_summary_messages = [AIMessage(content=observation)
-                                       for _, observation in should_summary_intermediate_steps]
-            if self.moving_summary_index == 0:
-                should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
-
-            self.moving_summary_index = len(intermediate_steps)
-        else:
-            error_msg = "Exceeded LLM tokens limit, stopped."
-            raise ExceededLLMTokensLimitError(error_msg)
-
-        if self.moving_summary_buffer and 'chat_history' in kwargs:
-            kwargs["chat_history"].pop()
-
-        self.moving_summary_buffer = self.predict_new_summary(
-            messages=should_summary_messages,
-            existing_summary=self.moving_summary_buffer
-        )
-
-        if 'chat_history' in kwargs:
-            kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
-
-        return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
-
-    def predict_new_summary(
-        self, messages: list[BaseMessage], existing_summary: str
-    ) -> str:
-        new_lines = get_buffer_string(
-            messages,
-            human_prefix="Human",
-            ai_prefix="AI",
-        )
-
-        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
-        return chain.predict(summary=existing_summary, new_lines=new_lines)
-
-    @classmethod
-    def create_prompt(
-            cls,
-            tools: Sequence[BaseTool],
-            prefix: str = PREFIX,
-            suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-            memory_prompts: Optional[list[BasePromptTemplate]] = None,
-    ) -> BasePromptTemplate:
-        tool_strings = []
-        for tool in tools:
-            args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
-            tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
-        formatted_tools = "\n".join(tool_strings)
-        tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
-        format_instructions = format_instructions.format(tool_names=tool_names)
-        template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        _memory_prompts = memory_prompts or []
-        messages = [
-            SystemMessagePromptTemplate.from_template(template),
-            *_memory_prompts,
-            HumanMessagePromptTemplate.from_template(human_message_template),
-        ]
-        return ChatPromptTemplate(input_variables=input_variables, messages=messages)
-
-    @classmethod
-    def create_completion_prompt(
-            cls,
-            tools: Sequence[BaseTool],
-            prefix: str = PREFIX,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-    ) -> PromptTemplate:
-        """Create prompt in the style of the zero shot agent.
-
-        Args:
-            tools: List of tools the agent will have access to, used to format the
-                prompt.
-            prefix: String to put before the list of tools.
-            input_variables: List of input variables the final prompt will expect.
-
-        Returns:
-            A PromptTemplate with the template assembled from the pieces here.
-        """
-        suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
-Question: {input}
-Thought: {agent_scratchpad}
-"""
-
-        tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
-        tool_names = ", ".join([tool.name for tool in tools])
-        format_instructions = format_instructions.format(tool_names=tool_names)
-        template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        return PromptTemplate(template=template, input_variables=input_variables)
-
-    def _construct_scratchpad(
-        self, intermediate_steps: list[tuple[AgentAction, str]]
-    ) -> str:
-        agent_scratchpad = ""
-        for action, observation in intermediate_steps:
-            agent_scratchpad += action.log
-            agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
-
-        if not isinstance(agent_scratchpad, str):
-            raise ValueError("agent_scratchpad should be of type string.")
-        if agent_scratchpad:
-            llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_config.mode == "chat":
-                return (
-                    f"This was your previous work "
-                    f"(but I haven't seen any of it! I only see what "
-                    f"you return as final answer):\n{agent_scratchpad}"
-                )
-            else:
-                return agent_scratchpad
-        else:
-            return agent_scratchpad
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_config: ModelConfigEntity,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            output_parser: Optional[AgentOutputParser] = None,
-            prefix: str = PREFIX,
-            suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-            memory_prompts: Optional[list[BasePromptTemplate]] = None,
-            agent_llm_callback: Optional[AgentLLMCallback] = None,
-            **kwargs: Any,
-    ) -> Agent:
-        """Construct an agent from an LLM and tools."""
-        cls._validate_tools(tools)
-        if model_config.mode == "chat":
-            prompt = cls.create_prompt(
-                tools,
-                prefix=prefix,
-                suffix=suffix,
-                human_message_template=human_message_template,
-                format_instructions=format_instructions,
-                input_variables=input_variables,
-                memory_prompts=memory_prompts,
-            )
-        else:
-            prompt = cls.create_completion_prompt(
-                tools,
-                prefix=prefix,
-                format_instructions=format_instructions,
-                input_variables=input_variables,
-            )
-        llm_chain = LLMChain(
-            model_config=model_config,
-            prompt=prompt,
-            callback_manager=callback_manager,
-            agent_llm_callback=agent_llm_callback,
-            parameters={
-                'temperature': 0.2,
-                'top_p': 0.3,
-                'max_tokens': 1500
-            }
-        )
-        tool_names = [tool.name for tool in tools]
-        _output_parser = output_parser
-        return cls(
-            llm_chain=llm_chain,
-            allowed_tools=tool_names,
-            output_parser=_output_parser,
-            **kwargs,
-        )

+ 1 - 39
api/core/app_runner/assistant_app_runner.py

@@ -1,4 +1,3 @@
-import json
 import logging
 import logging
 from typing import cast
 from typing import cast
 
 
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
 from core.moderation.base import ModerationException
 from core.moderation.base import ModerationException
 from core.tools.entities.tool_entities import ToolRuntimeVariablePool
 from core.tools.entities.tool_entities import ToolRuntimeVariablePool
 from extensions.ext_database import db
 from extensions.ext_database import db
-from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
+from models.model import App, Conversation, Message, MessageAgentThought
 from models.tools import ToolConversationVariables
 from models.tools import ToolConversationVariables
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
 
 
         # convert db variables to tool variables
         # convert db variables to tool variables
         tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
         tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
-        
-        message_chain = self._init_message_chain(
-            message=message,
-            query=query
-        )
 
 
         # init model instance
         # init model instance
         model_instance = ModelInstance(
         model_instance = ModelInstance(
@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
             'pool': db_variables.variables
             'pool': db_variables.variables
         })
         })
 
 
-    def _init_message_chain(self, message: Message, query: str) -> MessageChain:
-        """
-        Init MessageChain
-        :param message: message
-        :param query: query
-        :return:
-        """
-        message_chain = MessageChain(
-            message_id=message.id,
-            type="AgentExecutor",
-            input=json.dumps({
-                "input": query
-            })
-        )
-
-        db.session.add(message_chain)
-        db.session.commit()
-
-        return message_chain
-
-    def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
-        """
-        Save MessageChain
-        :param message_chain: message chain
-        :param output_text: output text
-        :return:
-        """
-        message_chain.output = json.dumps({
-            "output": output_text
-        })
-        db.session.commit()
-
     def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
     def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
                                          message: Message) -> LLMUsage:
                                          message: Message) -> LLMUsage:
         """
         """

+ 1 - 1
api/core/app_runner/basic_app_runner.py

@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
 from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
 from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
-from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_manager import ModelInstance
 from core.moderation.base import ModerationException
 from core.moderation.base import ModerationException

+ 8 - 0
api/core/entities/agent_entities.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class PlanningStrategy(Enum):
+    ROUTER = 'router'
+    REACT_ROUTER = 'react_router'
+    REACT = 'react'
+    FUNCTION_CALL = 'function_call'

+ 0 - 199
api/core/features/agent_runner.py

@@ -1,199 +0,0 @@
-import logging
-from typing import Optional, cast
-
-from langchain.tools import BaseTool
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
-from core.application_queue_manager import ApplicationQueueManager
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.entities.application_entities import (
-    AgentEntity,
-    AppOrchestrationConfigEntity,
-    InvokeFrom,
-    ModelConfigEntity,
-)
-from core.memory.token_buffer_memory import TokenBufferMemory
-from core.model_runtime.entities.model_entities import ModelFeature, ModelType
-from core.model_runtime.model_providers import model_provider_factory
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
-from extensions.ext_database import db
-from models.dataset import Dataset
-from models.model import Message
-
-logger = logging.getLogger(__name__)
-
-
-class AgentRunnerFeature:
-    def __init__(self, tenant_id: str,
-                 app_orchestration_config: AppOrchestrationConfigEntity,
-                 model_config: ModelConfigEntity,
-                 config: AgentEntity,
-                 queue_manager: ApplicationQueueManager,
-                 message: Message,
-                 user_id: str,
-                 agent_llm_callback: AgentLLMCallback,
-                 callback: AgentLoopGatherCallbackHandler,
-                 memory: Optional[TokenBufferMemory] = None,) -> None:
-        """
-        Agent runner
-        :param tenant_id: tenant id
-        :param app_orchestration_config: app orchestration config
-        :param model_config: model config
-        :param config: dataset config
-        :param queue_manager: queue manager
-        :param message: message
-        :param user_id: user id
-        :param agent_llm_callback: agent llm callback
-        :param callback: callback
-        :param memory: memory
-        """
-        self.tenant_id = tenant_id
-        self.app_orchestration_config = app_orchestration_config
-        self.model_config = model_config
-        self.config = config
-        self.queue_manager = queue_manager
-        self.message = message
-        self.user_id = user_id
-        self.agent_llm_callback = agent_llm_callback
-        self.callback = callback
-        self.memory = memory
-
-    def run(self, query: str,
-            invoke_from: InvokeFrom) -> Optional[str]:
-        """
-        Retrieve agent loop result.
-        :param query: query
-        :param invoke_from: invoke from
-        :return:
-        """
-        provider = self.config.provider
-        model = self.config.model
-        tool_configs = self.config.tools
-
-        # check model is support tool calling
-        provider_instance = model_provider_factory.get_provider_instance(provider=provider)
-        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
-        # get model schema
-        model_schema = model_type_instance.get_model_schema(
-            model=model,
-            credentials=self.model_config.credentials
-        )
-
-        if not model_schema:
-            return None
-
-        planning_strategy = PlanningStrategy.REACT
-        features = model_schema.features
-        if features:
-            if ModelFeature.TOOL_CALL in features \
-                    or ModelFeature.MULTI_TOOL_CALL in features:
-                planning_strategy = PlanningStrategy.FUNCTION_CALL
-
-        tools = self.to_tools(
-            tool_configs=tool_configs,
-            invoke_from=invoke_from,
-            callbacks=[self.callback, DifyStdOutCallbackHandler()],
-        )
-
-        if len(tools) == 0:
-            return None
-
-        agent_configuration = AgentConfiguration(
-            strategy=planning_strategy,
-            model_config=self.model_config,
-            tools=tools,
-            memory=self.memory,
-            max_iterations=10,
-            max_execution_time=400.0,
-            early_stopping_method="generate",
-            agent_llm_callback=self.agent_llm_callback,
-            callbacks=[self.callback, DifyStdOutCallbackHandler()]
-        )
-
-        agent_executor = AgentExecutor(agent_configuration)
-
-        try:
-            # check if should use agent
-            should_use_agent = agent_executor.should_use_agent(query)
-            if not should_use_agent:
-                return None
-
-            result = agent_executor.run(query)
-            return result.output
-        except Exception as ex:
-            logger.exception("agent_executor run failed")
-            return None
-
-    def to_dataset_retriever_tool(self, tool_config: dict,
-                                  invoke_from: InvokeFrom) \
-            -> Optional[BaseTool]:
-        """
-        A dataset tool is a tool that can be used to retrieve information from a dataset
-        :param tool_config: tool config
-        :param invoke_from: invoke from
-        """
-        show_retrieve_source = self.app_orchestration_config.show_retrieve_source
-
-        hit_callback = DatasetIndexToolCallbackHandler(
-            queue_manager=self.queue_manager,
-            app_id=self.message.app_id,
-            message_id=self.message.id,
-            user_id=self.user_id,
-            invoke_from=invoke_from
-        )
-
-        # get dataset from dataset id
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == self.tenant_id,
-            Dataset.id == tool_config.get("id")
-        ).first()
-
-        # pass if dataset is not available
-        if not dataset:
-            return None
-
-        # pass if dataset is not available
-        if (dataset and dataset.available_document_count == 0
-                and dataset.available_document_count == 0):
-            return None
-
-        # get retrieval model config
-        default_retrieval_model = {
-            'search_method': 'semantic_search',
-            'reranking_enable': False,
-            'reranking_model': {
-                'reranking_provider_name': '',
-                'reranking_model_name': ''
-            },
-            'top_k': 2,
-            'score_threshold_enabled': False
-        }
-
-        retrieval_model_config = dataset.retrieval_model \
-            if dataset.retrieval_model else default_retrieval_model
-
-        # get top k
-        top_k = retrieval_model_config['top_k']
-
-        # get score threshold
-        score_threshold = None
-        score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
-        if score_threshold_enabled:
-            score_threshold = retrieval_model_config.get("score_threshold")
-
-        tool = DatasetRetrieverTool.from_dataset(
-            dataset=dataset,
-            top_k=top_k,
-            score_threshold=score_threshold,
-            hit_callbacks=[hit_callback],
-            return_resource=show_retrieve_source,
-            retriever_from=invoke_from.to_source()
-        )
-
-        return tool

+ 0 - 0
api/core/third_party/langchain/llms/__init__.py → api/core/features/dataset_retrieval/__init__.py


+ 0 - 0
api/core/third_party/spark/__init__.py → api/core/features/dataset_retrieval/agent/__init__.py


+ 0 - 0
api/core/agent/agent/agent_llm_callback.py → api/core/features/dataset_retrieval/agent/agent_llm_callback.py


+ 0 - 0
api/core/third_party/langchain/llms/fake.py → api/core/features/dataset_retrieval/agent/fake_llm.py


+ 2 - 2
api/core/chain/llm_chain.py → api/core/features/dataset_retrieval/agent/llm_chain.py

@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.schema import Generation, LLMResult
 from langchain.schema import Generation, LLMResult
 from langchain.schema.language_model import BaseLanguageModel
 from langchain.schema.language_model import BaseLanguageModel
 
 
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.message_entities import lc_messages_to_prompt_messages
 from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
+from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
 from core.model_manager import ModelInstance
 from core.model_manager import ModelInstance
-from core.third_party.langchain.llms.fake import FakeLLM
 
 
 
 
 class LLMChain(LCLLMChain):
 class LLMChain(LCLLMChain):

+ 1 - 1
api/core/agent/agent/multi_dataset_router_agent.py → api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py

@@ -12,9 +12,9 @@ from pydantic import root_validator
 
 
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.message_entities import lc_messages_to_prompt_messages
 from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
 from core.model_manager import ModelInstance
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.entities.message_entities import PromptMessageTool
-from core.third_party.langchain.llms.fake import FakeLLM
 
 
 
 
 class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
 class MultiDatasetRouterAgent(OpenAIFunctionsAgent):

+ 0 - 0
api/core/data_loader/file_extractor.py → api/core/features/dataset_retrieval/agent/output_parser/__init__.py


+ 0 - 0
api/core/agent/agent/output_parser/structured_chat.py → api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py


+ 1 - 1
api/core/agent/agent/structed_multi_dataset_router_agent.py → api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py

@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
 from langchain.schema import AgentAction, AgentFinish, OutputParserException
 from langchain.schema import AgentAction, AgentFinish, OutputParserException
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 
 
-from core.chain.llm_chain import LLMChain
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.application_entities import ModelConfigEntity
+from core.features.dataset_retrieval.agent.llm_chain import LLMChain
 
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.

+ 6 - 36
api/core/agent/agent_executor.py → api/core/features/dataset_retrieval/agent_based_dataset_executor.py

@@ -1,4 +1,3 @@
-import enum
 import logging
 import logging
 from typing import Optional, Union
 from typing import Optional, Union
 
 
@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from pydantic import BaseModel, Extra
 from pydantic import BaseModel, Extra
 
 
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
-from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
-from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
-from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
-from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
+from core.entities.agent_entities import PlanningStrategy
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.message_entities import prompt_messages_to_lc_messages
 from core.entities.message_entities import prompt_messages_to_lc_messages
+from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
+from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
+from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
+from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.helper import moderation
 from core.helper import moderation
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.errors.invoke import InvokeError
 from core.model_runtime.errors.invoke import InvokeError
@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
 from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 
 
 
 
-class PlanningStrategy(str, enum.Enum):
-    ROUTER = 'router'
-    REACT_ROUTER = 'react_router'
-    REACT = 'react'
-    FUNCTION_CALL = 'function_call'
-
-
 class AgentConfiguration(BaseModel):
 class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
     strategy: PlanningStrategy
     model_config: ModelConfigEntity
     model_config: ModelConfigEntity
@@ -62,28 +53,7 @@ class AgentExecutor:
         self.agent = self._init_agent()
         self.agent = self._init_agent()
 
 
     def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
     def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
-        if self.configuration.strategy == PlanningStrategy.REACT:
-            agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
-                model_config=self.configuration.model_config,
-                tools=self.configuration.tools,
-                output_parser=StructuredChatOutputParser(),
-                summary_model_config=self.configuration.summary_model_config
-                if self.configuration.summary_model_config else None,
-                agent_llm_callback=self.configuration.agent_llm_callback,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
-            agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
-                model_config=self.configuration.model_config,
-                tools=self.configuration.tools,
-                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
-                if self.configuration.memory else None,  # used for read chat histories memory
-                summary_model_config=self.configuration.summary_model_config
-                if self.configuration.summary_model_config else None,
-                agent_llm_callback=self.configuration.agent_llm_callback,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.ROUTER:
+        if self.configuration.strategy == PlanningStrategy.ROUTER:
             self.configuration.tools = [t for t in self.configuration.tools
             self.configuration.tools = [t for t in self.configuration.tools
                                         if isinstance(t, DatasetRetrieverTool)
                                         if isinstance(t, DatasetRetrieverTool)
                                         or isinstance(t, DatasetMultiRetrieverTool)]
                                         or isinstance(t, DatasetMultiRetrieverTool)]

+ 2 - 1
api/core/features/dataset_retrieval.py → api/core/features/dataset_retrieval/dataset_retrieval.py

@@ -2,9 +2,10 @@ from typing import Optional, cast
 
 
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 
 
-from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.agent_entities import PlanningStrategy
 from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
 from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
+from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.model_entities import ModelFeature
 from core.model_runtime.entities.model_entities import ModelFeature
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

+ 0 - 189
api/core/third_party/spark/spark_llm.py

@@ -1,189 +0,0 @@
-import base64
-import hashlib
-import hmac
-import json
-import queue
-import ssl
-from datetime import datetime
-from time import mktime
-from typing import Optional
-from urllib.parse import urlencode, urlparse
-from wsgiref.handlers import format_date_time
-
-import websocket
-
-
-class SparkLLMClient:
-    def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
-        domain = 'spark-api.xf-yun.com'
-        endpoint = 'chat'
-        if api_domain:
-            domain = api_domain
-            if model_name == 'spark-v3':
-                endpoint = 'multimodal'
-
-        model_api_configs = {
-            'spark': {
-                'version': 'v1.1',
-                'chat_domain': 'general'
-            },
-            'spark-v2': {
-                'version': 'v2.1',
-                'chat_domain': 'generalv2'
-            },
-            'spark-v3': {
-                'version': 'v3.1',
-                'chat_domain': 'generalv3'
-            },
-            'spark-v3.5': {
-                'version': 'v3.5',
-                'chat_domain': 'generalv3.5'
-            }
-        }
-
-        api_version = model_api_configs[model_name]['version']
-
-        self.chat_domain = model_api_configs[model_name]['chat_domain']
-        self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
-        self.app_id = app_id
-        self.ws_url = self.create_url(
-            urlparse(self.api_base).netloc,
-            urlparse(self.api_base).path,
-            self.api_base,
-            api_key,
-            api_secret
-        )
-
-        self.queue = queue.Queue()
-        self.blocking_message = ''
-
-    def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
-        # generate timestamp by RFC1123
-        now = datetime.now()
-        date = format_date_time(mktime(now.timetuple()))
-
-        signature_origin = "host: " + host + "\n"
-        signature_origin += "date: " + date + "\n"
-        signature_origin += "GET " + path + " HTTP/1.1"
-
-        # encrypt using hmac-sha256
-        signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
-                                 digestmod=hashlib.sha256).digest()
-
-        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
-
-        authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
-
-        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
-
-        v = {
-            "authorization": authorization,
-            "date": date,
-            "host": host
-        }
-        # generate url
-        url = api_base + '?' + urlencode(v)
-        return url
-
-    def run(self, messages: list, user_id: str,
-            model_kwargs: Optional[dict] = None, streaming: bool = False):
-        websocket.enableTrace(False)
-        ws = websocket.WebSocketApp(
-            self.ws_url,
-            on_message=self.on_message,
-            on_error=self.on_error,
-            on_close=self.on_close,
-            on_open=self.on_open
-        )
-        ws.messages = messages
-        ws.user_id = user_id
-        ws.model_kwargs = model_kwargs
-        ws.streaming = streaming
-        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
-
-    def on_error(self, ws, error):
-        self.queue.put({
-            'status_code': error.status_code,
-            'error': error.resp_body.decode('utf-8')
-        })
-        ws.close()
-
-    def on_close(self, ws, close_status_code, close_reason):
-        self.queue.put({'done': True})
-
-    def on_open(self, ws):
-        self.blocking_message = ''
-        data = json.dumps(self.gen_params(
-            messages=ws.messages,
-            user_id=ws.user_id,
-            model_kwargs=ws.model_kwargs
-        ))
-        ws.send(data)
-
-    def on_message(self, ws, message):
-        data = json.loads(message)
-        code = data['header']['code']
-        if code != 0:
-            self.queue.put({
-                'status_code': 400,
-                'error': f"Code: {code}, Error: {data['header']['message']}"
-            })
-            ws.close()
-        else:
-            choices = data["payload"]["choices"]
-            status = choices["status"]
-            content = choices["text"][0]["content"]
-            if ws.streaming:
-                self.queue.put({'data': content})
-            else:
-                self.blocking_message += content
-
-            if status == 2:
-                if not ws.streaming:
-                    self.queue.put({'data': self.blocking_message})
-                ws.close()
-
-    def gen_params(self, messages: list, user_id: str,
-                   model_kwargs: Optional[dict] = None) -> dict:
-        data = {
-            "header": {
-                "app_id": self.app_id,
-                "uid": user_id
-            },
-            "parameter": {
-                "chat": {
-                    "domain": self.chat_domain
-                }
-            },
-            "payload": {
-                "message": {
-                    "text": messages
-                }
-            }
-        }
-
-        if model_kwargs:
-            data['parameter']['chat'].update(model_kwargs)
-
-        return data
-
-    def subscribe(self):
-        while True:
-            content = self.queue.get()
-            if 'error' in content:
-                if content['status_code'] == 401:
-                    raise SparkError('[Spark] The credentials you provided are incorrect. '
-                                     'Please double-check and fill them in again.')
-                elif content['status_code'] == 403:
-                    raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
-                                     "Please try again after obtaining the necessary permissions.")
-                else:
-                    raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
-
-            if 'data' not in content:
-                break
-            yield content
-
-
-class SparkError(Exception):
-    pass

+ 0 - 24
api/core/tool/current_datetime_tool.py

@@ -1,24 +0,0 @@
-from datetime import datetime
-
-from langchain.tools import BaseTool
-from pydantic import BaseModel, Field
-
-
-class DatetimeToolInput(BaseModel):
-    type: str = Field(..., description="Type for current time, must be: datetime.")
-
-
-class DatetimeTool(BaseTool):
-    """Tool for querying current datetime."""
-    name: str = "current_datetime"
-    args_schema: type[BaseModel] = DatetimeToolInput
-    description: str = "A tool when you want to get the current date, time, week, month or year, " \
-                       "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
-
-    def _run(self, type: str) -> str:
-        # get current time
-        current_time = datetime.utcnow()
-        return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
-
-    async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()

+ 0 - 63
api/core/tool/provider/base.py

@@ -1,63 +0,0 @@
-import base64
-from abc import ABC, abstractmethod
-from typing import Optional
-
-from extensions.ext_database import db
-from libs import rsa
-from models.account import Tenant
-from models.tool import ToolProvider, ToolProviderName
-
-
-class BaseToolProvider(ABC):
-    def __init__(self, tenant_id: str):
-        self.tenant_id = tenant_id
-
-    @abstractmethod
-    def get_provider_name(self) -> ToolProviderName:
-        raise NotImplementedError
-
-    @abstractmethod
-    def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def credentials_to_func_kwargs(self) -> Optional[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def credentials_validate(self, credentials: dict):
-        raise NotImplementedError
-
-    def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
-        """
-        Returns the Provider instance for the given tenant_id and tool_name.
-        """
-        query = db.session.query(ToolProvider).filter(
-            ToolProvider.tenant_id == self.tenant_id,
-            ToolProvider.tool_name == self.get_provider_name().value
-        )
-
-        if must_enabled:
-            query = query.filter(ToolProvider.is_enabled == True)
-
-        return query.first()
-
-    def encrypt_token(self, token) -> str:
-        tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
-        encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
-        return base64.b64encode(encrypted_token).decode()
-
-    def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
-        token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
-
-        if obfuscated:
-            return self._obfuscated_token(token)
-
-        return token
-
-    def _obfuscated_token(self, token: str) -> str:
-        return token[:6] + '*' * (len(token) - 8) + token[-2:]

+ 0 - 2
api/core/tool/provider/errors.py

@@ -1,2 +0,0 @@
-class ToolValidateFailedError(Exception):
-    description = "Tool Provider Validate failed"

+ 0 - 77
api/core/tool/provider/serpapi_provider.py

@@ -1,77 +0,0 @@
-from typing import Optional
-
-from core.tool.provider.base import BaseToolProvider
-from core.tool.provider.errors import ToolValidateFailedError
-from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
-from models.tool import ToolProviderName
-
-
-class SerpAPIToolProvider(BaseToolProvider):
-    def get_provider_name(self) -> ToolProviderName:
-        """
-        Returns the name of the provider.
-
-        :return:
-        """
-        return ToolProviderName.SERPAPI
-
-    def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
-        """
-        Returns the credentials for SerpAPI as a dictionary.
-
-        :param obfuscated: obfuscate credentials if True
-        :return:
-        """
-        tool_provider = self.get_provider(must_enabled=True)
-        if not tool_provider:
-            return None
-
-        credentials = tool_provider.credentials
-        if not credentials:
-            return None
-
-        if credentials.get('api_key'):
-            credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
-
-        return credentials
-
-    def credentials_to_func_kwargs(self) -> Optional[dict]:
-        """
-        Returns the credentials function kwargs as a dictionary.
-
-        :return:
-        """
-        credentials = self.get_credentials()
-        if not credentials:
-            return None
-
-        return {
-            'serpapi_api_key': credentials.get('api_key')
-        }
-
-    def credentials_validate(self, credentials: dict):
-        """
-        Validates the given credentials.
-
-        :param credentials:
-        :return:
-        """
-        if 'api_key' not in credentials or not credentials.get('api_key'):
-            raise ToolValidateFailedError("SerpAPI api_key is required.")
-
-        api_key = credentials.get('api_key')
-
-        try:
-            OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
-        except Exception as e:
-            raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
-
-    def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
-        """
-        Encrypts the given credentials.
-
-        :param credentials:
-        :return:
-        """
-        credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
-        return credentials

+ 0 - 43
api/core/tool/provider/tool_provider_service.py

@@ -1,43 +0,0 @@
-from typing import Optional
-
-from core.tool.provider.base import BaseToolProvider
-from core.tool.provider.serpapi_provider import SerpAPIToolProvider
-
-
-class ToolProviderService:
-
-    def __init__(self, tenant_id: str, provider_name: str):
-        self.provider = self._init_provider(tenant_id, provider_name)
-
-    def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
-        if provider_name == 'serpapi':
-            return SerpAPIToolProvider(tenant_id)
-        else:
-            raise Exception('tool provider {} not found'.format(provider_name))
-
-    def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
-        """
-        Returns the credentials for Tool as a dictionary.
-
-        :param obfuscated:
-        :return:
-        """
-        return self.provider.get_credentials(obfuscated)
-
-    def credentials_validate(self, credentials: dict):
-        """
-        Validates the given credentials.
-
-        :param credentials:
-        :raises: ValidateFailedError
-        """
-        return self.provider.credentials_validate(credentials)
-
-    def encrypt_credentials(self, credentials: dict):
-        """
-        Encrypts the given credentials.
-
-        :param credentials:
-        :return:
-        """
-        return self.provider.encrypt_credentials(credentials)

+ 0 - 51
api/core/tool/serpapi_wrapper.py

@@ -1,51 +0,0 @@
-from langchain import SerpAPIWrapper
-from pydantic import BaseModel, Field
-
-
-class OptimizedSerpAPIInput(BaseModel):
-    query: str = Field(..., description="search query.")
-
-
-class OptimizedSerpAPIWrapper(SerpAPIWrapper):
-
-    @staticmethod
-    def _process_response(res: dict, num_results: int = 5) -> str:
-        """Process response from SerpAPI."""
-        if "error" in res.keys():
-            raise ValueError(f"Got error from SerpAPI: {res['error']}")
-        if "answer_box" in res.keys() and type(res["answer_box"]) == list:
-            res["answer_box"] = res["answer_box"][0]
-        if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
-            toret = res["answer_box"]["answer"]
-        elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
-            toret = res["answer_box"]["snippet"]
-        elif (
-            "answer_box" in res.keys()
-            and "snippet_highlighted_words" in res["answer_box"].keys()
-        ):
-            toret = res["answer_box"]["snippet_highlighted_words"][0]
-        elif (
-            "sports_results" in res.keys()
-            and "game_spotlight" in res["sports_results"].keys()
-        ):
-            toret = res["sports_results"]["game_spotlight"]
-        elif (
-            "shopping_results" in res.keys()
-            and "title" in res["shopping_results"][0].keys()
-        ):
-            toret = res["shopping_results"][:3]
-        elif (
-            "knowledge_graph" in res.keys()
-            and "description" in res["knowledge_graph"].keys()
-        ):
-            toret = res["knowledge_graph"]["description"]
-        elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
-            toret = ""
-            for result in res["organic_results"][:num_results]:
-                if "link" in result:
-                    toret += "----------------\nlink: " + result["link"] + "\n"
-                if "snippet" in result:
-                    toret += "snippet: " + result["snippet"] + "\n"
-        else:
-            toret = "No good search result found"
-        return "search result:\n" + toret

+ 0 - 443
api/core/tool/web_reader_tool.py

@@ -1,443 +0,0 @@
-import hashlib
-import json
-import os
-import re
-import site
-import subprocess
-import tempfile
-import unicodedata
-from contextlib import contextmanager
-from typing import Any
-
-import requests
-from bs4 import BeautifulSoup, CData, Comment, NavigableString
-from langchain.chains import RefineDocumentsChain
-from langchain.chains.summarize import refine_prompts
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.tools.base import BaseTool
-from newspaper import Article
-from pydantic import BaseModel, Field
-from regex import regex
-
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.rag.extractor import extract_processor
-from core.rag.extractor.extract_processor import ExtractProcessor
-from core.rag.models.document import Document
-
-FULL_TEMPLATE = """
-TITLE: {title}
-AUTHORS: {authors}
-PUBLISH DATE: {publish_date}
-TOP_IMAGE_URL: {top_image}
-TEXT:
-
-{text}
-"""
-
-
-class WebReaderToolInput(BaseModel):
-    url: str = Field(..., description="URL of the website to read")
-    summary: bool = Field(
-        default=False,
-        description="When the user's question requires extracting the summarizing content of the webpage, "
-                    "set it to true."
-    )
-    cursor: int = Field(
-        default=0,
-        description="Start reading from this character."
-        "Use when the first response was truncated"
-        "and you want to continue reading the page."
-        "The value cannot exceed 24000.",
-    )
-
-
-class WebReaderTool(BaseTool):
-    """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
-
-    name: str = "web_reader"
-    args_schema: type[BaseModel] = WebReaderToolInput
-    description: str = "use this to read a website. " \
-                       "If you can answer the question based on the information provided, " \
-                       "there is no need to use."
-    page_contents: str = None
-    url: str = None
-    max_chunk_length: int = 4000
-    summary_chunk_tokens: int = 4000
-    summary_chunk_overlap: int = 0
-    summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
-    continue_reading: bool = True
-    model_config: ModelConfigEntity
-    model_parameters: dict[str, Any]
-
-    def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
-        try:
-            if not self.page_contents or self.url != url:
-                page_contents = get_url(url)
-                self.page_contents = page_contents
-                self.url = url
-            else:
-                page_contents = self.page_contents
-        except Exception as e:
-            return f'Read this website failed, caused by: {str(e)}.'
-
-        if summary:
-            character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
-                chunk_size=self.summary_chunk_tokens,
-                chunk_overlap=self.summary_chunk_overlap,
-                separators=self.summary_separators
-            )
-
-            texts = character_splitter.split_text(page_contents)
-            docs = [Document(page_content=t) for t in texts]
-
-            if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
-                return "No content found."
-
-            # only use first 5 docs
-            if len(docs) > 5:
-                docs = docs[:5]
-
-            chain = self.get_summary_chain()
-            try:
-                page_contents = chain.run(docs)
-            except Exception as e:
-                return f'Read this website failed, caused by: {str(e)}.'
-        else:
-            page_contents = page_result(page_contents, cursor, self.max_chunk_length)
-
-            if self.continue_reading and len(page_contents) >= self.max_chunk_length:
-                page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
-                                 f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
-                                 f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
-
-        return page_contents
-
-    async def _arun(self, url: str) -> str:
-        raise NotImplementedError
-
-    def get_summary_chain(self) -> RefineDocumentsChain:
-        initial_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.PROMPT,
-            parameters=self.model_parameters
-        )
-        refine_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.REFINE_PROMPT,
-            parameters=self.model_parameters
-        )
-        return RefineDocumentsChain(
-            initial_llm_chain=initial_chain,
-            refine_llm_chain=refine_chain,
-            document_variable_name="text",
-            initial_response_name="existing_answer",
-            callbacks=self.callbacks
-        )
-
-
-def page_result(text: str, cursor: int, max_length: int) -> str:
-    """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
-    return text[cursor: cursor + max_length]
-
-
-def get_url(url: str) -> str:
-    """Fetch URL and return the contents as a string."""
-    headers = {
-        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
-    }
-    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
-
-    head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
-
-    if head_response.status_code != 200:
-        return "URL returned status code {}.".format(head_response.status_code)
-
-    # check content-type
-    main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
-    if main_content_type not in supported_content_types:
-        return "Unsupported content-type [{}] of URL.".format(main_content_type)
-
-    if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
-        return ExtractProcessor.load_from_url(url, return_text=True)
-
-    response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
-    a = extract_using_readabilipy(response.text)
-
-    if not a['plain_text'] or not a['plain_text'].strip():
-        return get_url_from_newspaper3k(url)
-
-    res = FULL_TEMPLATE.format(
-        title=a['title'],
-        authors=a['byline'],
-        publish_date=a['date'],
-        top_image="",
-        text=a['plain_text'] if a['plain_text'] else "",
-    )
-
-    return res
-
-
-def get_url_from_newspaper3k(url: str) -> str:
-
-    a = Article(url)
-    a.download()
-    a.parse()
-
-    res = FULL_TEMPLATE.format(
-        title=a.title,
-        authors=a.authors,
-        publish_date=a.publish_date,
-        top_image=a.top_image,
-        text=a.text,
-    )
-
-    return res
-
-
-def extract_using_readabilipy(html):
-    with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
-        f_html.write(html)
-        f_html.close()
-    html_path = f_html.name
-
-    # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
-    article_json_path = html_path + ".json"
-    jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
-    with chdir(jsdir):
-        subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
-
-    # Read output of call to Readability.parse() from JSON file and return as Python dictionary
-    with open(article_json_path, encoding="utf-8") as json_file:
-        input_json = json.loads(json_file.read())
-
-    # Deleting files after processing
-    os.unlink(article_json_path)
-    os.unlink(html_path)
-
-    article_json = {
-        "title": None,
-        "byline": None,
-        "date": None,
-        "content": None,
-        "plain_content": None,
-        "plain_text": None
-    }
-    # Populate article fields from readability fields where present
-    if input_json:
-        if "title" in input_json and input_json["title"]:
-            article_json["title"] = input_json["title"]
-        if "byline" in input_json and input_json["byline"]:
-            article_json["byline"] = input_json["byline"]
-        if "date" in input_json and input_json["date"]:
-            article_json["date"] = input_json["date"]
-        if "content" in input_json and input_json["content"]:
-            article_json["content"] = input_json["content"]
-            article_json["plain_content"] = plain_content(article_json["content"], False, False)
-            article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
-        if "textContent" in input_json and input_json["textContent"]:
-            article_json["plain_text"] = input_json["textContent"]
-            article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
-
-    return article_json
-
-
-def find_module_path(module_name):
-    for package_path in site.getsitepackages():
-        potential_path = os.path.join(package_path, module_name)
-        if os.path.exists(potential_path):
-            return potential_path
-
-    return None
-
-@contextmanager
-def chdir(path):
-    """Change directory in context and return to original on exit"""
-    # From https://stackoverflow.com/a/37996581, couldn't find a built-in
-    original_path = os.getcwd()
-    os.chdir(path)
-    try:
-        yield
-    finally:
-        os.chdir(original_path)
-
-
-def extract_text_blocks_as_plain_text(paragraph_html):
-    # Load article as DOM
-    soup = BeautifulSoup(paragraph_html, 'html.parser')
-    # Select all lists
-    list_elements = soup.find_all(['ul', 'ol'])
-    # Prefix text in all list items with "* " and make lists paragraphs
-    for list_element in list_elements:
-        plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
-        list_element.string = plain_items
-        list_element.name = "p"
-    # Select all text blocks
-    text_blocks = [s.parent for s in soup.find_all(string=True)]
-    text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
-    # Drop empty paragraphs
-    text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
-    return text_blocks
-
-
-def plain_text_leaf_node(element):
-    # Extract all text, stripped of any child HTML elements and normalise it
-    plain_text = normalise_text(element.get_text())
-    if plain_text != "" and element.name == "li":
-        plain_text = "* {}, ".format(plain_text)
-    if plain_text == "":
-        plain_text = None
-    if "data-node-index" in element.attrs:
-        plain = {"node_index": element["data-node-index"], "text": plain_text}
-    else:
-        plain = {"text": plain_text}
-    return plain
-
-
-def plain_content(readability_content, content_digests, node_indexes):
-    # Load article as DOM
-    soup = BeautifulSoup(readability_content, 'html.parser')
-    # Make all elements plain
-    elements = plain_elements(soup.contents, content_digests, node_indexes)
-    if node_indexes:
-        # Add node index attributes to nodes
-        elements = [add_node_indexes(element) for element in elements]
-    # Replace article contents with plain elements
-    soup.contents = elements
-    return str(soup)
-
-
-def plain_elements(elements, content_digests, node_indexes):
-    # Get plain content versions of all elements
-    elements = [plain_element(element, content_digests, node_indexes)
-                for element in elements]
-    if content_digests:
-        # Add content digest attribute to nodes
-        elements = [add_content_digest(element) for element in elements]
-    return elements
-
-
-def plain_element(element, content_digests, node_indexes):
-    # For lists, we make each item plain text
-    if is_leaf(element):
-        # For leaf node elements, extract the text content, discarding any HTML tags
-        # 1. Get element contents as text
-        plain_text = element.get_text()
-        # 2. Normalise the extracted text string to a canonical representation
-        plain_text = normalise_text(plain_text)
-        # 3. Update element content to be plain text
-        element.string = plain_text
-    elif is_text(element):
-        if is_non_printing(element):
-            # The simplified HTML may have come from Readability.js so might
-            # have non-printing text (e.g. Comment or CData). In this case, we
-            # keep the structure, but ensure that the string is empty.
-            element = type(element)("")
-        else:
-            plain_text = element.string
-            plain_text = normalise_text(plain_text)
-            element = type(element)(plain_text)
-    else:
-        # If not a leaf node or leaf type call recursively on child nodes, replacing
-        element.contents = plain_elements(element.contents, content_digests, node_indexes)
-    return element
-
-
-def add_node_indexes(element, node_index="0"):
-    # Can't add attributes to string types
-    if is_text(element):
-        return element
-    # Add index to current element
-    element["data-node-index"] = node_index
-    # Add index to child elements
-    for local_idx, child in enumerate(
-            [c for c in element.contents if not is_text(c)], start=1):
-        # Can't add attributes to leaf string types
-        child_index = "{stem}.{local}".format(
-            stem=node_index, local=local_idx)
-        add_node_indexes(child, node_index=child_index)
-    return element
-
-
-def normalise_text(text):
-    """Normalise unicode and whitespace."""
-    # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
-    text = strip_control_characters(text)
-    text = normalise_unicode(text)
-    text = normalise_whitespace(text)
-    return text
-
-
-def strip_control_characters(text):
-    """Strip out unicode control characters which might break the parsing."""
-    # Unicode control characters
-    #   [Cc]: Other, Control [includes new lines]
-    #   [Cf]: Other, Format
-    #   [Cn]: Other, Not Assigned
-    #   [Co]: Other, Private Use
-    #   [Cs]: Other, Surrogate
-    control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
-    retained_chars = ['\t', '\n', '\r', '\f']
-
-    # Remove non-printing control characters
-    return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
-
-
-def normalise_unicode(text):
-    """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
-    normal_form = "NFKC"
-    text = unicodedata.normalize(normal_form, text)
-    return text
-
-
-def normalise_whitespace(text):
-    """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
-    text = regex.sub(r"\s+", " ", text)
-    # Remove leading and trailing whitespace
-    text = text.strip()
-    return text
-
-def is_leaf(element):
-    return (element.name in ['p', 'li'])
-
-
-def is_text(element):
-    return isinstance(element, NavigableString)
-
-
-def is_non_printing(element):
-    return any(isinstance(element, _e) for _e in [Comment, CData])
-
-
-def add_content_digest(element):
-    if not is_text(element):
-        element["data-content-digest"] = content_digest(element)
-    return element
-
-
-def content_digest(element):
-    if is_text(element):
-        # Hash
-        trimmed_string = element.string.strip()
-        if trimmed_string == "":
-            digest = ""
-        else:
-            digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
-    else:
-        contents = element.contents
-        num_contents = len(contents)
-        if num_contents == 0:
-            # No hash when no child elements exist
-            digest = ""
-        elif num_contents == 1:
-            # If single child, use digest of child
-            digest = content_digest(contents[0])
-        else:
-            # Build content digest from the "non-empty" digests of child nodes
-            digest = hashlib.sha256()
-            child_digests = list(
-                filter(lambda x: x != "", [content_digest(content) for content in contents]))
-            for child in child_digests:
-                digest.update(child.encode('utf-8'))
-            digest = digest.hexdigest()
-    return digest

+ 18 - 18
api/core/tools/tool/dataset_retriever_tool.py

@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
 
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
 from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
-from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
 from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
 
 
     @staticmethod
     @staticmethod
     def get_dataset_tools(tenant_id: str,
     def get_dataset_tools(tenant_id: str,
-                         dataset_ids: list[str],
-                         retrieve_config: DatasetRetrieveConfigEntity,
-                         return_resource: bool,
-                         invoke_from: InvokeFrom,
-                         hit_callback: DatasetIndexToolCallbackHandler
-    ) -> list['DatasetRetrieverTool']:
+                          dataset_ids: list[str],
+                          retrieve_config: DatasetRetrieveConfigEntity,
+                          return_resource: bool,
+                          invoke_from: InvokeFrom,
+                          hit_callback: DatasetIndexToolCallbackHandler
+                          ) -> list['DatasetRetrieverTool']:
         """
         """
         get dataset tool
         get dataset tool
         """
         """
@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
         )
         )
         # restore retrieve strategy
         # restore retrieve strategy
         retrieve_config.retrieve_strategy = original_retriever_mode
         retrieve_config.retrieve_strategy = original_retriever_mode
-        
+
         # convert langchain tools to Tools
         # convert langchain tools to Tools
         tools = []
         tools = []
         for langchain_tool in langchain_tools:
         for langchain_tool in langchain_tools:
@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
                     llm=langchain_tool.description),
                     llm=langchain_tool.description),
                 runtime=DatasetRetrieverTool.Runtime()
                 runtime=DatasetRetrieverTool.Runtime()
             )
             )
-            
+
             tools.append(tool)
             tools.append(tool)
 
 
         return tools
         return tools
@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
     def get_runtime_parameters(self) -> list[ToolParameter]:
     def get_runtime_parameters(self) -> list[ToolParameter]:
         return [
         return [
             ToolParameter(name='query',
             ToolParameter(name='query',
-                         label=I18nObject(en_US='', zh_Hans=''),
-                         human_description=I18nObject(en_US='', zh_Hans=''),
-                         type=ToolParameter.ToolParameterType.STRING,
-                         form=ToolParameter.ToolParameterForm.LLM,
-                         llm_description='Query for the dataset to be used to retrieve the dataset.',
-                         required=True,
-                         default=''),
+                          label=I18nObject(en_US='', zh_Hans=''),
+                          human_description=I18nObject(en_US='', zh_Hans=''),
+                          type=ToolParameter.ToolParameterType.STRING,
+                          form=ToolParameter.ToolParameterForm.LLM,
+                          llm_description='Query for the dataset to be used to retrieve the dataset.',
+                          required=True,
+                          default=''),
         ]
         ]
 
 
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
         query = tool_parameters.get('query', None)
         query = tool_parameters.get('query', None)
         if not query:
         if not query:
             return self.create_text_message(text='please input query')
             return self.create_text_message(text='please input query')
-        
+
         # invoke dataset retriever tool
         # invoke dataset retriever tool
         result = self.langchain_tool._run(query=query)
         result = self.langchain_tool._run(query=query)
 
 
@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
         """
         """
         validate the credentials for dataset retriever tool
         validate the credentials for dataset retriever tool
         """
         """
-        pass
+        pass

+ 0 - 109
api/core/tools/utils/web_reader_tool.py

@@ -7,23 +7,14 @@ import subprocess
 import tempfile
 import tempfile
 import unicodedata
 import unicodedata
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import Any
 
 
 import requests
 import requests
 from bs4 import BeautifulSoup, CData, Comment, NavigableString
 from bs4 import BeautifulSoup, CData, Comment, NavigableString
-from langchain.chains import RefineDocumentsChain
-from langchain.chains.summarize import refine_prompts
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.tools.base import BaseTool
 from newspaper import Article
 from newspaper import Article
-from pydantic import BaseModel, Field
 from regex import regex
 from regex import regex
 
 
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
 from core.rag.extractor import extract_processor
 from core.rag.extractor import extract_processor
 from core.rag.extractor.extract_processor import ExtractProcessor
 from core.rag.extractor.extract_processor import ExtractProcessor
-from core.rag.models.document import Document
 
 
 FULL_TEMPLATE = """
 FULL_TEMPLATE = """
 TITLE: {title}
 TITLE: {title}
@@ -36,106 +27,6 @@ TEXT:
 """
 """
 
 
 
 
-class WebReaderToolInput(BaseModel):
-    url: str = Field(..., description="URL of the website to read")
-    summary: bool = Field(
-        default=False,
-        description="When the user's question requires extracting the summarizing content of the webpage, "
-                    "set it to true."
-    )
-    cursor: int = Field(
-        default=0,
-        description="Start reading from this character."
-        "Use when the first response was truncated"
-        "and you want to continue reading the page."
-        "The value cannot exceed 24000.",
-    )
-
-
-class WebReaderTool(BaseTool):
-    """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
-
-    name: str = "web_reader"
-    args_schema: type[BaseModel] = WebReaderToolInput
-    description: str = "use this to read a website. " \
-                       "If you can answer the question based on the information provided, " \
-                       "there is no need to use."
-    page_contents: str = None
-    url: str = None
-    max_chunk_length: int = 4000
-    summary_chunk_tokens: int = 4000
-    summary_chunk_overlap: int = 0
-    summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
-    continue_reading: bool = True
-    model_config: ModelConfigEntity
-    model_parameters: dict[str, Any]
-
-    def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
-        try:
-            if not self.page_contents or self.url != url:
-                page_contents = get_url(url)
-                self.page_contents = page_contents
-                self.url = url
-            else:
-                page_contents = self.page_contents
-        except Exception as e:
-            return f'Read this website failed, caused by: {str(e)}.'
-
-        if summary:
-            character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
-                chunk_size=self.summary_chunk_tokens,
-                chunk_overlap=self.summary_chunk_overlap,
-                separators=self.summary_separators
-            )
-
-            texts = character_splitter.split_text(page_contents)
-            docs = [Document(page_content=t) for t in texts]
-
-            if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
-                return "No content found."
-
-            # only use first 5 docs
-            if len(docs) > 5:
-                docs = docs[:5]
-
-            chain = self.get_summary_chain()
-            try:
-                page_contents = chain.run(docs)
-            except Exception as e:
-                return f'Read this website failed, caused by: {str(e)}.'
-        else:
-            page_contents = page_result(page_contents, cursor, self.max_chunk_length)
-
-            if self.continue_reading and len(page_contents) >= self.max_chunk_length:
-                page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
-                                 f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
-                                 f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
-
-        return page_contents
-
-    async def _arun(self, url: str) -> str:
-        raise NotImplementedError
-
-    def get_summary_chain(self) -> RefineDocumentsChain:
-        initial_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.PROMPT,
-            parameters=self.model_parameters
-        )
-        refine_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.REFINE_PROMPT,
-            parameters=self.model_parameters
-        )
-        return RefineDocumentsChain(
-            initial_llm_chain=initial_chain,
-            refine_llm_chain=refine_chain,
-            document_variable_name="text",
-            initial_response_name="existing_answer",
-            callbacks=self.callbacks
-        )
-
-
 def page_result(text: str, cursor: int, max_length: int) -> str:
 def page_result(text: str, cursor: int, max_length: int) -> str:
     """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
     """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
     return text[cursor: cursor + max_length]
     return text[cursor: cursor + max_length]

+ 1 - 1
api/services/app_model_config_service.py

@@ -1,7 +1,7 @@
 import re
 import re
 import uuid
 import uuid
 
 
-from core.agent.agent_executor import PlanningStrategy
+from core.entities.agent_entities import PlanningStrategy
 from core.external_data_tool.factory import ExternalDataToolFactory
 from core.external_data_tool.factory import ExternalDataToolFactory
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers import model_provider_factory
 from core.model_runtime.model_providers import model_provider_factory