Browse Source

feat: remove llm client use (#1316)

takatost 1 year ago
parent
commit
cbf095465c

+ 57 - 7
api/core/agent/agent/multi_dataset_router_agent.py

@@ -2,14 +2,18 @@ import json
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
+from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
 from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
+from pydantic import root_validator
 
+from core.model_providers.models.entity.message import to_prompt_messages
 from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
 
         arbitrary_types_allowed = True
 
+    @root_validator
+    def validate_llm(cls, values: dict) -> dict:
+        return values
+
     def should_use_agent(self, query: str):
         """
         return should use agent
@@ -65,7 +73,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             return AgentFinish(return_values={"output": observation}, log=observation)
 
         try:
-            agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
+            agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
             if isinstance(agent_decision, AgentAction):
                 tool_inputs = agent_decision.tool_input
                 if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
@@ -76,6 +84,44 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             new_exception = self.model_instance.handle_exceptions(e)
             raise new_exception
 
+    def real_plan(
+        self,
+        intermediate_steps: List[Tuple[AgentAction, str]],
+        callbacks: Callbacks = None,
+        **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        """Given input, decided what to do.
+
+        Args:
+            intermediate_steps: Steps the LLM has taken to date, along with observations
+            **kwargs: User inputs.
+
+        Returns:
+            Action specifying what tool to use.
+        """
+        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
+        selected_inputs = {
+            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
+        }
+        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
+        prompt = self.prompt.format_prompt(**full_inputs)
+        messages = prompt.to_messages()
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            functions=self.functions,
+        )
+
+        ai_message = AIMessage(
+            content=result.content,
+            additional_kwargs={
+                'function_call': result.function_call
+            }
+        )
+
+        agent_decision = _parse_ai_message(ai_message)
+        return agent_decision
+
     async def aplan(
             self,
             intermediate_steps: List[Tuple[AgentAction, str]],
@@ -87,7 +133,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     @classmethod
     def from_llm_and_tools(
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -96,11 +142,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             ),
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
+        prompt = cls.create_prompt(
             extra_prompt_messages=extra_prompt_messages,
             system_message=system_message,
+        )
+        return cls(
+            model_instance=model_instance,
+            llm=FakeLLM(response=''),
+            prompt=prompt,
+            tools=tools,
+            callback_manager=callback_manager,
             **kwargs,
         )

+ 193 - 19
api/core/agent/agent/openai_function_call.py

@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
     _format_intermediate_steps
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
+from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
+from langchain.memory.prompt import SUMMARY_PROMPT
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
-from langchain.schema.language_model import BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
+    get_buffer_string
 from langchain.tools import BaseTool
+from pydantic import root_validator
 
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
-from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
+from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
+from core.chain.llm_chain import LLMChain
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
 
 
-class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
+class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
+    moving_summary_buffer: str = ""
+    moving_summary_index: int = 0
+    summary_model_instance: BaseLLM = None
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    @root_validator
+    def validate_llm(cls, values: dict) -> dict:
+        return values
 
     @classmethod
     def from_llm_and_tools(
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
             ),
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
-        return super().from_llm_and_tools(
-            llm=llm,
+        prompt = cls.create_prompt(
+            extra_prompt_messages=extra_prompt_messages,
+            system_message=system_message,
+        )
+        return cls(
+            model_instance=model_instance,
+            llm=FakeLLM(response=''),
+            prompt=prompt,
             tools=tools,
             callback_manager=callback_manager,
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=cls.get_system_message(),
             **kwargs,
         )
 
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
         :param query:
         :return:
         """
-        original_max_tokens = self.llm.max_tokens
-        self.llm.max_tokens = 40
+        original_max_tokens = self.model_instance.model_kwargs.max_tokens
+        self.model_instance.model_kwargs.max_tokens = 40
 
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()
 
         try:
-            predicted_message = self.llm.predict_messages(
-                messages, functions=self.functions, callbacks=None
+            prompt_messages = to_prompt_messages(messages)
+            result = self.model_instance.run(
+                messages=prompt_messages,
+                functions=self.functions,
+                callbacks=None
             )
         except Exception as e:
             new_exception = self.model_instance.handle_exceptions(e)
             raise new_exception
 
-        function_call = predicted_message.additional_kwargs.get("function_call", {})
+        function_call = result.function_call
 
-        self.llm.max_tokens = original_max_tokens
+        self.model_instance.model_kwargs.max_tokens = original_max_tokens
 
         return True if function_call else False
 
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 
-        predicted_message = self.llm.predict_messages(
-            messages, functions=self.functions, callbacks=callbacks
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            functions=self.functions,
+        )
+
+        ai_message = AIMessage(
+            content=result.content,
+            additional_kwargs={
+                'function_call': result.function_call
+            }
         )
-        agent_decision = _parse_ai_message(predicted_message)
+        agent_decision = _parse_ai_message(ai_message)
 
         if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
             tool_inputs = agent_decision.tool_input
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
             return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
         except ValueError:
             return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
+
+    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
+        # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
+        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
+        rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
+        if rest_tokens >= 0:
+            return messages
+
+        system_message = None
+        human_message = None
+        should_summary_messages = []
+        for message in messages:
+            if isinstance(message, SystemMessage):
+                system_message = message
+            elif isinstance(message, HumanMessage):
+                human_message = message
+            else:
+                should_summary_messages.append(message)
+
+        if len(should_summary_messages) > 2:
+            ai_message = should_summary_messages[-2]
+            function_message = should_summary_messages[-1]
+            should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
+            self.moving_summary_index = len(should_summary_messages)
+        else:
+            error_msg = "Exceeded LLM tokens limit, stopped."
+            raise ExceededLLMTokensLimitError(error_msg)
+
+        new_messages = [system_message, human_message]
+
+        if self.moving_summary_index == 0:
+            should_summary_messages.insert(0, human_message)
+
+        self.moving_summary_buffer = self.predict_new_summary(
+            messages=should_summary_messages,
+            existing_summary=self.moving_summary_buffer
+        )
+
+        new_messages.append(AIMessage(content=self.moving_summary_buffer))
+        new_messages.append(ai_message)
+        new_messages.append(function_message)
+
+        return new_messages
+
+    def predict_new_summary(
+        self, messages: List[BaseMessage], existing_summary: str
+    ) -> str:
+        new_lines = get_buffer_string(
+            messages,
+            human_prefix="Human",
+            ai_prefix="AI",
+        )
+
+        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        return chain.predict(summary=existing_summary, new_lines=new_lines)
+
+    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+        Official documentation: https://github.com/openai/openai-cookbook/blob/
+        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+        if model_instance.model_provider.provider_name == 'azure_openai':
+            model = model_instance.base_model_name
+            model = model.replace("gpt-35", "gpt-3.5")
+        else:
+            model = model_instance.base_model_name
+
+        tiktoken_ = _import_tiktoken()
+        try:
+            encoding = tiktoken_.encoding_for_model(model)
+        except KeyError:
+            model = "cl100k_base"
+            encoding = tiktoken_.get_encoding(model)
+
+        if model.startswith("gpt-3.5-turbo"):
+            # every message follows <im_start>{role/name}\n{content}<im_end>\n
+            tokens_per_message = 4
+            # if there's a name, the role is omitted
+            tokens_per_name = -1
+        elif model.startswith("gpt-4"):
+            tokens_per_message = 3
+            tokens_per_name = 1
+        else:
+            raise NotImplementedError(
+                f"get_num_tokens_from_messages() is not presently implemented "
+                f"for model {model}."
+                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
+                "information on how messages are converted to tokens."
+            )
+        num_tokens = 0
+        for m in messages:
+            message = _convert_message_to_dict(m)
+            num_tokens += tokens_per_message
+            for key, value in message.items():
+                if key == "function_call":
+                    for f_key, f_value in value.items():
+                        num_tokens += len(encoding.encode(f_key))
+                        num_tokens += len(encoding.encode(f_value))
+                else:
+                    num_tokens += len(encoding.encode(value))
+
+                if key == "name":
+                    num_tokens += tokens_per_name
+        # every reply is primed with <im_start>assistant
+        num_tokens += 3
+
+        if kwargs.get('functions'):
+            for function in kwargs.get('functions'):
+                num_tokens += len(encoding.encode('name'))
+                num_tokens += len(encoding.encode(function.get("name")))
+                num_tokens += len(encoding.encode('description'))
+                num_tokens += len(encoding.encode(function.get("description")))
+                parameters = function.get("parameters")
+                num_tokens += len(encoding.encode('parameters'))
+                if 'title' in parameters:
+                    num_tokens += len(encoding.encode('title'))
+                    num_tokens += len(encoding.encode(parameters.get("title")))
+                num_tokens += len(encoding.encode('type'))
+                num_tokens += len(encoding.encode(parameters.get("type")))
+                if 'properties' in parameters:
+                    num_tokens += len(encoding.encode('properties'))
+                    for key, value in parameters.get('properties').items():
+                        num_tokens += len(encoding.encode(key))
+                        for field_key, field_value in value.items():
+                            num_tokens += len(encoding.encode(field_key))
+                            if field_key == 'enum':
+                                for enum_field in field_value:
+                                    num_tokens += 3
+                                    num_tokens += len(encoding.encode(enum_field))
+                            else:
+                                num_tokens += len(encoding.encode(field_key))
+                                num_tokens += len(encoding.encode(str(field_value)))
+                if 'required' in parameters:
+                    num_tokens += len(encoding.encode('required'))
+                    for required_field in parameters['required']:
+                        num_tokens += 3
+                        num_tokens += len(encoding.encode(required_field))
+
+        return num_tokens

+ 0 - 140
api/core/agent/agent/openai_function_call_summarize_mixin.py

@@ -1,140 +0,0 @@
-from typing import cast, List
-
-from langchain.chat_models import ChatOpenAI
-from langchain.chat_models.openai import _convert_message_to_dict
-from langchain.memory.summary import SummarizerMixin
-from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
-from langchain.schema.language_model import BaseLanguageModel
-from pydantic import BaseModel
-
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
-from core.model_providers.models.llm.base import BaseLLM
-
-
-class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
-    moving_summary_buffer: str = ""
-    moving_summary_index: int = 0
-    summary_llm: BaseLanguageModel = None
-    model_instance: BaseLLM
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
-        # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
-        rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
-        if rest_tokens >= 0:
-            return messages
-
-        system_message = None
-        human_message = None
-        should_summary_messages = []
-        for message in messages:
-            if isinstance(message, SystemMessage):
-                system_message = message
-            elif isinstance(message, HumanMessage):
-                human_message = message
-            else:
-                should_summary_messages.append(message)
-
-        if len(should_summary_messages) > 2:
-            ai_message = should_summary_messages[-2]
-            function_message = should_summary_messages[-1]
-            should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
-            self.moving_summary_index = len(should_summary_messages)
-        else:
-            error_msg = "Exceeded LLM tokens limit, stopped."
-            raise ExceededLLMTokensLimitError(error_msg)
-
-        new_messages = [system_message, human_message]
-
-        if self.moving_summary_index == 0:
-            should_summary_messages.insert(0, human_message)
-
-        summary_handler = SummarizerMixin(llm=self.summary_llm)
-        self.moving_summary_buffer = summary_handler.predict_new_summary(
-            messages=should_summary_messages,
-            existing_summary=self.moving_summary_buffer
-        )
-
-        new_messages.append(AIMessage(content=self.moving_summary_buffer))
-        new_messages.append(ai_message)
-        new_messages.append(function_message)
-
-        return new_messages
-
-    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
-        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
-
-        Official documentation: https://github.com/openai/openai-cookbook/blob/
-        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
-        llm = cast(ChatOpenAI, model_instance.client)
-        model, encoding = llm._get_encoding_model()
-        if model.startswith("gpt-3.5-turbo"):
-            # every message follows <im_start>{role/name}\n{content}<im_end>\n
-            tokens_per_message = 4
-            # if there's a name, the role is omitted
-            tokens_per_name = -1
-        elif model.startswith("gpt-4"):
-            tokens_per_message = 3
-            tokens_per_name = 1
-        else:
-            raise NotImplementedError(
-                f"get_num_tokens_from_messages() is not presently implemented "
-                f"for model {model}."
-                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
-                "information on how messages are converted to tokens."
-            )
-        num_tokens = 0
-        for m in messages:
-            message = _convert_message_to_dict(m)
-            num_tokens += tokens_per_message
-            for key, value in message.items():
-                if key == "function_call":
-                    for f_key, f_value in value.items():
-                        num_tokens += len(encoding.encode(f_key))
-                        num_tokens += len(encoding.encode(f_value))
-                else:
-                    num_tokens += len(encoding.encode(value))
-
-                if key == "name":
-                    num_tokens += tokens_per_name
-        # every reply is primed with <im_start>assistant
-        num_tokens += 3
-
-        if kwargs.get('functions'):
-            for function in kwargs.get('functions'):
-                num_tokens += len(encoding.encode('name'))
-                num_tokens += len(encoding.encode(function.get("name")))
-                num_tokens += len(encoding.encode('description'))
-                num_tokens += len(encoding.encode(function.get("description")))
-                parameters = function.get("parameters")
-                num_tokens += len(encoding.encode('parameters'))
-                if 'title' in parameters:
-                    num_tokens += len(encoding.encode('title'))
-                    num_tokens += len(encoding.encode(parameters.get("title")))
-                num_tokens += len(encoding.encode('type'))
-                num_tokens += len(encoding.encode(parameters.get("type")))
-                if 'properties' in parameters:
-                    num_tokens += len(encoding.encode('properties'))
-                    for key, value in parameters.get('properties').items():
-                        num_tokens += len(encoding.encode(key))
-                        for field_key, field_value in value.items():
-                            num_tokens += len(encoding.encode(field_key))
-                            if field_key == 'enum':
-                                for enum_field in field_value:
-                                    num_tokens += 3
-                                    num_tokens += len(encoding.encode(enum_field))
-                            else:
-                                num_tokens += len(encoding.encode(field_key))
-                                num_tokens += len(encoding.encode(str(field_value)))
-                if 'required' in parameters:
-                    num_tokens += len(encoding.encode('required'))
-                    for required_field in parameters['required']:
-                        num_tokens += 3
-                        num_tokens += len(encoding.encode(required_field))
-
-        return num_tokens

+ 0 - 107
api/core/agent/agent/openai_multi_function_call.py

@@ -1,107 +0,0 @@
-from typing import List, Tuple, Any, Union, Sequence, Optional
-
-from langchain.agents import BaseMultiActionAgent
-from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
-    _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
-from langchain.schema.language_model import BaseLanguageModel
-from langchain.tools import BaseTool
-
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
-from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
-
-
-class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            llm: BaseLanguageModel,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            **kwargs: Any,
-    ) -> BaseMultiActionAgent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=cls.get_system_message(),
-            **kwargs,
-        )
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        original_max_tokens = self.llm.max_tokens
-        self.llm.max_tokens = 15
-
-        prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
-        messages = prompt.to_messages()
-
-        try:
-            predicted_message = self.llm.predict_messages(
-                messages, functions=self.functions, callbacks=None
-            )
-        except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
-
-        function_call = predicted_message.additional_kwargs.get("function_call", {})
-
-        self.llm.max_tokens = original_max_tokens
-
-        return True if function_call else False
-
-    def plan(
-            self,
-            intermediate_steps: List[Tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-
-        # summarize messages if rest_tokens < 0
-        try:
-            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
-        except ExceededLLMTokensLimitError as e:
-            return AgentFinish(return_values={"output": str(e)}, log=str(e))
-
-        predicted_message = self.llm.predict_messages(
-            messages, functions=self.functions, callbacks=callbacks
-        )
-        agent_decision = _parse_ai_message(predicted_message)
-        return agent_decision
-
-    @classmethod
-    def get_system_message(cls):
-        # get current time
-        return SystemMessage(content="You are a helpful AI assistant.\n"
-                                     "The current date or current time you know is wrong.\n"
-                                     "Respond directly if appropriate.")

+ 19 - 9
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast
 from langchain import BasePromptTemplate
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.base_language import BaseLanguageModel
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
@@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
+from core.chain.llm_chain import LLMChain
 from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
@@ -49,7 +49,6 @@ Action:
 
 
 class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
-    model_instance: BaseLLM
     dataset_tools: Sequence[BaseTool]
 
     class Config:
@@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
+            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
             raise new_exception
 
         try:
@@ -145,7 +144,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
     @classmethod
     def from_llm_and_tools(
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -157,17 +156,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             **kwargs: Any,
     ) -> Agent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
-            output_parser=output_parser,
+        """Construct an agent from an LLM and tools."""
+        cls._validate_tools(tools)
+        prompt = cls.create_prompt(
+            tools,
             prefix=prefix,
             suffix=suffix,
             human_message_template=human_message_template,
             format_instructions=format_instructions,
             input_variables=input_variables,
             memory_prompts=memory_prompts,
+        )
+        llm_chain = LLMChain(
+            model_instance=model_instance,
+            prompt=prompt,
+            callback_manager=callback_manager,
+        )
+        tool_names = [tool.name for tool in tools]
+        _output_parser = output_parser
+        return cls(
+            llm_chain=llm_chain,
+            allowed_tools=tool_names,
+            output_parser=_output_parser,
             dataset_tools=tools,
             **kwargs,
         )

+ 38 - 16
api/core/agent/agent/structured_chat.py

@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
 from langchain import BasePromptTemplate
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.base_language import BaseLanguageModel
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
-from langchain.memory.summary import SummarizerMixin
+from langchain.memory.prompt import SUMMARY_PROMPT
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
-from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
+from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
+    get_buffer_string
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
+from core.chain.llm_chain import LLMChain
 from core.model_providers.models.llm.base import BaseLLM
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -52,8 +53,7 @@ Action:
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
-    summary_llm: BaseLanguageModel = None
-    model_instance: BaseLLM
+    summary_model_instance: BaseLLM = None
 
     class Config:
         """Configuration for this pydantic object."""
@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if prompts:
             messages = prompts[0].to_messages()
 
-        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
+        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
+            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
             raise new_exception
 
         try:
@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
                                           "I don't know how to respond to that."}, "")
 
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2 and self.summary_llm:
+        if len(intermediate_steps) >= 2 and self.summary_model_instance:
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_messages = [AIMessage(content=observation)
                                        for _, observation in should_summary_intermediate_steps]
@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             error_msg = "Exceeded LLM tokens limit, stopped."
             raise ExceededLLMTokensLimitError(error_msg)
 
-        summary_handler = SummarizerMixin(llm=self.summary_llm)
         if self.moving_summary_buffer and 'chat_history' in kwargs:
             kwargs["chat_history"].pop()
 
-        self.moving_summary_buffer = summary_handler.predict_new_summary(
+        self.moving_summary_buffer = self.predict_new_summary(
             messages=should_summary_messages,
             existing_summary=self.moving_summary_buffer
         )
@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
 
         return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
 
+    def predict_new_summary(
+        self, messages: List[BaseMessage], existing_summary: str
+    ) -> str:
+        new_lines = get_buffer_string(
+            messages,
+            human_prefix="Human",
+            ai_prefix="AI",
+        )
+
+        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        return chain.predict(summary=existing_summary, new_lines=new_lines)
+
     @classmethod
     def create_prompt(
             cls,
@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     @classmethod
     def from_llm_and_tools(
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             **kwargs: Any,
     ) -> Agent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
-            output_parser=output_parser,
+        """Construct an agent from an LLM and tools."""
+        cls._validate_tools(tools)
+        prompt = cls.create_prompt(
+            tools,
             prefix=prefix,
             suffix=suffix,
             human_message_template=human_message_template,
             format_instructions=format_instructions,
             input_variables=input_variables,
             memory_prompts=memory_prompts,
+        )
+        llm_chain = LLMChain(
+            model_instance=model_instance,
+            prompt=prompt,
+            callback_manager=callback_manager,
+        )
+        tool_names = [tool.name for tool in tools]
+        _output_parser = output_parser
+        return cls(
+            llm_chain=llm_chain,
+            allowed_tools=tool_names,
+            output_parser=_output_parser,
             **kwargs,
         )

+ 2 - 18
api/core/agent/agent_executor.py

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
 
 from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
-from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
 from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
     REACT_ROUTER = 'react_router'
     REACT = 'react'
     FUNCTION_CALL = 'function_call'
-    MULTI_FUNCTION_CALL = 'multi_function_call'
 
 
 class AgentConfiguration(BaseModel):
@@ -64,30 +62,18 @@ class AgentExecutor:
         if self.configuration.strategy == PlanningStrategy.REACT:
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
-                summary_llm=self.configuration.summary_model_instance.client
+                summary_model_instance=self.configuration.summary_model_instance
                 if self.configuration.summary_model_instance else None,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_model_instance.client
-                if self.configuration.summary_model_instance else None,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
-            agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
-                tools=self.configuration.tools,
-                extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_model_instance.client
+                summary_model_instance=self.configuration.summary_model_instance
                 if self.configuration.summary_model_instance else None,
                 verbose=True
             )
@@ -95,7 +81,6 @@ class AgentExecutor:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 verbose=True
@@ -104,7 +89,6 @@ class AgentExecutor:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 verbose=True

+ 36 - 0
api/core/chain/llm_chain.py

@@ -0,0 +1,36 @@
+from typing import List, Dict, Any, Optional
+
+from langchain import LLMChain as LCLLMChain
+from langchain.callbacks.manager import CallbackManagerForChainRun
+from langchain.schema import LLMResult, Generation
+from langchain.schema.language_model import BaseLanguageModel
+
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
+
+
+class LLMChain(LCLLMChain):
+    model_instance: BaseLLM
+    """The language model instance to use."""
+    llm: BaseLanguageModel = FakeLLM(response="")
+
+    def generate(
+        self,
+        input_list: List[Dict[str, Any]],
+        run_manager: Optional[CallbackManagerForChainRun] = None,
+    ) -> LLMResult:
+        """Generate LLM result from inputs."""
+        prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
+        messages = prompts[0].to_messages()
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            stop=stop
+        )
+
+        generations = [
+            [Generation(text=result.content)]
+        ]
+
+        return LLMResult(generations=generations)

+ 18 - 3
api/core/model_providers/models/entity/message.py

@@ -1,6 +1,6 @@
 import enum
 
-from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
+from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
 from pydantic import BaseModel
 
 
@@ -9,6 +9,7 @@ class LLMRunResult(BaseModel):
     prompt_tokens: int
     completion_tokens: int
     source: list = None
+    function_call: dict = None
 
 
 class MessageType(enum.Enum):
@@ -20,6 +21,7 @@ class MessageType(enum.Enum):
 class PromptMessage(BaseModel):
     type: MessageType = MessageType.HUMAN
     content: str = ''
+    function_call: dict = None
 
 
 def to_lc_messages(messages: list[PromptMessage]):
@@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]):
         if message.type == MessageType.HUMAN:
             lc_messages.append(HumanMessage(content=message.content))
         elif message.type == MessageType.ASSISTANT:
-            lc_messages.append(AIMessage(content=message.content))
+            additional_kwargs = {}
+            if message.function_call:
+                additional_kwargs['function_call'] = message.function_call
+            lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
         elif message.type == MessageType.SYSTEM:
             lc_messages.append(SystemMessage(content=message.content))
 
@@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]):
         if isinstance(message, HumanMessage):
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
         elif isinstance(message, AIMessage):
-            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
+            message_kwargs = {
+                'content': message.content,
+                'type': MessageType.ASSISTANT
+            }
+
+            if 'function_call' in message.additional_kwargs:
+                message_kwargs['function_call'] = message.additional_kwargs['function_call']
+
+            prompt_messages.append(PromptMessage(**message_kwargs))
         elif isinstance(message, SystemMessage):
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
+        elif isinstance(message, FunctionMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
     return prompt_messages
 
 

+ 14 - 1
api/core/model_providers/models/llm/azure_openai_model.py

@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
         :return:
         """
         prompts = self._get_prompt_from_messages(messages)
-        return self._client.generate([prompts], stop, callbacks)
+        generate_kwargs = {
+            'stop': stop,
+            'callbacks': callbacks
+        }
+
+        if isinstance(prompts, str):
+            generate_kwargs['prompts'] = [prompts]
+        else:
+            generate_kwargs['messages'] = [prompts]
+
+        if 'functions' in kwargs:
+            generate_kwargs['functions'] = kwargs['functions']
+
+        return self._client.generate(**generate_kwargs)
     
     @property
     def base_model_name(self) -> str:

+ 8 - 12
api/core/model_providers/models/llm/base.py

@@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 from core.helper import moderation
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
+    to_lc_messages
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.providers.base import BaseModelProvider
 from core.prompt.prompt_builder import PromptBuilder
@@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel):
             except Exception as ex:
                 raise self.handle_exceptions(ex)
 
+        function_call = None
         if isinstance(result.generations[0][0], ChatGeneration):
             completion_content = result.generations[0][0].message.content
+            if 'function_call' in result.generations[0][0].message.additional_kwargs:
+                function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
         else:
             completion_content = result.generations[0][0].text
 
@@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel):
         return LLMRunResult(
             content=completion_content,
             prompt_tokens=prompt_tokens,
-            completion_tokens=completion_tokens
+            completion_tokens=completion_tokens,
+            function_call=function_call
         )
 
     @abstractmethod
@@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel):
             if len(messages) == 0:
                 return []
 
-            chat_messages = []
-            for message in messages:
-                if message.type == MessageType.HUMAN:
-                    chat_messages.append(HumanMessage(content=message.content))
-                elif message.type == MessageType.ASSISTANT:
-                    chat_messages.append(AIMessage(content=message.content))
-                elif message.type == MessageType.SYSTEM:
-                    chat_messages.append(SystemMessage(content=message.content))
-
-            return chat_messages
+            return to_lc_messages(messages)
 
     def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
         """

+ 15 - 1
api/core/model_providers/models/llm/openai_model.py

@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
             raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
 
         prompts = self._get_prompt_from_messages(messages)
-        return self._client.generate([prompts], stop, callbacks)
+
+        generate_kwargs = {
+            'stop': stop,
+            'callbacks': callbacks
+        }
+
+        if isinstance(prompts, str):
+            generate_kwargs['prompts'] = [prompts]
+        else:
+            generate_kwargs['messages'] = [prompts]
+
+        if 'functions' in kwargs:
+            generate_kwargs['functions'] = kwargs['functions']
+
+        return self._client.generate(**generate_kwargs)
 
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """

+ 10 - 14
api/core/orchestrator_rule_parser.py

@@ -1,7 +1,6 @@
 import math
 from typing import Optional
 
-from flask import current_app
 from langchain import WikipediaAPIWrapper
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
@@ -27,7 +26,6 @@ from core.tool.web_reader_tool import WebReaderTool
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetProcessRule
 from models.model import AppModelConfig
-from models.provider import ProviderType
 
 
 class OrchestratorRuleParser:
@@ -77,7 +75,7 @@ class OrchestratorRuleParser:
             # only OpenAI chat model (include Azure) support function call, use ReACT instead
             if agent_model_instance.model_mode != ModelMode.CHAT \
                     or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
-                if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
+                if planning_strategy == PlanningStrategy.FUNCTION_CALL:
                     planning_strategy = PlanningStrategy.REACT
                 elif planning_strategy == PlanningStrategy.ROUTER:
                     planning_strategy = PlanningStrategy.REACT_ROUTER
@@ -207,7 +205,10 @@ class OrchestratorRuleParser:
                 tool = self.to_current_datetime_tool()
 
             if tool:
-                tool.callbacks.extend(callbacks)
+                if tool.callbacks is not None:
+                    tool.callbacks.extend(callbacks)
+                else:
+                    tool.callbacks = callbacks
                 tools.append(tool)
 
         return tools
@@ -269,10 +270,9 @@ class OrchestratorRuleParser:
             summary_model_instance = None
 
         tool = WebReaderTool(
-            llm=summary_model_instance.client if summary_model_instance else None,
+            model_instance=summary_model_instance if summary_model_instance else None,
             max_chunk_length=4000,
-            continue_reading=True,
-            callbacks=[DifyStdOutCallbackHandler()]
+            continue_reading=True
         )
 
         return tool
@@ -290,16 +290,13 @@ class OrchestratorRuleParser:
                         "is not up to date. "
                         "Input should be a search query.",
             func=OptimizedSerpAPIWrapper(**func_kwargs).run,
-            args_schema=OptimizedSerpAPIInput,
-            callbacks=[DifyStdOutCallbackHandler()]
+            args_schema=OptimizedSerpAPIInput
         )
 
         return tool
 
     def to_current_datetime_tool(self) -> Optional[BaseTool]:
-        tool = DatetimeTool(
-            callbacks=[DifyStdOutCallbackHandler()]
-        )
+        tool = DatetimeTool()
 
         return tool
 
@@ -310,8 +307,7 @@ class OrchestratorRuleParser:
         return WikipediaQueryRun(
             name="wikipedia",
             api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
-            args_schema=WikipediaInput,
-            callbacks=[DifyStdOutCallbackHandler()]
+            args_schema=WikipediaInput
         )
 
     @classmethod

+ 24 - 6
api/core/tool/web_reader_tool.py

@@ -11,8 +11,8 @@ from typing import Type
 
 import requests
 from bs4 import BeautifulSoup, NavigableString, Comment, CData
-from langchain.base_language import BaseLanguageModel
-from langchain.chains.summarize import load_summarize_chain
+from langchain.chains import RefineDocumentsChain
+from langchain.chains.summarize import refine_prompts
 from langchain.schema import Document
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.tools.base import BaseTool
@@ -20,8 +20,10 @@ from newspaper import Article
 from pydantic import BaseModel, Field
 from regex import regex
 
+from core.chain.llm_chain import LLMChain
 from core.data_loader import file_extractor
 from core.data_loader.file_extractor import FileExtractor
+from core.model_providers.models.llm.base import BaseLLM
 
 FULL_TEMPLATE = """
 TITLE: {title}
@@ -65,7 +67,7 @@ class WebReaderTool(BaseTool):
     summary_chunk_overlap: int = 0
     summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
     continue_reading: bool = True
-    llm: BaseLanguageModel = None
+    model_instance: BaseLLM = None
 
     def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
         try:
@@ -78,7 +80,7 @@ class WebReaderTool(BaseTool):
         except Exception as e:
             return f'Read this website failed, caused by: {str(e)}.'
 
-        if summary and self.llm:
+        if summary and self.model_instance:
             character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                 chunk_size=self.summary_chunk_tokens,
                 chunk_overlap=self.summary_chunk_overlap,
@@ -95,10 +97,9 @@ class WebReaderTool(BaseTool):
             if len(docs) > 5:
                 docs = docs[:5]
 
-            chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
+            chain = self.get_summary_chain()
             try:
                 page_contents = chain.run(docs)
-                # todo use cache
             except Exception as e:
                 return f'Read this website failed, caused by: {str(e)}.'
         else:
@@ -114,6 +115,23 @@ class WebReaderTool(BaseTool):
     async def _arun(self, url: str) -> str:
         raise NotImplementedError
 
+    def get_summary_chain(self) -> RefineDocumentsChain:
+        initial_chain = LLMChain(
+            model_instance=self.model_instance,
+            prompt=refine_prompts.PROMPT
+        )
+        refine_chain = LLMChain(
+            model_instance=self.model_instance,
+            prompt=refine_prompts.REFINE_PROMPT
+        )
+        return RefineDocumentsChain(
+            initial_llm_chain=initial_chain,
+            refine_llm_chain=refine_chain,
+            document_variable_name="text",
+            initial_response_name="existing_answer",
+            callbacks=self.callbacks
+        )
+
 
 def page_result(text: str, cursor: int, max_length: int) -> str:
     """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""