ソースを参照

refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614)

takatost 1 年間 前
コミット
dd961985f0
29 ファイル変更41 行追加2016 行削除
  1. 0 49
      api/core/agent/agent/calc_token_mixin.py
  2. 0 361
      api/core/agent/agent/openai_function_call.py
  3. 0 306
      api/core/agent/agent/structured_chat.py
  4. 1 39
      api/core/app_runner/assistant_app_runner.py
  5. 1 1
      api/core/app_runner/basic_app_runner.py
  6. 8 0
      api/core/entities/agent_entities.py
  7. 0 199
      api/core/features/agent_runner.py
  8. 0 0
      api/core/features/dataset_retrieval/__init__.py
  9. 0 0
      api/core/features/dataset_retrieval/agent/__init__.py
  10. 0 0
      api/core/features/dataset_retrieval/agent/agent_llm_callback.py
  11. 0 0
      api/core/features/dataset_retrieval/agent/fake_llm.py
  12. 2 2
      api/core/features/dataset_retrieval/agent/llm_chain.py
  13. 1 1
      api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py
  14. 0 0
      api/core/features/dataset_retrieval/agent/output_parser/__init__.py
  15. 0 0
      api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py
  16. 1 1
      api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py
  17. 6 36
      api/core/features/dataset_retrieval/agent_based_dataset_executor.py
  18. 2 1
      api/core/features/dataset_retrieval/dataset_retrieval.py
  19. 0 189
      api/core/third_party/spark/spark_llm.py
  20. 0 24
      api/core/tool/current_datetime_tool.py
  21. 0 63
      api/core/tool/provider/base.py
  22. 0 2
      api/core/tool/provider/errors.py
  23. 0 77
      api/core/tool/provider/serpapi_provider.py
  24. 0 43
      api/core/tool/provider/tool_provider_service.py
  25. 0 51
      api/core/tool/serpapi_wrapper.py
  26. 0 443
      api/core/tool/web_reader_tool.py
  27. 18 18
      api/core/tools/tool/dataset_retriever_tool.py
  28. 0 109
      api/core/tools/utils/web_reader_tool.py
  29. 1 1
      api/services/app_model_config_service.py

+ 0 - 49
api/core/agent/agent/calc_token_mixin.py

@@ -1,49 +0,0 @@
-from typing import cast
-
-from core.entities.application_entities import ModelConfigEntity
-from core.model_runtime.entities.message_entities import PromptMessage
-from core.model_runtime.entities.model_entities import ModelPropertyKey
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-
-
-class CalcTokenMixin:
-
-    def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
-        """
-        Got the rest tokens available for the model after excluding messages tokens and completion max tokens
-
-        :param model_config:
-        :param messages:
-        :return:
-        """
-        model_type_instance = model_config.provider_model_bundle.model_type_instance
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
-        model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
-
-        max_tokens = 0
-        for parameter_rule in model_config.model_schema.parameter_rules:
-            if (parameter_rule.name == 'max_tokens'
-                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
-                max_tokens = (model_config.parameters.get(parameter_rule.name)
-                              or model_config.parameters.get(parameter_rule.use_template)) or 0
-
-        if model_context_tokens is None:
-            return 0
-
-        if max_tokens is None:
-            max_tokens = 0
-
-        prompt_tokens = model_type_instance.get_num_tokens(
-            model_config.model,
-            model_config.credentials,
-            messages
-        )
-
-        rest_tokens = model_context_tokens - max_tokens - prompt_tokens
-
-        return rest_tokens
-
-
-class ExceededLLMTokensLimitError(Exception):
-    pass

+ 0 - 361
api/core/agent/agent/openai_function_call.py

@@ -1,361 +0,0 @@
-from collections.abc import Sequence
-from typing import Any, Optional, Union
-
-from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
-from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
-from langchain.memory.prompt import SUMMARY_PROMPT
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import (
-    AgentAction,
-    AgentFinish,
-    AIMessage,
-    BaseMessage,
-    HumanMessage,
-    SystemMessage,
-    get_buffer_string,
-)
-from langchain.tools import BaseTool
-from pydantic import root_validator
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-from core.model_manager import ModelInstance
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
-from core.third_party.langchain.llms.fake import FakeLLM
-
-
-class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
-    moving_summary_buffer: str = ""
-    moving_summary_index: int = 0
-    summary_model_config: ModelConfigEntity = None
-    model_config: ModelConfigEntity
-    agent_llm_callback: Optional[AgentLLMCallback] = None
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    @root_validator
-    def validate_llm(cls, values: dict) -> dict:
-        return values
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_config: ModelConfigEntity,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            agent_llm_callback: Optional[AgentLLMCallback] = None,
-            **kwargs: Any,
-    ) -> BaseSingleActionAgent:
-        prompt = cls.create_prompt(
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=system_message,
-        )
-        return cls(
-            model_config=model_config,
-            llm=FakeLLM(response=''),
-            prompt=prompt,
-            tools=tools,
-            callback_manager=callback_manager,
-            agent_llm_callback=agent_llm_callback,
-            **kwargs,
-        )
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        original_max_tokens = 0
-        for parameter_rule in self.model_config.model_schema.parameter_rules:
-            if (parameter_rule.name == 'max_tokens'
-                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
-                original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
-                              or self.model_config.parameters.get(parameter_rule.use_template)) or 0
-
-        self.model_config.parameters['max_tokens'] = 40
-
-        prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
-        messages = prompt.to_messages()
-
-        try:
-            prompt_messages = lc_messages_to_prompt_messages(messages)
-            model_instance = ModelInstance(
-                provider_model_bundle=self.model_config.provider_model_bundle,
-                model=self.model_config.model,
-            )
-
-            tools = []
-            for function in self.functions:
-                tool = PromptMessageTool(
-                    **function
-                )
-
-                tools.append(tool)
-
-            result = model_instance.invoke_llm(
-                prompt_messages=prompt_messages,
-                tools=tools,
-                stream=False,
-                model_parameters={
-                    'temperature': 0.2,
-                    'top_p': 0.3,
-                    'max_tokens': 1500
-                }
-            )
-        except Exception as e:
-            raise e
-
-        self.model_config.parameters['max_tokens'] = original_max_tokens
-
-        return True if result.message.tool_calls else False
-
-    def plan(
-            self,
-            intermediate_steps: list[tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-
-        prompt_messages = lc_messages_to_prompt_messages(messages)
-
-        # summarize messages if rest_tokens < 0
-        try:
-            prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
-        except ExceededLLMTokensLimitError as e:
-            return AgentFinish(return_values={"output": str(e)}, log=str(e))
-
-        model_instance = ModelInstance(
-            provider_model_bundle=self.model_config.provider_model_bundle,
-            model=self.model_config.model,
-        )
-
-        tools = []
-        for function in self.functions:
-            tool = PromptMessageTool(
-                **function
-            )
-
-            tools.append(tool)
-
-        result = model_instance.invoke_llm(
-            prompt_messages=prompt_messages,
-            tools=tools,
-            stream=False,
-            callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
-            model_parameters={
-                'temperature': 0.2,
-                'top_p': 0.3,
-                'max_tokens': 1500
-            }
-        )
-
-        ai_message = AIMessage(
-            content=result.message.content or "",
-            additional_kwargs={
-                'function_call': {
-                    'id': result.message.tool_calls[0].id,
-                    **result.message.tool_calls[0].function.dict()
-                } if result.message.tool_calls else None
-            }
-        )
-        agent_decision = _parse_ai_message(ai_message)
-
-        if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
-            tool_inputs = agent_decision.tool_input
-            if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
-                tool_inputs['query'] = kwargs['input']
-                agent_decision.tool_input = tool_inputs
-
-        return agent_decision
-
-    @classmethod
-    def get_system_message(cls):
-        return SystemMessage(content="You are a helpful AI assistant.\n"
-                                     "The current date or current time you know is wrong.\n"
-                                     "Respond directly if appropriate.")
-
-    def return_stopped_response(
-            self,
-            early_stopping_method: str,
-            intermediate_steps: list[tuple[AgentAction, str]],
-            **kwargs: Any,
-    ) -> AgentFinish:
-        try:
-            return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
-        except ValueError:
-            return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
-
-    def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
-        # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(
-            self.model_config,
-            messages,
-            **kwargs
-        )
-
-        rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
-        if rest_tokens >= 0:
-            return messages
-
-        system_message = None
-        human_message = None
-        should_summary_messages = []
-        for message in messages:
-            if isinstance(message, SystemMessage):
-                system_message = message
-            elif isinstance(message, HumanMessage):
-                human_message = message
-            else:
-                should_summary_messages.append(message)
-
-        if len(should_summary_messages) > 2:
-            ai_message = should_summary_messages[-2]
-            function_message = should_summary_messages[-1]
-            should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
-            self.moving_summary_index = len(should_summary_messages)
-        else:
-            error_msg = "Exceeded LLM tokens limit, stopped."
-            raise ExceededLLMTokensLimitError(error_msg)
-
-        new_messages = [system_message, human_message]
-
-        if self.moving_summary_index == 0:
-            should_summary_messages.insert(0, human_message)
-
-        self.moving_summary_buffer = self.predict_new_summary(
-            messages=should_summary_messages,
-            existing_summary=self.moving_summary_buffer
-        )
-
-        new_messages.append(AIMessage(content=self.moving_summary_buffer))
-        new_messages.append(ai_message)
-        new_messages.append(function_message)
-
-        return new_messages
-
-    def predict_new_summary(
-        self, messages: list[BaseMessage], existing_summary: str
-    ) -> str:
-        new_lines = get_buffer_string(
-            messages,
-            human_prefix="Human",
-            ai_prefix="AI",
-        )
-
-        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
-        return chain.predict(summary=existing_summary, new_lines=new_lines)
-
-    def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
-        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
-
-        Official documentation: https://github.com/openai/openai-cookbook/blob/
-        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
-        if model_config.provider == 'azure_openai':
-            model = model_config.model
-            model = model.replace("gpt-35", "gpt-3.5")
-        else:
-            model = model_config.credentials.get("base_model_name")
-
-        tiktoken_ = _import_tiktoken()
-        try:
-            encoding = tiktoken_.encoding_for_model(model)
-        except KeyError:
-            model = "cl100k_base"
-            encoding = tiktoken_.get_encoding(model)
-
-        if model.startswith("gpt-3.5-turbo"):
-            # every message follows <im_start>{role/name}\n{content}<im_end>\n
-            tokens_per_message = 4
-            # if there's a name, the role is omitted
-            tokens_per_name = -1
-        elif model.startswith("gpt-4"):
-            tokens_per_message = 3
-            tokens_per_name = 1
-        else:
-            raise NotImplementedError(
-                f"get_num_tokens_from_messages() is not presently implemented "
-                f"for model {model}."
-                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
-                "information on how messages are converted to tokens."
-            )
-        num_tokens = 0
-        for m in messages:
-            message = _convert_message_to_dict(m)
-            num_tokens += tokens_per_message
-            for key, value in message.items():
-                if key == "function_call":
-                    for f_key, f_value in value.items():
-                        num_tokens += len(encoding.encode(f_key))
-                        num_tokens += len(encoding.encode(f_value))
-                else:
-                    num_tokens += len(encoding.encode(value))
-
-                if key == "name":
-                    num_tokens += tokens_per_name
-        # every reply is primed with <im_start>assistant
-        num_tokens += 3
-
-        if kwargs.get('functions'):
-            for function in kwargs.get('functions'):
-                num_tokens += len(encoding.encode('name'))
-                num_tokens += len(encoding.encode(function.get("name")))
-                num_tokens += len(encoding.encode('description'))
-                num_tokens += len(encoding.encode(function.get("description")))
-                parameters = function.get("parameters")
-                num_tokens += len(encoding.encode('parameters'))
-                if 'title' in parameters:
-                    num_tokens += len(encoding.encode('title'))
-                    num_tokens += len(encoding.encode(parameters.get("title")))
-                num_tokens += len(encoding.encode('type'))
-                num_tokens += len(encoding.encode(parameters.get("type")))
-                if 'properties' in parameters:
-                    num_tokens += len(encoding.encode('properties'))
-                    for key, value in parameters.get('properties').items():
-                        num_tokens += len(encoding.encode(key))
-                        for field_key, field_value in value.items():
-                            num_tokens += len(encoding.encode(field_key))
-                            if field_key == 'enum':
-                                for enum_field in field_value:
-                                    num_tokens += 3
-                                    num_tokens += len(encoding.encode(enum_field))
-                            else:
-                                num_tokens += len(encoding.encode(field_key))
-                                num_tokens += len(encoding.encode(str(field_value)))
-                if 'required' in parameters:
-                    num_tokens += len(encoding.encode('required'))
-                    for required_field in parameters['required']:
-                        num_tokens += 3
-                        num_tokens += len(encoding.encode(required_field))
-
-        return num_tokens

+ 0 - 306
api/core/agent/agent/structured_chat.py

@@ -1,306 +0,0 @@
-import re
-from collections.abc import Sequence
-from typing import Any, Optional, Union, cast
-
-from langchain import BasePromptTemplate, PromptTemplate
-from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
-from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.memory.prompt import SUMMARY_PROMPT
-from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
-from langchain.schema import (
-    AgentAction,
-    AgentFinish,
-    AIMessage,
-    BaseMessage,
-    HumanMessage,
-    OutputParserException,
-    get_buffer_string,
-)
-from langchain.tools import BaseTool
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.entities.message_entities import lc_messages_to_prompt_messages
-
-FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
-Valid "action" values: "Final Answer" or {tool_names}
-
-Provide only ONE action per $JSON_BLOB, as shown:
-
-```
-{{{{
-  "action": $TOOL_NAME,
-  "action_input": $INPUT
-}}}}
-```
-
-Follow this format:
-
-Question: input question to answer
-Thought: consider previous and subsequent steps
-Action:
-```
-$JSON_BLOB
-```
-Observation: action result
-... (repeat Thought/Action/Observation N times)
-Thought: I know what to respond
-Action:
-```
-{{{{
-  "action": "Final Answer",
-  "action_input": "Final response to human"
-}}}}
-```"""
-
-
-class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
-    moving_summary_buffer: str = ""
-    moving_summary_index: int = 0
-    summary_model_config: ModelConfigEntity = None
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-        Using the ReACT mode to determine whether an agent is needed is costly,
-        so it's better to just use an Agent for reasoning, which is cheaper.
-
-        :param query:
-        :return:
-        """
-        return True
-
-    def plan(
-        self,
-        intermediate_steps: list[tuple[AgentAction, str]],
-        callbacks: Callbacks = None,
-        **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date,
-                along with observatons
-            callbacks: Callbacks to run.
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
-        prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
-
-        messages = []
-        if prompts:
-            messages = prompts[0].to_messages()
-
-        prompt_messages = lc_messages_to_prompt_messages(messages)
-
-        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
-        if rest_tokens < 0:
-            full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
-
-        try:
-            full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
-        except Exception as e:
-            raise e
-
-        try:
-            agent_decision = self.output_parser.parse(full_output)
-            if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
-                tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
-                    tool_inputs['query'] = kwargs['input']
-                    agent_decision.tool_input = tool_inputs
-            return agent_decision
-        except OutputParserException:
-            return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
-                                          "I don't know how to respond to that."}, "")
-
-    def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2 and self.summary_model_config:
-            should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
-            should_summary_messages = [AIMessage(content=observation)
-                                       for _, observation in should_summary_intermediate_steps]
-            if self.moving_summary_index == 0:
-                should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
-
-            self.moving_summary_index = len(intermediate_steps)
-        else:
-            error_msg = "Exceeded LLM tokens limit, stopped."
-            raise ExceededLLMTokensLimitError(error_msg)
-
-        if self.moving_summary_buffer and 'chat_history' in kwargs:
-            kwargs["chat_history"].pop()
-
-        self.moving_summary_buffer = self.predict_new_summary(
-            messages=should_summary_messages,
-            existing_summary=self.moving_summary_buffer
-        )
-
-        if 'chat_history' in kwargs:
-            kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
-
-        return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
-
-    def predict_new_summary(
-        self, messages: list[BaseMessage], existing_summary: str
-    ) -> str:
-        new_lines = get_buffer_string(
-            messages,
-            human_prefix="Human",
-            ai_prefix="AI",
-        )
-
-        chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
-        return chain.predict(summary=existing_summary, new_lines=new_lines)
-
-    @classmethod
-    def create_prompt(
-            cls,
-            tools: Sequence[BaseTool],
-            prefix: str = PREFIX,
-            suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-            memory_prompts: Optional[list[BasePromptTemplate]] = None,
-    ) -> BasePromptTemplate:
-        tool_strings = []
-        for tool in tools:
-            args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
-            tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
-        formatted_tools = "\n".join(tool_strings)
-        tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
-        format_instructions = format_instructions.format(tool_names=tool_names)
-        template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        _memory_prompts = memory_prompts or []
-        messages = [
-            SystemMessagePromptTemplate.from_template(template),
-            *_memory_prompts,
-            HumanMessagePromptTemplate.from_template(human_message_template),
-        ]
-        return ChatPromptTemplate(input_variables=input_variables, messages=messages)
-
-    @classmethod
-    def create_completion_prompt(
-            cls,
-            tools: Sequence[BaseTool],
-            prefix: str = PREFIX,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-    ) -> PromptTemplate:
-        """Create prompt in the style of the zero shot agent.
-
-        Args:
-            tools: List of tools the agent will have access to, used to format the
-                prompt.
-            prefix: String to put before the list of tools.
-            input_variables: List of input variables the final prompt will expect.
-
-        Returns:
-            A PromptTemplate with the template assembled from the pieces here.
-        """
-        suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
-Question: {input}
-Thought: {agent_scratchpad}
-"""
-
-        tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
-        tool_names = ", ".join([tool.name for tool in tools])
-        format_instructions = format_instructions.format(tool_names=tool_names)
-        template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
-        if input_variables is None:
-            input_variables = ["input", "agent_scratchpad"]
-        return PromptTemplate(template=template, input_variables=input_variables)
-
-    def _construct_scratchpad(
-        self, intermediate_steps: list[tuple[AgentAction, str]]
-    ) -> str:
-        agent_scratchpad = ""
-        for action, observation in intermediate_steps:
-            agent_scratchpad += action.log
-            agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
-
-        if not isinstance(agent_scratchpad, str):
-            raise ValueError("agent_scratchpad should be of type string.")
-        if agent_scratchpad:
-            llm_chain = cast(LLMChain, self.llm_chain)
-            if llm_chain.model_config.mode == "chat":
-                return (
-                    f"This was your previous work "
-                    f"(but I haven't seen any of it! I only see what "
-                    f"you return as final answer):\n{agent_scratchpad}"
-                )
-            else:
-                return agent_scratchpad
-        else:
-            return agent_scratchpad
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            model_config: ModelConfigEntity,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            output_parser: Optional[AgentOutputParser] = None,
-            prefix: str = PREFIX,
-            suffix: str = SUFFIX,
-            human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
-            format_instructions: str = FORMAT_INSTRUCTIONS,
-            input_variables: Optional[list[str]] = None,
-            memory_prompts: Optional[list[BasePromptTemplate]] = None,
-            agent_llm_callback: Optional[AgentLLMCallback] = None,
-            **kwargs: Any,
-    ) -> Agent:
-        """Construct an agent from an LLM and tools."""
-        cls._validate_tools(tools)
-        if model_config.mode == "chat":
-            prompt = cls.create_prompt(
-                tools,
-                prefix=prefix,
-                suffix=suffix,
-                human_message_template=human_message_template,
-                format_instructions=format_instructions,
-                input_variables=input_variables,
-                memory_prompts=memory_prompts,
-            )
-        else:
-            prompt = cls.create_completion_prompt(
-                tools,
-                prefix=prefix,
-                format_instructions=format_instructions,
-                input_variables=input_variables,
-            )
-        llm_chain = LLMChain(
-            model_config=model_config,
-            prompt=prompt,
-            callback_manager=callback_manager,
-            agent_llm_callback=agent_llm_callback,
-            parameters={
-                'temperature': 0.2,
-                'top_p': 0.3,
-                'max_tokens': 1500
-            }
-        )
-        tool_names = [tool.name for tool in tools]
-        _output_parser = output_parser
-        return cls(
-            llm_chain=llm_chain,
-            allowed_tools=tool_names,
-            output_parser=_output_parser,
-            **kwargs,
-        )

+ 1 - 39
api/core/app_runner/assistant_app_runner.py

@@ -1,4 +1,3 @@
-import json
 import logging
 from typing import cast
 
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
 from core.moderation.base import ModerationException
 from core.tools.entities.tool_entities import ToolRuntimeVariablePool
 from extensions.ext_database import db
-from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
+from models.model import App, Conversation, Message, MessageAgentThought
 from models.tools import ToolConversationVariables
 
 logger = logging.getLogger(__name__)
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
 
         # convert db variables to tool variables
         tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
-        
-        message_chain = self._init_message_chain(
-            message=message,
-            query=query
-        )
 
         # init model instance
         model_instance = ModelInstance(
@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
             'pool': db_variables.variables
         })
 
-    def _init_message_chain(self, message: Message, query: str) -> MessageChain:
-        """
-        Init MessageChain
-        :param message: message
-        :param query: query
-        :return:
-        """
-        message_chain = MessageChain(
-            message_id=message.id,
-            type="AgentExecutor",
-            input=json.dumps({
-                "input": query
-            })
-        )
-
-        db.session.add(message_chain)
-        db.session.commit()
-
-        return message_chain
-
-    def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
-        """
-        Save MessageChain
-        :param message_chain: message chain
-        :param output_text: output text
-        :return:
-        """
-        message_chain.output = json.dumps({
-            "output": output_text
-        })
-        db.session.commit()
-
     def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
                                          message: Message) -> LLMUsage:
         """

+ 1 - 1
api/core/app_runner/basic_app_runner.py

@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
 from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
-from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.moderation.base import ModerationException

+ 8 - 0
api/core/entities/agent_entities.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class PlanningStrategy(Enum):
+    ROUTER = 'router'
+    REACT_ROUTER = 'react_router'
+    REACT = 'react'
+    FUNCTION_CALL = 'function_call'

+ 0 - 199
api/core/features/agent_runner.py

@@ -1,199 +0,0 @@
-import logging
-from typing import Optional, cast
-
-from langchain.tools import BaseTool
-
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
-from core.application_queue_manager import ApplicationQueueManager
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.entities.application_entities import (
-    AgentEntity,
-    AppOrchestrationConfigEntity,
-    InvokeFrom,
-    ModelConfigEntity,
-)
-from core.memory.token_buffer_memory import TokenBufferMemory
-from core.model_runtime.entities.model_entities import ModelFeature, ModelType
-from core.model_runtime.model_providers import model_provider_factory
-from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
-from extensions.ext_database import db
-from models.dataset import Dataset
-from models.model import Message
-
-logger = logging.getLogger(__name__)
-
-
-class AgentRunnerFeature:
-    def __init__(self, tenant_id: str,
-                 app_orchestration_config: AppOrchestrationConfigEntity,
-                 model_config: ModelConfigEntity,
-                 config: AgentEntity,
-                 queue_manager: ApplicationQueueManager,
-                 message: Message,
-                 user_id: str,
-                 agent_llm_callback: AgentLLMCallback,
-                 callback: AgentLoopGatherCallbackHandler,
-                 memory: Optional[TokenBufferMemory] = None,) -> None:
-        """
-        Agent runner
-        :param tenant_id: tenant id
-        :param app_orchestration_config: app orchestration config
-        :param model_config: model config
-        :param config: dataset config
-        :param queue_manager: queue manager
-        :param message: message
-        :param user_id: user id
-        :param agent_llm_callback: agent llm callback
-        :param callback: callback
-        :param memory: memory
-        """
-        self.tenant_id = tenant_id
-        self.app_orchestration_config = app_orchestration_config
-        self.model_config = model_config
-        self.config = config
-        self.queue_manager = queue_manager
-        self.message = message
-        self.user_id = user_id
-        self.agent_llm_callback = agent_llm_callback
-        self.callback = callback
-        self.memory = memory
-
-    def run(self, query: str,
-            invoke_from: InvokeFrom) -> Optional[str]:
-        """
-        Retrieve agent loop result.
-        :param query: query
-        :param invoke_from: invoke from
-        :return:
-        """
-        provider = self.config.provider
-        model = self.config.model
-        tool_configs = self.config.tools
-
-        # check model is support tool calling
-        provider_instance = model_provider_factory.get_provider_instance(provider=provider)
-        model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
-        model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
-        # get model schema
-        model_schema = model_type_instance.get_model_schema(
-            model=model,
-            credentials=self.model_config.credentials
-        )
-
-        if not model_schema:
-            return None
-
-        planning_strategy = PlanningStrategy.REACT
-        features = model_schema.features
-        if features:
-            if ModelFeature.TOOL_CALL in features \
-                    or ModelFeature.MULTI_TOOL_CALL in features:
-                planning_strategy = PlanningStrategy.FUNCTION_CALL
-
-        tools = self.to_tools(
-            tool_configs=tool_configs,
-            invoke_from=invoke_from,
-            callbacks=[self.callback, DifyStdOutCallbackHandler()],
-        )
-
-        if len(tools) == 0:
-            return None
-
-        agent_configuration = AgentConfiguration(
-            strategy=planning_strategy,
-            model_config=self.model_config,
-            tools=tools,
-            memory=self.memory,
-            max_iterations=10,
-            max_execution_time=400.0,
-            early_stopping_method="generate",
-            agent_llm_callback=self.agent_llm_callback,
-            callbacks=[self.callback, DifyStdOutCallbackHandler()]
-        )
-
-        agent_executor = AgentExecutor(agent_configuration)
-
-        try:
-            # check if should use agent
-            should_use_agent = agent_executor.should_use_agent(query)
-            if not should_use_agent:
-                return None
-
-            result = agent_executor.run(query)
-            return result.output
-        except Exception as ex:
-            logger.exception("agent_executor run failed")
-            return None
-
-    def to_dataset_retriever_tool(self, tool_config: dict,
-                                  invoke_from: InvokeFrom) \
-            -> Optional[BaseTool]:
-        """
-        A dataset tool is a tool that can be used to retrieve information from a dataset
-        :param tool_config: tool config
-        :param invoke_from: invoke from
-        """
-        show_retrieve_source = self.app_orchestration_config.show_retrieve_source
-
-        hit_callback = DatasetIndexToolCallbackHandler(
-            queue_manager=self.queue_manager,
-            app_id=self.message.app_id,
-            message_id=self.message.id,
-            user_id=self.user_id,
-            invoke_from=invoke_from
-        )
-
-        # get dataset from dataset id
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == self.tenant_id,
-            Dataset.id == tool_config.get("id")
-        ).first()
-
-        # pass if dataset is not available
-        if not dataset:
-            return None
-
-        # pass if dataset is not available
-        if (dataset and dataset.available_document_count == 0
-                and dataset.available_document_count == 0):
-            return None
-
-        # get retrieval model config
-        default_retrieval_model = {
-            'search_method': 'semantic_search',
-            'reranking_enable': False,
-            'reranking_model': {
-                'reranking_provider_name': '',
-                'reranking_model_name': ''
-            },
-            'top_k': 2,
-            'score_threshold_enabled': False
-        }
-
-        retrieval_model_config = dataset.retrieval_model \
-            if dataset.retrieval_model else default_retrieval_model
-
-        # get top k
-        top_k = retrieval_model_config['top_k']
-
-        # get score threshold
-        score_threshold = None
-        score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
-        if score_threshold_enabled:
-            score_threshold = retrieval_model_config.get("score_threshold")
-
-        tool = DatasetRetrieverTool.from_dataset(
-            dataset=dataset,
-            top_k=top_k,
-            score_threshold=score_threshold,
-            hit_callbacks=[hit_callback],
-            return_resource=show_retrieve_source,
-            retriever_from=invoke_from.to_source()
-        )
-
-        return tool

+ 0 - 0
api/core/third_party/langchain/llms/__init__.py → api/core/features/dataset_retrieval/__init__.py


+ 0 - 0
api/core/third_party/spark/__init__.py → api/core/features/dataset_retrieval/agent/__init__.py


+ 0 - 0
api/core/agent/agent/agent_llm_callback.py → api/core/features/dataset_retrieval/agent/agent_llm_callback.py


+ 0 - 0
api/core/third_party/langchain/llms/fake.py → api/core/features/dataset_retrieval/agent/fake_llm.py


+ 2 - 2
api/core/chain/llm_chain.py → api/core/features/dataset_retrieval/agent/llm_chain.py

@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.schema import Generation, LLMResult
 from langchain.schema.language_model import BaseLanguageModel
 
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
+from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
 from core.model_manager import ModelInstance
-from core.third_party.langchain.llms.fake import FakeLLM
 
 
 class LLMChain(LCLLMChain):

+ 1 - 1
api/core/agent/agent/multi_dataset_router_agent.py → api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py

@@ -12,9 +12,9 @@ from pydantic import root_validator
 
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.message_entities import lc_messages_to_prompt_messages
+from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.message_entities import PromptMessageTool
-from core.third_party.langchain.llms.fake import FakeLLM
 
 
 class MultiDatasetRouterAgent(OpenAIFunctionsAgent):

+ 0 - 0
api/core/data_loader/file_extractor.py → api/core/features/dataset_retrieval/agent/output_parser/__init__.py


+ 0 - 0
api/core/agent/agent/output_parser/structured_chat.py → api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py


+ 1 - 1
api/core/agent/agent/structed_multi_dataset_router_agent.py → api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py

@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
 from langchain.schema import AgentAction, AgentFinish, OutputParserException
 from langchain.tools import BaseTool
 
-from core.chain.llm_chain import LLMChain
 from core.entities.application_entities import ModelConfigEntity
+from core.features.dataset_retrieval.agent.llm_chain import LLMChain
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.

+ 6 - 36
api/core/agent/agent_executor.py → api/core/features/dataset_retrieval/agent_based_dataset_executor.py

@@ -1,4 +1,3 @@
-import enum
 import logging
 from typing import Optional, Union
 
@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
 from langchain.tools import BaseTool
 from pydantic import BaseModel, Extra
 
-from core.agent.agent.agent_llm_callback import AgentLLMCallback
-from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
-from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
-from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
-from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
-from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
+from core.entities.agent_entities import PlanningStrategy
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.message_entities import prompt_messages_to_lc_messages
+from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
+from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
+from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
+from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.helper import moderation
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.errors.invoke import InvokeError
@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
 from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 
 
-class PlanningStrategy(str, enum.Enum):
-    ROUTER = 'router'
-    REACT_ROUTER = 'react_router'
-    REACT = 'react'
-    FUNCTION_CALL = 'function_call'
-
-
 class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
     model_config: ModelConfigEntity
@@ -62,28 +53,7 @@ class AgentExecutor:
         self.agent = self._init_agent()
 
     def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
-        if self.configuration.strategy == PlanningStrategy.REACT:
-            agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
-                model_config=self.configuration.model_config,
-                tools=self.configuration.tools,
-                output_parser=StructuredChatOutputParser(),
-                summary_model_config=self.configuration.summary_model_config
-                if self.configuration.summary_model_config else None,
-                agent_llm_callback=self.configuration.agent_llm_callback,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
-            agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
-                model_config=self.configuration.model_config,
-                tools=self.configuration.tools,
-                extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
-                if self.configuration.memory else None,  # used for read chat histories memory
-                summary_model_config=self.configuration.summary_model_config
-                if self.configuration.summary_model_config else None,
-                agent_llm_callback=self.configuration.agent_llm_callback,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.ROUTER:
+        if self.configuration.strategy == PlanningStrategy.ROUTER:
             self.configuration.tools = [t for t in self.configuration.tools
                                         if isinstance(t, DatasetRetrieverTool)
                                         or isinstance(t, DatasetMultiRetrieverTool)]

+ 2 - 1
api/core/features/dataset_retrieval.py → api/core/features/dataset_retrieval/dataset_retrieval.py

@@ -2,9 +2,10 @@ from typing import Optional, cast
 
 from langchain.tools import BaseTool
 
-from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.entities.agent_entities import PlanningStrategy
 from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
+from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.model_entities import ModelFeature
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

+ 0 - 189
api/core/third_party/spark/spark_llm.py

@@ -1,189 +0,0 @@
-import base64
-import hashlib
-import hmac
-import json
-import queue
-import ssl
-from datetime import datetime
-from time import mktime
-from typing import Optional
-from urllib.parse import urlencode, urlparse
-from wsgiref.handlers import format_date_time
-
-import websocket
-
-
-class SparkLLMClient:
-    def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
-        domain = 'spark-api.xf-yun.com'
-        endpoint = 'chat'
-        if api_domain:
-            domain = api_domain
-            if model_name == 'spark-v3':
-                endpoint = 'multimodal'
-
-        model_api_configs = {
-            'spark': {
-                'version': 'v1.1',
-                'chat_domain': 'general'
-            },
-            'spark-v2': {
-                'version': 'v2.1',
-                'chat_domain': 'generalv2'
-            },
-            'spark-v3': {
-                'version': 'v3.1',
-                'chat_domain': 'generalv3'
-            },
-            'spark-v3.5': {
-                'version': 'v3.5',
-                'chat_domain': 'generalv3.5'
-            }
-        }
-
-        api_version = model_api_configs[model_name]['version']
-
-        self.chat_domain = model_api_configs[model_name]['chat_domain']
-        self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
-        self.app_id = app_id
-        self.ws_url = self.create_url(
-            urlparse(self.api_base).netloc,
-            urlparse(self.api_base).path,
-            self.api_base,
-            api_key,
-            api_secret
-        )
-
-        self.queue = queue.Queue()
-        self.blocking_message = ''
-
-    def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
-        # generate timestamp by RFC1123
-        now = datetime.now()
-        date = format_date_time(mktime(now.timetuple()))
-
-        signature_origin = "host: " + host + "\n"
-        signature_origin += "date: " + date + "\n"
-        signature_origin += "GET " + path + " HTTP/1.1"
-
-        # encrypt using hmac-sha256
-        signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
-                                 digestmod=hashlib.sha256).digest()
-
-        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
-
-        authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
-
-        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
-
-        v = {
-            "authorization": authorization,
-            "date": date,
-            "host": host
-        }
-        # generate url
-        url = api_base + '?' + urlencode(v)
-        return url
-
-    def run(self, messages: list, user_id: str,
-            model_kwargs: Optional[dict] = None, streaming: bool = False):
-        websocket.enableTrace(False)
-        ws = websocket.WebSocketApp(
-            self.ws_url,
-            on_message=self.on_message,
-            on_error=self.on_error,
-            on_close=self.on_close,
-            on_open=self.on_open
-        )
-        ws.messages = messages
-        ws.user_id = user_id
-        ws.model_kwargs = model_kwargs
-        ws.streaming = streaming
-        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
-
-    def on_error(self, ws, error):
-        self.queue.put({
-            'status_code': error.status_code,
-            'error': error.resp_body.decode('utf-8')
-        })
-        ws.close()
-
-    def on_close(self, ws, close_status_code, close_reason):
-        self.queue.put({'done': True})
-
-    def on_open(self, ws):
-        self.blocking_message = ''
-        data = json.dumps(self.gen_params(
-            messages=ws.messages,
-            user_id=ws.user_id,
-            model_kwargs=ws.model_kwargs
-        ))
-        ws.send(data)
-
-    def on_message(self, ws, message):
-        data = json.loads(message)
-        code = data['header']['code']
-        if code != 0:
-            self.queue.put({
-                'status_code': 400,
-                'error': f"Code: {code}, Error: {data['header']['message']}"
-            })
-            ws.close()
-        else:
-            choices = data["payload"]["choices"]
-            status = choices["status"]
-            content = choices["text"][0]["content"]
-            if ws.streaming:
-                self.queue.put({'data': content})
-            else:
-                self.blocking_message += content
-
-            if status == 2:
-                if not ws.streaming:
-                    self.queue.put({'data': self.blocking_message})
-                ws.close()
-
-    def gen_params(self, messages: list, user_id: str,
-                   model_kwargs: Optional[dict] = None) -> dict:
-        data = {
-            "header": {
-                "app_id": self.app_id,
-                "uid": user_id
-            },
-            "parameter": {
-                "chat": {
-                    "domain": self.chat_domain
-                }
-            },
-            "payload": {
-                "message": {
-                    "text": messages
-                }
-            }
-        }
-
-        if model_kwargs:
-            data['parameter']['chat'].update(model_kwargs)
-
-        return data
-
-    def subscribe(self):
-        while True:
-            content = self.queue.get()
-            if 'error' in content:
-                if content['status_code'] == 401:
-                    raise SparkError('[Spark] The credentials you provided are incorrect. '
-                                     'Please double-check and fill them in again.')
-                elif content['status_code'] == 403:
-                    raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
-                                     "Please try again after obtaining the necessary permissions.")
-                else:
-                    raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
-
-            if 'data' not in content:
-                break
-            yield content
-
-
-class SparkError(Exception):
-    pass

+ 0 - 24
api/core/tool/current_datetime_tool.py

@@ -1,24 +0,0 @@
-from datetime import datetime
-
-from langchain.tools import BaseTool
-from pydantic import BaseModel, Field
-
-
-class DatetimeToolInput(BaseModel):
-    type: str = Field(..., description="Type for current time, must be: datetime.")
-
-
-class DatetimeTool(BaseTool):
-    """Tool for querying current datetime."""
-    name: str = "current_datetime"
-    args_schema: type[BaseModel] = DatetimeToolInput
-    description: str = "A tool when you want to get the current date, time, week, month or year, " \
-                       "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
-
-    def _run(self, type: str) -> str:
-        # get current time
-        current_time = datetime.utcnow()
-        return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
-
-    async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()

+ 0 - 63
api/core/tool/provider/base.py

@@ -1,63 +0,0 @@
-import base64
-from abc import ABC, abstractmethod
-from typing import Optional
-
-from extensions.ext_database import db
-from libs import rsa
-from models.account import Tenant
-from models.tool import ToolProvider, ToolProviderName
-
-
-class BaseToolProvider(ABC):
-    def __init__(self, tenant_id: str):
-        self.tenant_id = tenant_id
-
-    @abstractmethod
-    def get_provider_name(self) -> ToolProviderName:
-        raise NotImplementedError
-
-    @abstractmethod
-    def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def credentials_to_func_kwargs(self) -> Optional[dict]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def credentials_validate(self, credentials: dict):
-        raise NotImplementedError
-
-    def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
-        """
-        Returns the Provider instance for the given tenant_id and tool_name.
-        """
-        query = db.session.query(ToolProvider).filter(
-            ToolProvider.tenant_id == self.tenant_id,
-            ToolProvider.tool_name == self.get_provider_name().value
-        )
-
-        if must_enabled:
-            query = query.filter(ToolProvider.is_enabled == True)
-
-        return query.first()
-
-    def encrypt_token(self, token) -> str:
-        tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
-        encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
-        return base64.b64encode(encrypted_token).decode()
-
-    def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
-        token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
-
-        if obfuscated:
-            return self._obfuscated_token(token)
-
-        return token
-
-    def _obfuscated_token(self, token: str) -> str:
-        return token[:6] + '*' * (len(token) - 8) + token[-2:]

+ 0 - 2
api/core/tool/provider/errors.py

@@ -1,2 +0,0 @@
-class ToolValidateFailedError(Exception):
-    description = "Tool Provider Validate failed"

+ 0 - 77
api/core/tool/provider/serpapi_provider.py

@@ -1,77 +0,0 @@
-from typing import Optional
-
-from core.tool.provider.base import BaseToolProvider
-from core.tool.provider.errors import ToolValidateFailedError
-from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
-from models.tool import ToolProviderName
-
-
-class SerpAPIToolProvider(BaseToolProvider):
-    def get_provider_name(self) -> ToolProviderName:
-        """
-        Returns the name of the provider.
-
-        :return:
-        """
-        return ToolProviderName.SERPAPI
-
-    def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
-        """
-        Returns the credentials for SerpAPI as a dictionary.
-
-        :param obfuscated: obfuscate credentials if True
-        :return:
-        """
-        tool_provider = self.get_provider(must_enabled=True)
-        if not tool_provider:
-            return None
-
-        credentials = tool_provider.credentials
-        if not credentials:
-            return None
-
-        if credentials.get('api_key'):
-            credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
-
-        return credentials
-
-    def credentials_to_func_kwargs(self) -> Optional[dict]:
-        """
-        Returns the credentials function kwargs as a dictionary.
-
-        :return:
-        """
-        credentials = self.get_credentials()
-        if not credentials:
-            return None
-
-        return {
-            'serpapi_api_key': credentials.get('api_key')
-        }
-
-    def credentials_validate(self, credentials: dict):
-        """
-        Validates the given credentials.
-
-        :param credentials:
-        :return:
-        """
-        if 'api_key' not in credentials or not credentials.get('api_key'):
-            raise ToolValidateFailedError("SerpAPI api_key is required.")
-
-        api_key = credentials.get('api_key')
-
-        try:
-            OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
-        except Exception as e:
-            raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
-
-    def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
-        """
-        Encrypts the given credentials.
-
-        :param credentials:
-        :return:
-        """
-        credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
-        return credentials

+ 0 - 43
api/core/tool/provider/tool_provider_service.py

@@ -1,43 +0,0 @@
-from typing import Optional
-
-from core.tool.provider.base import BaseToolProvider
-from core.tool.provider.serpapi_provider import SerpAPIToolProvider
-
-
-class ToolProviderService:
-
-    def __init__(self, tenant_id: str, provider_name: str):
-        self.provider = self._init_provider(tenant_id, provider_name)
-
-    def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
-        if provider_name == 'serpapi':
-            return SerpAPIToolProvider(tenant_id)
-        else:
-            raise Exception('tool provider {} not found'.format(provider_name))
-
-    def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
-        """
-        Returns the credentials for Tool as a dictionary.
-
-        :param obfuscated:
-        :return:
-        """
-        return self.provider.get_credentials(obfuscated)
-
-    def credentials_validate(self, credentials: dict):
-        """
-        Validates the given credentials.
-
-        :param credentials:
-        :raises: ValidateFailedError
-        """
-        return self.provider.credentials_validate(credentials)
-
-    def encrypt_credentials(self, credentials: dict):
-        """
-        Encrypts the given credentials.
-
-        :param credentials:
-        :return:
-        """
-        return self.provider.encrypt_credentials(credentials)

+ 0 - 51
api/core/tool/serpapi_wrapper.py

@@ -1,51 +0,0 @@
-from langchain import SerpAPIWrapper
-from pydantic import BaseModel, Field
-
-
-class OptimizedSerpAPIInput(BaseModel):
-    query: str = Field(..., description="search query.")
-
-
-class OptimizedSerpAPIWrapper(SerpAPIWrapper):
-
-    @staticmethod
-    def _process_response(res: dict, num_results: int = 5) -> str:
-        """Process response from SerpAPI."""
-        if "error" in res.keys():
-            raise ValueError(f"Got error from SerpAPI: {res['error']}")
-        if "answer_box" in res.keys() and type(res["answer_box"]) == list:
-            res["answer_box"] = res["answer_box"][0]
-        if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
-            toret = res["answer_box"]["answer"]
-        elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
-            toret = res["answer_box"]["snippet"]
-        elif (
-            "answer_box" in res.keys()
-            and "snippet_highlighted_words" in res["answer_box"].keys()
-        ):
-            toret = res["answer_box"]["snippet_highlighted_words"][0]
-        elif (
-            "sports_results" in res.keys()
-            and "game_spotlight" in res["sports_results"].keys()
-        ):
-            toret = res["sports_results"]["game_spotlight"]
-        elif (
-            "shopping_results" in res.keys()
-            and "title" in res["shopping_results"][0].keys()
-        ):
-            toret = res["shopping_results"][:3]
-        elif (
-            "knowledge_graph" in res.keys()
-            and "description" in res["knowledge_graph"].keys()
-        ):
-            toret = res["knowledge_graph"]["description"]
-        elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
-            toret = ""
-            for result in res["organic_results"][:num_results]:
-                if "link" in result:
-                    toret += "----------------\nlink: " + result["link"] + "\n"
-                if "snippet" in result:
-                    toret += "snippet: " + result["snippet"] + "\n"
-        else:
-            toret = "No good search result found"
-        return "search result:\n" + toret

+ 0 - 443
api/core/tool/web_reader_tool.py

@@ -1,443 +0,0 @@
-import hashlib
-import json
-import os
-import re
-import site
-import subprocess
-import tempfile
-import unicodedata
-from contextlib import contextmanager
-from typing import Any
-
-import requests
-from bs4 import BeautifulSoup, CData, Comment, NavigableString
-from langchain.chains import RefineDocumentsChain
-from langchain.chains.summarize import refine_prompts
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.tools.base import BaseTool
-from newspaper import Article
-from pydantic import BaseModel, Field
-from regex import regex
-
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
-from core.rag.extractor import extract_processor
-from core.rag.extractor.extract_processor import ExtractProcessor
-from core.rag.models.document import Document
-
-FULL_TEMPLATE = """
-TITLE: {title}
-AUTHORS: {authors}
-PUBLISH DATE: {publish_date}
-TOP_IMAGE_URL: {top_image}
-TEXT:
-
-{text}
-"""
-
-
-class WebReaderToolInput(BaseModel):
-    url: str = Field(..., description="URL of the website to read")
-    summary: bool = Field(
-        default=False,
-        description="When the user's question requires extracting the summarizing content of the webpage, "
-                    "set it to true."
-    )
-    cursor: int = Field(
-        default=0,
-        description="Start reading from this character."
-        "Use when the first response was truncated"
-        "and you want to continue reading the page."
-        "The value cannot exceed 24000.",
-    )
-
-
-class WebReaderTool(BaseTool):
-    """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
-
-    name: str = "web_reader"
-    args_schema: type[BaseModel] = WebReaderToolInput
-    description: str = "use this to read a website. " \
-                       "If you can answer the question based on the information provided, " \
-                       "there is no need to use."
-    page_contents: str = None
-    url: str = None
-    max_chunk_length: int = 4000
-    summary_chunk_tokens: int = 4000
-    summary_chunk_overlap: int = 0
-    summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
-    continue_reading: bool = True
-    model_config: ModelConfigEntity
-    model_parameters: dict[str, Any]
-
-    def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
-        try:
-            if not self.page_contents or self.url != url:
-                page_contents = get_url(url)
-                self.page_contents = page_contents
-                self.url = url
-            else:
-                page_contents = self.page_contents
-        except Exception as e:
-            return f'Read this website failed, caused by: {str(e)}.'
-
-        if summary:
-            character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
-                chunk_size=self.summary_chunk_tokens,
-                chunk_overlap=self.summary_chunk_overlap,
-                separators=self.summary_separators
-            )
-
-            texts = character_splitter.split_text(page_contents)
-            docs = [Document(page_content=t) for t in texts]
-
-            if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
-                return "No content found."
-
-            # only use first 5 docs
-            if len(docs) > 5:
-                docs = docs[:5]
-
-            chain = self.get_summary_chain()
-            try:
-                page_contents = chain.run(docs)
-            except Exception as e:
-                return f'Read this website failed, caused by: {str(e)}.'
-        else:
-            page_contents = page_result(page_contents, cursor, self.max_chunk_length)
-
-            if self.continue_reading and len(page_contents) >= self.max_chunk_length:
-                page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
-                                 f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
-                                 f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
-
-        return page_contents
-
-    async def _arun(self, url: str) -> str:
-        raise NotImplementedError
-
-    def get_summary_chain(self) -> RefineDocumentsChain:
-        initial_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.PROMPT,
-            parameters=self.model_parameters
-        )
-        refine_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.REFINE_PROMPT,
-            parameters=self.model_parameters
-        )
-        return RefineDocumentsChain(
-            initial_llm_chain=initial_chain,
-            refine_llm_chain=refine_chain,
-            document_variable_name="text",
-            initial_response_name="existing_answer",
-            callbacks=self.callbacks
-        )
-
-
-def page_result(text: str, cursor: int, max_length: int) -> str:
-    """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
-    return text[cursor: cursor + max_length]
-
-
-def get_url(url: str) -> str:
-    """Fetch URL and return the contents as a string."""
-    headers = {
-        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
-    }
-    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
-
-    head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
-
-    if head_response.status_code != 200:
-        return "URL returned status code {}.".format(head_response.status_code)
-
-    # check content-type
-    main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
-    if main_content_type not in supported_content_types:
-        return "Unsupported content-type [{}] of URL.".format(main_content_type)
-
-    if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
-        return ExtractProcessor.load_from_url(url, return_text=True)
-
-    response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
-    a = extract_using_readabilipy(response.text)
-
-    if not a['plain_text'] or not a['plain_text'].strip():
-        return get_url_from_newspaper3k(url)
-
-    res = FULL_TEMPLATE.format(
-        title=a['title'],
-        authors=a['byline'],
-        publish_date=a['date'],
-        top_image="",
-        text=a['plain_text'] if a['plain_text'] else "",
-    )
-
-    return res
-
-
-def get_url_from_newspaper3k(url: str) -> str:
-
-    a = Article(url)
-    a.download()
-    a.parse()
-
-    res = FULL_TEMPLATE.format(
-        title=a.title,
-        authors=a.authors,
-        publish_date=a.publish_date,
-        top_image=a.top_image,
-        text=a.text,
-    )
-
-    return res
-
-
-def extract_using_readabilipy(html):
-    with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
-        f_html.write(html)
-        f_html.close()
-    html_path = f_html.name
-
-    # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
-    article_json_path = html_path + ".json"
-    jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
-    with chdir(jsdir):
-        subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
-
-    # Read output of call to Readability.parse() from JSON file and return as Python dictionary
-    with open(article_json_path, encoding="utf-8") as json_file:
-        input_json = json.loads(json_file.read())
-
-    # Deleting files after processing
-    os.unlink(article_json_path)
-    os.unlink(html_path)
-
-    article_json = {
-        "title": None,
-        "byline": None,
-        "date": None,
-        "content": None,
-        "plain_content": None,
-        "plain_text": None
-    }
-    # Populate article fields from readability fields where present
-    if input_json:
-        if "title" in input_json and input_json["title"]:
-            article_json["title"] = input_json["title"]
-        if "byline" in input_json and input_json["byline"]:
-            article_json["byline"] = input_json["byline"]
-        if "date" in input_json and input_json["date"]:
-            article_json["date"] = input_json["date"]
-        if "content" in input_json and input_json["content"]:
-            article_json["content"] = input_json["content"]
-            article_json["plain_content"] = plain_content(article_json["content"], False, False)
-            article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
-        if "textContent" in input_json and input_json["textContent"]:
-            article_json["plain_text"] = input_json["textContent"]
-            article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
-
-    return article_json
-
-
-def find_module_path(module_name):
-    for package_path in site.getsitepackages():
-        potential_path = os.path.join(package_path, module_name)
-        if os.path.exists(potential_path):
-            return potential_path
-
-    return None
-
-@contextmanager
-def chdir(path):
-    """Change directory in context and return to original on exit"""
-    # From https://stackoverflow.com/a/37996581, couldn't find a built-in
-    original_path = os.getcwd()
-    os.chdir(path)
-    try:
-        yield
-    finally:
-        os.chdir(original_path)
-
-
-def extract_text_blocks_as_plain_text(paragraph_html):
-    # Load article as DOM
-    soup = BeautifulSoup(paragraph_html, 'html.parser')
-    # Select all lists
-    list_elements = soup.find_all(['ul', 'ol'])
-    # Prefix text in all list items with "* " and make lists paragraphs
-    for list_element in list_elements:
-        plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
-        list_element.string = plain_items
-        list_element.name = "p"
-    # Select all text blocks
-    text_blocks = [s.parent for s in soup.find_all(string=True)]
-    text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
-    # Drop empty paragraphs
-    text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
-    return text_blocks
-
-
-def plain_text_leaf_node(element):
-    # Extract all text, stripped of any child HTML elements and normalise it
-    plain_text = normalise_text(element.get_text())
-    if plain_text != "" and element.name == "li":
-        plain_text = "* {}, ".format(plain_text)
-    if plain_text == "":
-        plain_text = None
-    if "data-node-index" in element.attrs:
-        plain = {"node_index": element["data-node-index"], "text": plain_text}
-    else:
-        plain = {"text": plain_text}
-    return plain
-
-
-def plain_content(readability_content, content_digests, node_indexes):
-    # Load article as DOM
-    soup = BeautifulSoup(readability_content, 'html.parser')
-    # Make all elements plain
-    elements = plain_elements(soup.contents, content_digests, node_indexes)
-    if node_indexes:
-        # Add node index attributes to nodes
-        elements = [add_node_indexes(element) for element in elements]
-    # Replace article contents with plain elements
-    soup.contents = elements
-    return str(soup)
-
-
-def plain_elements(elements, content_digests, node_indexes):
-    # Get plain content versions of all elements
-    elements = [plain_element(element, content_digests, node_indexes)
-                for element in elements]
-    if content_digests:
-        # Add content digest attribute to nodes
-        elements = [add_content_digest(element) for element in elements]
-    return elements
-
-
-def plain_element(element, content_digests, node_indexes):
-    # For lists, we make each item plain text
-    if is_leaf(element):
-        # For leaf node elements, extract the text content, discarding any HTML tags
-        # 1. Get element contents as text
-        plain_text = element.get_text()
-        # 2. Normalise the extracted text string to a canonical representation
-        plain_text = normalise_text(plain_text)
-        # 3. Update element content to be plain text
-        element.string = plain_text
-    elif is_text(element):
-        if is_non_printing(element):
-            # The simplified HTML may have come from Readability.js so might
-            # have non-printing text (e.g. Comment or CData). In this case, we
-            # keep the structure, but ensure that the string is empty.
-            element = type(element)("")
-        else:
-            plain_text = element.string
-            plain_text = normalise_text(plain_text)
-            element = type(element)(plain_text)
-    else:
-        # If not a leaf node or leaf type call recursively on child nodes, replacing
-        element.contents = plain_elements(element.contents, content_digests, node_indexes)
-    return element
-
-
-def add_node_indexes(element, node_index="0"):
-    # Can't add attributes to string types
-    if is_text(element):
-        return element
-    # Add index to current element
-    element["data-node-index"] = node_index
-    # Add index to child elements
-    for local_idx, child in enumerate(
-            [c for c in element.contents if not is_text(c)], start=1):
-        # Can't add attributes to leaf string types
-        child_index = "{stem}.{local}".format(
-            stem=node_index, local=local_idx)
-        add_node_indexes(child, node_index=child_index)
-    return element
-
-
-def normalise_text(text):
-    """Normalise unicode and whitespace."""
-    # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
-    text = strip_control_characters(text)
-    text = normalise_unicode(text)
-    text = normalise_whitespace(text)
-    return text
-
-
-def strip_control_characters(text):
-    """Strip out unicode control characters which might break the parsing."""
-    # Unicode control characters
-    #   [Cc]: Other, Control [includes new lines]
-    #   [Cf]: Other, Format
-    #   [Cn]: Other, Not Assigned
-    #   [Co]: Other, Private Use
-    #   [Cs]: Other, Surrogate
-    control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
-    retained_chars = ['\t', '\n', '\r', '\f']
-
-    # Remove non-printing control characters
-    return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
-
-
-def normalise_unicode(text):
-    """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
-    normal_form = "NFKC"
-    text = unicodedata.normalize(normal_form, text)
-    return text
-
-
-def normalise_whitespace(text):
-    """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
-    text = regex.sub(r"\s+", " ", text)
-    # Remove leading and trailing whitespace
-    text = text.strip()
-    return text
-
-def is_leaf(element):
-    return (element.name in ['p', 'li'])
-
-
-def is_text(element):
-    return isinstance(element, NavigableString)
-
-
-def is_non_printing(element):
-    return any(isinstance(element, _e) for _e in [Comment, CData])
-
-
-def add_content_digest(element):
-    if not is_text(element):
-        element["data-content-digest"] = content_digest(element)
-    return element
-
-
-def content_digest(element):
-    if is_text(element):
-        # Hash
-        trimmed_string = element.string.strip()
-        if trimmed_string == "":
-            digest = ""
-        else:
-            digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
-    else:
-        contents = element.contents
-        num_contents = len(contents)
-        if num_contents == 0:
-            # No hash when no child elements exist
-            digest = ""
-        elif num_contents == 1:
-            # If single child, use digest of child
-            digest = content_digest(contents[0])
-        else:
-            # Build content digest from the "non-empty" digests of child nodes
-            digest = hashlib.sha256()
-            child_digests = list(
-                filter(lambda x: x != "", [content_digest(content) for content in contents]))
-            for child in child_digests:
-                digest.update(child.encode('utf-8'))
-            digest = digest.hexdigest()
-    return digest

+ 18 - 18
api/core/tools/tool/dataset_retriever_tool.py

@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
-from core.features.dataset_retrieval import DatasetRetrievalFeature
+from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
 from core.tools.tool.tool import Tool
@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
 
     @staticmethod
     def get_dataset_tools(tenant_id: str,
-                         dataset_ids: list[str],
-                         retrieve_config: DatasetRetrieveConfigEntity,
-                         return_resource: bool,
-                         invoke_from: InvokeFrom,
-                         hit_callback: DatasetIndexToolCallbackHandler
-    ) -> list['DatasetRetrieverTool']:
+                          dataset_ids: list[str],
+                          retrieve_config: DatasetRetrieveConfigEntity,
+                          return_resource: bool,
+                          invoke_from: InvokeFrom,
+                          hit_callback: DatasetIndexToolCallbackHandler
+                          ) -> list['DatasetRetrieverTool']:
         """
         get dataset tool
         """
@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
         )
         # restore retrieve strategy
         retrieve_config.retrieve_strategy = original_retriever_mode
-        
+
         # convert langchain tools to Tools
         tools = []
         for langchain_tool in langchain_tools:
@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
                     llm=langchain_tool.description),
                 runtime=DatasetRetrieverTool.Runtime()
             )
-            
+
             tools.append(tool)
 
         return tools
@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
     def get_runtime_parameters(self) -> list[ToolParameter]:
         return [
             ToolParameter(name='query',
-                         label=I18nObject(en_US='', zh_Hans=''),
-                         human_description=I18nObject(en_US='', zh_Hans=''),
-                         type=ToolParameter.ToolParameterType.STRING,
-                         form=ToolParameter.ToolParameterForm.LLM,
-                         llm_description='Query for the dataset to be used to retrieve the dataset.',
-                         required=True,
-                         default=''),
+                          label=I18nObject(en_US='', zh_Hans=''),
+                          human_description=I18nObject(en_US='', zh_Hans=''),
+                          type=ToolParameter.ToolParameterType.STRING,
+                          form=ToolParameter.ToolParameterForm.LLM,
+                          llm_description='Query for the dataset to be used to retrieve the dataset.',
+                          required=True,
+                          default=''),
         ]
 
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
         query = tool_parameters.get('query', None)
         if not query:
             return self.create_text_message(text='please input query')
-        
+
         # invoke dataset retriever tool
         result = self.langchain_tool._run(query=query)
 
@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
         """
         validate the credentials for dataset retriever tool
         """
-        pass
+        pass

+ 0 - 109
api/core/tools/utils/web_reader_tool.py

@@ -7,23 +7,14 @@ import subprocess
 import tempfile
 import unicodedata
 from contextlib import contextmanager
-from typing import Any
 
 import requests
 from bs4 import BeautifulSoup, CData, Comment, NavigableString
-from langchain.chains import RefineDocumentsChain
-from langchain.chains.summarize import refine_prompts
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.tools.base import BaseTool
 from newspaper import Article
-from pydantic import BaseModel, Field
 from regex import regex
 
-from core.chain.llm_chain import LLMChain
-from core.entities.application_entities import ModelConfigEntity
 from core.rag.extractor import extract_processor
 from core.rag.extractor.extract_processor import ExtractProcessor
-from core.rag.models.document import Document
 
 FULL_TEMPLATE = """
 TITLE: {title}
@@ -36,106 +27,6 @@ TEXT:
 """
 
 
-class WebReaderToolInput(BaseModel):
-    url: str = Field(..., description="URL of the website to read")
-    summary: bool = Field(
-        default=False,
-        description="When the user's question requires extracting the summarizing content of the webpage, "
-                    "set it to true."
-    )
-    cursor: int = Field(
-        default=0,
-        description="Start reading from this character."
-        "Use when the first response was truncated"
-        "and you want to continue reading the page."
-        "The value cannot exceed 24000.",
-    )
-
-
-class WebReaderTool(BaseTool):
-    """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
-
-    name: str = "web_reader"
-    args_schema: type[BaseModel] = WebReaderToolInput
-    description: str = "use this to read a website. " \
-                       "If you can answer the question based on the information provided, " \
-                       "there is no need to use."
-    page_contents: str = None
-    url: str = None
-    max_chunk_length: int = 4000
-    summary_chunk_tokens: int = 4000
-    summary_chunk_overlap: int = 0
-    summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
-    continue_reading: bool = True
-    model_config: ModelConfigEntity
-    model_parameters: dict[str, Any]
-
-    def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
-        try:
-            if not self.page_contents or self.url != url:
-                page_contents = get_url(url)
-                self.page_contents = page_contents
-                self.url = url
-            else:
-                page_contents = self.page_contents
-        except Exception as e:
-            return f'Read this website failed, caused by: {str(e)}.'
-
-        if summary:
-            character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
-                chunk_size=self.summary_chunk_tokens,
-                chunk_overlap=self.summary_chunk_overlap,
-                separators=self.summary_separators
-            )
-
-            texts = character_splitter.split_text(page_contents)
-            docs = [Document(page_content=t) for t in texts]
-
-            if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
-                return "No content found."
-
-            # only use first 5 docs
-            if len(docs) > 5:
-                docs = docs[:5]
-
-            chain = self.get_summary_chain()
-            try:
-                page_contents = chain.run(docs)
-            except Exception as e:
-                return f'Read this website failed, caused by: {str(e)}.'
-        else:
-            page_contents = page_result(page_contents, cursor, self.max_chunk_length)
-
-            if self.continue_reading and len(page_contents) >= self.max_chunk_length:
-                page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
-                                 f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
-                                 f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
-
-        return page_contents
-
-    async def _arun(self, url: str) -> str:
-        raise NotImplementedError
-
-    def get_summary_chain(self) -> RefineDocumentsChain:
-        initial_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.PROMPT,
-            parameters=self.model_parameters
-        )
-        refine_chain = LLMChain(
-            model_config=self.model_config,
-            prompt=refine_prompts.REFINE_PROMPT,
-            parameters=self.model_parameters
-        )
-        return RefineDocumentsChain(
-            initial_llm_chain=initial_chain,
-            refine_llm_chain=refine_chain,
-            document_variable_name="text",
-            initial_response_name="existing_answer",
-            callbacks=self.callbacks
-        )
-
-
 def page_result(text: str, cursor: int, max_length: int) -> str:
     """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
     return text[cursor: cursor + max_length]

+ 1 - 1
api/services/app_model_config_service.py

@@ -1,7 +1,7 @@
 import re
 import uuid
 
-from core.agent.agent_executor import PlanningStrategy
+from core.entities.agent_entities import PlanningStrategy
 from core.external_data_tool.factory import ExternalDataToolFactory
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers import model_provider_factory