浏览代码

feat: remove llm client use (#1316)

takatost 1 年之前
父节点
当前提交
cbf095465c

+ 57 - 7
api/core/agent/agent/multi_dataset_router_agent.py

@@ -2,14 +2,18 @@ import json
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
+from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts.chat import BaseMessagePromptTemplate
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
 from langchain.schema.language_model import BaseLanguageModel
 from langchain.schema.language_model import BaseLanguageModel
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
+from pydantic import root_validator
 
 
+from core.model_providers.models.entity.message import to_prompt_messages
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
 
 
@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
 
 
         arbitrary_types_allowed = True
         arbitrary_types_allowed = True
 
 
+    @root_validator
+    def validate_llm(cls, values: dict) -> dict:
+        return values
+
     def should_use_agent(self, query: str):
     def should_use_agent(self, query: str):
         """
         """
         return should use agent
         return should use agent
@@ -65,7 +73,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             return AgentFinish(return_values={"output": observation}, log=observation)
             return AgentFinish(return_values={"output": observation}, log=observation)
 
 
         try:
         try:
-            agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
+            agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
             if isinstance(agent_decision, AgentAction):
             if isinstance(agent_decision, AgentAction):
                 tool_inputs = agent_decision.tool_input
                 tool_inputs = agent_decision.tool_input
                 if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
                 if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
@@ -76,6 +84,44 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             new_exception = self.model_instance.handle_exceptions(e)
             new_exception = self.model_instance.handle_exceptions(e)
             raise new_exception
             raise new_exception
 
 
+    def real_plan(
+        self,
+        intermediate_steps: List[Tuple[AgentAction, str]],
+        callbacks: Callbacks = None,
+        **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        """Given input, decided what to do.
+
+        Args:
+            intermediate_steps: Steps the LLM has taken to date, along with observations
+            **kwargs: User inputs.
+
+        Returns:
+            Action specifying what tool to use.
+        """
+        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
+        selected_inputs = {
+            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
+        }
+        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
+        prompt = self.prompt.format_prompt(**full_inputs)
+        messages = prompt.to_messages()
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            functions=self.functions,
+        )
+
+        ai_message = AIMessage(
+            content=result.content,
+            additional_kwargs={
+                'function_call': result.function_call
+            }
+        )
+
+        agent_decision = _parse_ai_message(ai_message)
+        return agent_decision
+
     async def aplan(
     async def aplan(
             self,
             self,
             intermediate_steps: List[Tuple[AgentAction, str]],
             intermediate_steps: List[Tuple[AgentAction, str]],
@@ -87,7 +133,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -96,11 +142,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             ),
             ),
             **kwargs: Any,
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
     ) -> BaseSingleActionAgent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
+        prompt = cls.create_prompt(
             extra_prompt_messages=extra_prompt_messages,
             extra_prompt_messages=extra_prompt_messages,
             system_message=system_message,
             system_message=system_message,
+        )
+        return cls(
+            model_instance=model_instance,
+            llm=FakeLLM(response=''),
+            prompt=prompt,
+            tools=tools,
+            callback_manager=callback_manager,
             **kwargs,
             **kwargs,
         )
         )

+ 193 - 19
api/core/agent/agent/openai_function_call.py

@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
     _format_intermediate_steps
     _format_intermediate_steps
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
+from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
+from langchain.memory.prompt import SUMMARY_PROMPT
 from langchain.prompts.chat import BaseMessagePromptTemplate
 from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
-from langchain.schema.language_model import BaseLanguageModel
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
+    get_buffer_string
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
+from pydantic import root_validator
 
 
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
-from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
+from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
+from core.chain.llm_chain import LLMChain
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
 
 
 
 
-class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
+class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
+    moving_summary_buffer: str = ""
+    moving_summary_index: int = 0
+    summary_model_instance: BaseLLM = None
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    @root_validator
+    def validate_llm(cls, values: dict) -> dict:
+        return values
 
 
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
             extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
             ),
             ),
             **kwargs: Any,
             **kwargs: Any,
     ) -> BaseSingleActionAgent:
     ) -> BaseSingleActionAgent:
-        return super().from_llm_and_tools(
-            llm=llm,
+        prompt = cls.create_prompt(
+            extra_prompt_messages=extra_prompt_messages,
+            system_message=system_message,
+        )
+        return cls(
+            model_instance=model_instance,
+            llm=FakeLLM(response=''),
+            prompt=prompt,
             tools=tools,
             tools=tools,
             callback_manager=callback_manager,
             callback_manager=callback_manager,
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=cls.get_system_message(),
             **kwargs,
             **kwargs,
         )
         )
 
 
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
         :param query:
         :param query:
         :return:
         :return:
         """
         """
-        original_max_tokens = self.llm.max_tokens
-        self.llm.max_tokens = 40
+        original_max_tokens = self.model_instance.model_kwargs.max_tokens
+        self.model_instance.model_kwargs.max_tokens = 40
 
 
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
         messages = prompt.to_messages()
         messages = prompt.to_messages()
 
 
         try:
         try:
-            predicted_message = self.llm.predict_messages(
-                messages, functions=self.functions, callbacks=None
+            prompt_messages = to_prompt_messages(messages)
+            result = self.model_instance.run(
+                messages=prompt_messages,
+                functions=self.functions,
+                callbacks=None
             )
             )
         except Exception as e:
         except Exception as e:
             new_exception = self.model_instance.handle_exceptions(e)
             new_exception = self.model_instance.handle_exceptions(e)
             raise new_exception
             raise new_exception
 
 
-        function_call = predicted_message.additional_kwargs.get("function_call", {})
+        function_call = result.function_call
 
 
-        self.llm.max_tokens = original_max_tokens
+        self.model_instance.model_kwargs.max_tokens = original_max_tokens
 
 
         return True if function_call else False
         return True if function_call else False
 
 
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
         except ExceededLLMTokensLimitError as e:
         except ExceededLLMTokensLimitError as e:
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
             return AgentFinish(return_values={"output": str(e)}, log=str(e))
 
 
-        predicted_message = self.llm.predict_messages(
-            messages, functions=self.functions, callbacks=callbacks
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            functions=self.functions,
+        )
+
+        ai_message = AIMessage(
+            content=result.content,
+            additional_kwargs={
+                'function_call': result.function_call
+            }
         )
         )
-        agent_decision = _parse_ai_message(predicted_message)
+        agent_decision = _parse_ai_message(ai_message)
 
 
         if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
         if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
             tool_inputs = agent_decision.tool_input
             tool_inputs = agent_decision.tool_input
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
             return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
             return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
         except ValueError:
         except ValueError:
             return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
             return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
+
+    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
+        # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
+        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
+        rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
+        if rest_tokens >= 0:
+            return messages
+
+        system_message = None
+        human_message = None
+        should_summary_messages = []
+        for message in messages:
+            if isinstance(message, SystemMessage):
+                system_message = message
+            elif isinstance(message, HumanMessage):
+                human_message = message
+            else:
+                should_summary_messages.append(message)
+
+        if len(should_summary_messages) > 2:
+            ai_message = should_summary_messages[-2]
+            function_message = should_summary_messages[-1]
+            should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
+            self.moving_summary_index = len(should_summary_messages)
+        else:
+            error_msg = "Exceeded LLM tokens limit, stopped."
+            raise ExceededLLMTokensLimitError(error_msg)
+
+        new_messages = [system_message, human_message]
+
+        if self.moving_summary_index == 0:
+            should_summary_messages.insert(0, human_message)
+
+        self.moving_summary_buffer = self.predict_new_summary(
+            messages=should_summary_messages,
+            existing_summary=self.moving_summary_buffer
+        )
+
+        new_messages.append(AIMessage(content=self.moving_summary_buffer))
+        new_messages.append(ai_message)
+        new_messages.append(function_message)
+
+        return new_messages
+
+    def predict_new_summary(
+        self, messages: List[BaseMessage], existing_summary: str
+    ) -> str:
+        new_lines = get_buffer_string(
+            messages,
+            human_prefix="Human",
+            ai_prefix="AI",
+        )
+
+        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        return chain.predict(summary=existing_summary, new_lines=new_lines)
+
+    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
+        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+        Official documentation: https://github.com/openai/openai-cookbook/blob/
+        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+        if model_instance.model_provider.provider_name == 'azure_openai':
+            model = model_instance.base_model_name
+            model = model.replace("gpt-35", "gpt-3.5")
+        else:
+            model = model_instance.base_model_name
+
+        tiktoken_ = _import_tiktoken()
+        try:
+            encoding = tiktoken_.encoding_for_model(model)
+        except KeyError:
+            model = "cl100k_base"
+            encoding = tiktoken_.get_encoding(model)
+
+        if model.startswith("gpt-3.5-turbo"):
+            # every message follows <im_start>{role/name}\n{content}<im_end>\n
+            tokens_per_message = 4
+            # if there's a name, the role is omitted
+            tokens_per_name = -1
+        elif model.startswith("gpt-4"):
+            tokens_per_message = 3
+            tokens_per_name = 1
+        else:
+            raise NotImplementedError(
+                f"get_num_tokens_from_messages() is not presently implemented "
+                f"for model {model}."
+                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
+                "information on how messages are converted to tokens."
+            )
+        num_tokens = 0
+        for m in messages:
+            message = _convert_message_to_dict(m)
+            num_tokens += tokens_per_message
+            for key, value in message.items():
+                if key == "function_call":
+                    for f_key, f_value in value.items():
+                        num_tokens += len(encoding.encode(f_key))
+                        num_tokens += len(encoding.encode(f_value))
+                else:
+                    num_tokens += len(encoding.encode(value))
+
+                if key == "name":
+                    num_tokens += tokens_per_name
+        # every reply is primed with <im_start>assistant
+        num_tokens += 3
+
+        if kwargs.get('functions'):
+            for function in kwargs.get('functions'):
+                num_tokens += len(encoding.encode('name'))
+                num_tokens += len(encoding.encode(function.get("name")))
+                num_tokens += len(encoding.encode('description'))
+                num_tokens += len(encoding.encode(function.get("description")))
+                parameters = function.get("parameters")
+                num_tokens += len(encoding.encode('parameters'))
+                if 'title' in parameters:
+                    num_tokens += len(encoding.encode('title'))
+                    num_tokens += len(encoding.encode(parameters.get("title")))
+                num_tokens += len(encoding.encode('type'))
+                num_tokens += len(encoding.encode(parameters.get("type")))
+                if 'properties' in parameters:
+                    num_tokens += len(encoding.encode('properties'))
+                    for key, value in parameters.get('properties').items():
+                        num_tokens += len(encoding.encode(key))
+                        for field_key, field_value in value.items():
+                            num_tokens += len(encoding.encode(field_key))
+                            if field_key == 'enum':
+                                for enum_field in field_value:
+                                    num_tokens += 3
+                                    num_tokens += len(encoding.encode(enum_field))
+                            else:
+                                num_tokens += len(encoding.encode(field_key))
+                                num_tokens += len(encoding.encode(str(field_value)))
+                if 'required' in parameters:
+                    num_tokens += len(encoding.encode('required'))
+                    for required_field in parameters['required']:
+                        num_tokens += 3
+                        num_tokens += len(encoding.encode(required_field))
+
+        return num_tokens

+ 0 - 140
api/core/agent/agent/openai_function_call_summarize_mixin.py

@@ -1,140 +0,0 @@
-from typing import cast, List
-
-from langchain.chat_models import ChatOpenAI
-from langchain.chat_models.openai import _convert_message_to_dict
-from langchain.memory.summary import SummarizerMixin
-from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
-from langchain.schema.language_model import BaseLanguageModel
-from pydantic import BaseModel
-
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
-from core.model_providers.models.llm.base import BaseLLM
-
-
-class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
-    moving_summary_buffer: str = ""
-    moving_summary_index: int = 0
-    summary_llm: BaseLanguageModel = None
-    model_instance: BaseLLM
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        arbitrary_types_allowed = True
-
-    def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
-        # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
-        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
-        rest_tokens = rest_tokens - 20  # to deal with the inaccuracy of rest_tokens
-        if rest_tokens >= 0:
-            return messages
-
-        system_message = None
-        human_message = None
-        should_summary_messages = []
-        for message in messages:
-            if isinstance(message, SystemMessage):
-                system_message = message
-            elif isinstance(message, HumanMessage):
-                human_message = message
-            else:
-                should_summary_messages.append(message)
-
-        if len(should_summary_messages) > 2:
-            ai_message = should_summary_messages[-2]
-            function_message = should_summary_messages[-1]
-            should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
-            self.moving_summary_index = len(should_summary_messages)
-        else:
-            error_msg = "Exceeded LLM tokens limit, stopped."
-            raise ExceededLLMTokensLimitError(error_msg)
-
-        new_messages = [system_message, human_message]
-
-        if self.moving_summary_index == 0:
-            should_summary_messages.insert(0, human_message)
-
-        summary_handler = SummarizerMixin(llm=self.summary_llm)
-        self.moving_summary_buffer = summary_handler.predict_new_summary(
-            messages=should_summary_messages,
-            existing_summary=self.moving_summary_buffer
-        )
-
-        new_messages.append(AIMessage(content=self.moving_summary_buffer))
-        new_messages.append(ai_message)
-        new_messages.append(function_message)
-
-        return new_messages
-
-    def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
-        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
-
-        Official documentation: https://github.com/openai/openai-cookbook/blob/
-        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
-        llm = cast(ChatOpenAI, model_instance.client)
-        model, encoding = llm._get_encoding_model()
-        if model.startswith("gpt-3.5-turbo"):
-            # every message follows <im_start>{role/name}\n{content}<im_end>\n
-            tokens_per_message = 4
-            # if there's a name, the role is omitted
-            tokens_per_name = -1
-        elif model.startswith("gpt-4"):
-            tokens_per_message = 3
-            tokens_per_name = 1
-        else:
-            raise NotImplementedError(
-                f"get_num_tokens_from_messages() is not presently implemented "
-                f"for model {model}."
-                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
-                "information on how messages are converted to tokens."
-            )
-        num_tokens = 0
-        for m in messages:
-            message = _convert_message_to_dict(m)
-            num_tokens += tokens_per_message
-            for key, value in message.items():
-                if key == "function_call":
-                    for f_key, f_value in value.items():
-                        num_tokens += len(encoding.encode(f_key))
-                        num_tokens += len(encoding.encode(f_value))
-                else:
-                    num_tokens += len(encoding.encode(value))
-
-                if key == "name":
-                    num_tokens += tokens_per_name
-        # every reply is primed with <im_start>assistant
-        num_tokens += 3
-
-        if kwargs.get('functions'):
-            for function in kwargs.get('functions'):
-                num_tokens += len(encoding.encode('name'))
-                num_tokens += len(encoding.encode(function.get("name")))
-                num_tokens += len(encoding.encode('description'))
-                num_tokens += len(encoding.encode(function.get("description")))
-                parameters = function.get("parameters")
-                num_tokens += len(encoding.encode('parameters'))
-                if 'title' in parameters:
-                    num_tokens += len(encoding.encode('title'))
-                    num_tokens += len(encoding.encode(parameters.get("title")))
-                num_tokens += len(encoding.encode('type'))
-                num_tokens += len(encoding.encode(parameters.get("type")))
-                if 'properties' in parameters:
-                    num_tokens += len(encoding.encode('properties'))
-                    for key, value in parameters.get('properties').items():
-                        num_tokens += len(encoding.encode(key))
-                        for field_key, field_value in value.items():
-                            num_tokens += len(encoding.encode(field_key))
-                            if field_key == 'enum':
-                                for enum_field in field_value:
-                                    num_tokens += 3
-                                    num_tokens += len(encoding.encode(enum_field))
-                            else:
-                                num_tokens += len(encoding.encode(field_key))
-                                num_tokens += len(encoding.encode(str(field_value)))
-                if 'required' in parameters:
-                    num_tokens += len(encoding.encode('required'))
-                    for required_field in parameters['required']:
-                        num_tokens += 3
-                        num_tokens += len(encoding.encode(required_field))
-
-        return num_tokens

+ 0 - 107
api/core/agent/agent/openai_multi_function_call.py

@@ -1,107 +0,0 @@
-from typing import List, Tuple, Any, Union, Sequence, Optional
-
-from langchain.agents import BaseMultiActionAgent
-from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
-    _parse_ai_message
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.manager import Callbacks
-from langchain.prompts.chat import BaseMessagePromptTemplate
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
-from langchain.schema.language_model import BaseLanguageModel
-from langchain.tools import BaseTool
-
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
-from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
-
-
-class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
-
-    @classmethod
-    def from_llm_and_tools(
-            cls,
-            llm: BaseLanguageModel,
-            tools: Sequence[BaseTool],
-            callback_manager: Optional[BaseCallbackManager] = None,
-            extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
-            system_message: Optional[SystemMessage] = SystemMessage(
-                content="You are a helpful AI assistant."
-            ),
-            **kwargs: Any,
-    ) -> BaseMultiActionAgent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
-            extra_prompt_messages=extra_prompt_messages,
-            system_message=cls.get_system_message(),
-            **kwargs,
-        )
-
-    def should_use_agent(self, query: str):
-        """
-        return should use agent
-
-        :param query:
-        :return:
-        """
-        original_max_tokens = self.llm.max_tokens
-        self.llm.max_tokens = 15
-
-        prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
-        messages = prompt.to_messages()
-
-        try:
-            predicted_message = self.llm.predict_messages(
-                messages, functions=self.functions, callbacks=None
-            )
-        except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
-            raise new_exception
-
-        function_call = predicted_message.additional_kwargs.get("function_call", {})
-
-        self.llm.max_tokens = original_max_tokens
-
-        return True if function_call else False
-
-    def plan(
-            self,
-            intermediate_steps: List[Tuple[AgentAction, str]],
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> Union[AgentAction, AgentFinish]:
-        """Given input, decided what to do.
-
-        Args:
-            intermediate_steps: Steps the LLM has taken to date, along with observations
-            **kwargs: User inputs.
-
-        Returns:
-            Action specifying what tool to use.
-        """
-        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
-        selected_inputs = {
-            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
-        }
-        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
-        prompt = self.prompt.format_prompt(**full_inputs)
-        messages = prompt.to_messages()
-
-        # summarize messages if rest_tokens < 0
-        try:
-            messages = self.summarize_messages_if_needed(messages, functions=self.functions)
-        except ExceededLLMTokensLimitError as e:
-            return AgentFinish(return_values={"output": str(e)}, log=str(e))
-
-        predicted_message = self.llm.predict_messages(
-            messages, functions=self.functions, callbacks=callbacks
-        )
-        agent_decision = _parse_ai_message(predicted_message)
-        return agent_decision
-
-    @classmethod
-    def get_system_message(cls):
-        # get current time
-        return SystemMessage(content="You are a helpful AI assistant.\n"
-                                     "The current date or current time you know is wrong.\n"
-                                     "Respond directly if appropriate.")

+ 19 - 9
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast
 from langchain import BasePromptTemplate
 from langchain import BasePromptTemplate
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.base_language import BaseLanguageModel
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
@@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 
+from core.chain.llm_chain import LLMChain
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
@@ -49,7 +49,6 @@ Action:
 
 
 
 
 class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
 class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
-    model_instance: BaseLLM
     dataset_tools: Sequence[BaseTool]
     dataset_tools: Sequence[BaseTool]
 
 
     class Config:
     class Config:
@@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         try:
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
         except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
+            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
             raise new_exception
             raise new_exception
 
 
         try:
         try:
@@ -145,7 +144,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -157,17 +156,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             **kwargs: Any,
             **kwargs: Any,
     ) -> Agent:
     ) -> Agent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
-            output_parser=output_parser,
+        """Construct an agent from an LLM and tools."""
+        cls._validate_tools(tools)
+        prompt = cls.create_prompt(
+            tools,
             prefix=prefix,
             prefix=prefix,
             suffix=suffix,
             suffix=suffix,
             human_message_template=human_message_template,
             human_message_template=human_message_template,
             format_instructions=format_instructions,
             format_instructions=format_instructions,
             input_variables=input_variables,
             input_variables=input_variables,
             memory_prompts=memory_prompts,
             memory_prompts=memory_prompts,
+        )
+        llm_chain = LLMChain(
+            model_instance=model_instance,
+            prompt=prompt,
+            callback_manager=callback_manager,
+        )
+        tool_names = [tool.name for tool in tools]
+        _output_parser = output_parser
+        return cls(
+            llm_chain=llm_chain,
+            allowed_tools=tool_names,
+            output_parser=_output_parser,
             dataset_tools=tools,
             dataset_tools=tools,
             **kwargs,
             **kwargs,
         )
         )

+ 38 - 16
api/core/agent/agent/structured_chat.py

@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
 from langchain import BasePromptTemplate
 from langchain import BasePromptTemplate
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
-from langchain.base_language import BaseLanguageModel
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.base import BaseCallbackManager
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
-from langchain.memory.summary import SummarizerMixin
+from langchain.memory.prompt import SUMMARY_PROMPT
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
-from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
+from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
+    get_buffer_string
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
+from core.chain.llm_chain import LLMChain
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -52,8 +53,7 @@ Action:
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
     moving_summary_index: int = 0
-    summary_llm: BaseLanguageModel = None
-    model_instance: BaseLLM
+    summary_model_instance: BaseLLM = None
 
 
     class Config:
     class Config:
         """Configuration for this pydantic object."""
         """Configuration for this pydantic object."""
@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         if prompts:
         if prompts:
             messages = prompts[0].to_messages()
             messages = prompts[0].to_messages()
 
 
-        rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
+        rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
         if rest_tokens < 0:
         if rest_tokens < 0:
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
             full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
 
 
         try:
         try:
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
             full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
         except Exception as e:
         except Exception as e:
-            new_exception = self.model_instance.handle_exceptions(e)
+            new_exception = self.llm_chain.model_instance.handle_exceptions(e)
             raise new_exception
             raise new_exception
 
 
         try:
         try:
@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
                                           "I don't know how to respond to that."}, "")
                                           "I don't know how to respond to that."}, "")
 
 
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2 and self.summary_llm:
+        if len(intermediate_steps) >= 2 and self.summary_model_instance:
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_messages = [AIMessage(content=observation)
             should_summary_messages = [AIMessage(content=observation)
                                        for _, observation in should_summary_intermediate_steps]
                                        for _, observation in should_summary_intermediate_steps]
@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             error_msg = "Exceeded LLM tokens limit, stopped."
             error_msg = "Exceeded LLM tokens limit, stopped."
             raise ExceededLLMTokensLimitError(error_msg)
             raise ExceededLLMTokensLimitError(error_msg)
 
 
-        summary_handler = SummarizerMixin(llm=self.summary_llm)
         if self.moving_summary_buffer and 'chat_history' in kwargs:
         if self.moving_summary_buffer and 'chat_history' in kwargs:
             kwargs["chat_history"].pop()
             kwargs["chat_history"].pop()
 
 
-        self.moving_summary_buffer = summary_handler.predict_new_summary(
+        self.moving_summary_buffer = self.predict_new_summary(
             messages=should_summary_messages,
             messages=should_summary_messages,
             existing_summary=self.moving_summary_buffer
             existing_summary=self.moving_summary_buffer
         )
         )
@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
 
 
         return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
         return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
 
 
+    def predict_new_summary(
+        self, messages: List[BaseMessage], existing_summary: str
+    ) -> str:
+        new_lines = get_buffer_string(
+            messages,
+            human_prefix="Human",
+            ai_prefix="AI",
+        )
+
+        chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
+        return chain.predict(summary=existing_summary, new_lines=new_lines)
+
     @classmethod
     @classmethod
     def create_prompt(
     def create_prompt(
             cls,
             cls,
@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     @classmethod
     @classmethod
     def from_llm_and_tools(
     def from_llm_and_tools(
             cls,
             cls,
-            llm: BaseLanguageModel,
+            model_instance: BaseLLM,
             tools: Sequence[BaseTool],
             tools: Sequence[BaseTool],
             callback_manager: Optional[BaseCallbackManager] = None,
             callback_manager: Optional[BaseCallbackManager] = None,
             output_parser: Optional[AgentOutputParser] = None,
             output_parser: Optional[AgentOutputParser] = None,
@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             memory_prompts: Optional[List[BasePromptTemplate]] = None,
             **kwargs: Any,
             **kwargs: Any,
     ) -> Agent:
     ) -> Agent:
-        return super().from_llm_and_tools(
-            llm=llm,
-            tools=tools,
-            callback_manager=callback_manager,
-            output_parser=output_parser,
+        """Construct an agent from an LLM and tools."""
+        cls._validate_tools(tools)
+        prompt = cls.create_prompt(
+            tools,
             prefix=prefix,
             prefix=prefix,
             suffix=suffix,
             suffix=suffix,
             human_message_template=human_message_template,
             human_message_template=human_message_template,
             format_instructions=format_instructions,
             format_instructions=format_instructions,
             input_variables=input_variables,
             input_variables=input_variables,
             memory_prompts=memory_prompts,
             memory_prompts=memory_prompts,
+        )
+        llm_chain = LLMChain(
+            model_instance=model_instance,
+            prompt=prompt,
+            callback_manager=callback_manager,
+        )
+        tool_names = [tool.name for tool in tools]
+        _output_parser = output_parser
+        return cls(
+            llm_chain=llm_chain,
+            allowed_tools=tool_names,
+            output_parser=_output_parser,
             **kwargs,
             **kwargs,
         )
         )

+ 2 - 18
api/core/agent/agent_executor.py

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
 
 
 from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
 from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
-from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
 from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
 from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
     REACT_ROUTER = 'react_router'
     REACT_ROUTER = 'react_router'
     REACT = 'react'
     REACT = 'react'
     FUNCTION_CALL = 'function_call'
     FUNCTION_CALL = 'function_call'
-    MULTI_FUNCTION_CALL = 'multi_function_call'
 
 
 
 
 class AgentConfiguration(BaseModel):
 class AgentConfiguration(BaseModel):
@@ -64,30 +62,18 @@ class AgentExecutor:
         if self.configuration.strategy == PlanningStrategy.REACT:
         if self.configuration.strategy == PlanningStrategy.REACT:
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
             agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 output_parser=StructuredChatOutputParser(),
-                summary_llm=self.configuration.summary_model_instance.client
+                summary_model_instance=self.configuration.summary_model_instance
                 if self.configuration.summary_model_instance else None,
                 if self.configuration.summary_model_instance else None,
                 verbose=True
                 verbose=True
             )
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
             agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_model_instance.client
-                if self.configuration.summary_model_instance else None,
-                verbose=True
-            )
-        elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
-            agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
-                model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
-                tools=self.configuration.tools,
-                extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_model_instance.client
+                summary_model_instance=self.configuration.summary_model_instance
                 if self.configuration.summary_model_instance else None,
                 if self.configuration.summary_model_instance else None,
                 verbose=True
                 verbose=True
             )
             )
@@ -95,7 +81,6 @@ class AgentExecutor:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
                 verbose=True
                 verbose=True
@@ -104,7 +89,6 @@ class AgentExecutor:
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
                 model_instance=self.configuration.model_instance,
-                llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
                 output_parser=StructuredChatOutputParser(),
                 verbose=True
                 verbose=True

+ 36 - 0
api/core/chain/llm_chain.py

@@ -0,0 +1,36 @@
+from typing import List, Dict, Any, Optional
+
+from langchain import LLMChain as LCLLMChain
+from langchain.callbacks.manager import CallbackManagerForChainRun
+from langchain.schema import LLMResult, Generation
+from langchain.schema.language_model import BaseLanguageModel
+
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
+
+
+class LLMChain(LCLLMChain):
+    model_instance: BaseLLM
+    """The language model instance to use."""
+    llm: BaseLanguageModel = FakeLLM(response="")
+
+    def generate(
+        self,
+        input_list: List[Dict[str, Any]],
+        run_manager: Optional[CallbackManagerForChainRun] = None,
+    ) -> LLMResult:
+        """Generate LLM result from inputs."""
+        prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
+        messages = prompts[0].to_messages()
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            stop=stop
+        )
+
+        generations = [
+            [Generation(text=result.content)]
+        ]
+
+        return LLMResult(generations=generations)

+ 18 - 3
api/core/model_providers/models/entity/message.py

@@ -1,6 +1,6 @@
 import enum
 import enum
 
 
-from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
+from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 
 
@@ -9,6 +9,7 @@ class LLMRunResult(BaseModel):
     prompt_tokens: int
     prompt_tokens: int
     completion_tokens: int
     completion_tokens: int
     source: list = None
     source: list = None
+    function_call: dict = None
 
 
 
 
 class MessageType(enum.Enum):
 class MessageType(enum.Enum):
@@ -20,6 +21,7 @@ class MessageType(enum.Enum):
 class PromptMessage(BaseModel):
 class PromptMessage(BaseModel):
     type: MessageType = MessageType.HUMAN
     type: MessageType = MessageType.HUMAN
     content: str = ''
     content: str = ''
+    function_call: dict = None
 
 
 
 
 def to_lc_messages(messages: list[PromptMessage]):
 def to_lc_messages(messages: list[PromptMessage]):
@@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]):
         if message.type == MessageType.HUMAN:
         if message.type == MessageType.HUMAN:
             lc_messages.append(HumanMessage(content=message.content))
             lc_messages.append(HumanMessage(content=message.content))
         elif message.type == MessageType.ASSISTANT:
         elif message.type == MessageType.ASSISTANT:
-            lc_messages.append(AIMessage(content=message.content))
+            additional_kwargs = {}
+            if message.function_call:
+                additional_kwargs['function_call'] = message.function_call
+            lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
         elif message.type == MessageType.SYSTEM:
         elif message.type == MessageType.SYSTEM:
             lc_messages.append(SystemMessage(content=message.content))
             lc_messages.append(SystemMessage(content=message.content))
 
 
@@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]):
         if isinstance(message, HumanMessage):
         if isinstance(message, HumanMessage):
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
         elif isinstance(message, AIMessage):
         elif isinstance(message, AIMessage):
-            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
+            message_kwargs = {
+                'content': message.content,
+                'type': MessageType.ASSISTANT
+            }
+
+            if 'function_call' in message.additional_kwargs:
+                message_kwargs['function_call'] = message.additional_kwargs['function_call']
+
+            prompt_messages.append(PromptMessage(**message_kwargs))
         elif isinstance(message, SystemMessage):
         elif isinstance(message, SystemMessage):
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
+        elif isinstance(message, FunctionMessage):
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
     return prompt_messages
     return prompt_messages
 
 
 
 

+ 14 - 1
api/core/model_providers/models/llm/azure_openai_model.py

@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
         :return:
         :return:
         """
         """
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
-        return self._client.generate([prompts], stop, callbacks)
+        generate_kwargs = {
+            'stop': stop,
+            'callbacks': callbacks
+        }
+
+        if isinstance(prompts, str):
+            generate_kwargs['prompts'] = [prompts]
+        else:
+            generate_kwargs['messages'] = [prompts]
+
+        if 'functions' in kwargs:
+            generate_kwargs['functions'] = kwargs['functions']
+
+        return self._client.generate(**generate_kwargs)
     
     
     @property
     @property
     def base_model_name(self) -> str:
     def base_model_name(self) -> str:

+ 8 - 12
api/core/model_providers/models/llm/base.py

@@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 from core.helper import moderation
 from core.helper import moderation
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
+    to_lc_messages
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.providers.base import BaseModelProvider
 from core.model_providers.providers.base import BaseModelProvider
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
@@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel):
             except Exception as ex:
             except Exception as ex:
                 raise self.handle_exceptions(ex)
                 raise self.handle_exceptions(ex)
 
 
+        function_call = None
         if isinstance(result.generations[0][0], ChatGeneration):
         if isinstance(result.generations[0][0], ChatGeneration):
             completion_content = result.generations[0][0].message.content
             completion_content = result.generations[0][0].message.content
+            if 'function_call' in result.generations[0][0].message.additional_kwargs:
+                function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
         else:
         else:
             completion_content = result.generations[0][0].text
             completion_content = result.generations[0][0].text
 
 
@@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel):
         return LLMRunResult(
         return LLMRunResult(
             content=completion_content,
             content=completion_content,
             prompt_tokens=prompt_tokens,
             prompt_tokens=prompt_tokens,
-            completion_tokens=completion_tokens
+            completion_tokens=completion_tokens,
+            function_call=function_call
         )
         )
 
 
     @abstractmethod
     @abstractmethod
@@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel):
             if len(messages) == 0:
             if len(messages) == 0:
                 return []
                 return []
 
 
-            chat_messages = []
-            for message in messages:
-                if message.type == MessageType.HUMAN:
-                    chat_messages.append(HumanMessage(content=message.content))
-                elif message.type == MessageType.ASSISTANT:
-                    chat_messages.append(AIMessage(content=message.content))
-                elif message.type == MessageType.SYSTEM:
-                    chat_messages.append(SystemMessage(content=message.content))
-
-            return chat_messages
+            return to_lc_messages(messages)
 
 
     def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
     def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
         """
         """

+ 15 - 1
api/core/model_providers/models/llm/openai_model.py

@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
             raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
             raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
 
 
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
-        return self._client.generate([prompts], stop, callbacks)
+
+        generate_kwargs = {
+            'stop': stop,
+            'callbacks': callbacks
+        }
+
+        if isinstance(prompts, str):
+            generate_kwargs['prompts'] = [prompts]
+        else:
+            generate_kwargs['messages'] = [prompts]
+
+        if 'functions' in kwargs:
+            generate_kwargs['functions'] = kwargs['functions']
+
+        return self._client.generate(**generate_kwargs)
 
 
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """
         """

+ 10 - 14
api/core/orchestrator_rule_parser.py

@@ -1,7 +1,6 @@
 import math
 import math
 from typing import Optional
 from typing import Optional
 
 
-from flask import current_app
 from langchain import WikipediaAPIWrapper
 from langchain import WikipediaAPIWrapper
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
@@ -27,7 +26,6 @@ from core.tool.web_reader_tool import WebReaderTool
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetProcessRule
 from models.dataset import Dataset, DatasetProcessRule
 from models.model import AppModelConfig
 from models.model import AppModelConfig
-from models.provider import ProviderType
 
 
 
 
 class OrchestratorRuleParser:
 class OrchestratorRuleParser:
@@ -77,7 +75,7 @@ class OrchestratorRuleParser:
             # only OpenAI chat model (include Azure) support function call, use ReACT instead
             # only OpenAI chat model (include Azure) support function call, use ReACT instead
             if agent_model_instance.model_mode != ModelMode.CHAT \
             if agent_model_instance.model_mode != ModelMode.CHAT \
                     or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
                     or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
-                if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
+                if planning_strategy == PlanningStrategy.FUNCTION_CALL:
                     planning_strategy = PlanningStrategy.REACT
                     planning_strategy = PlanningStrategy.REACT
                 elif planning_strategy == PlanningStrategy.ROUTER:
                 elif planning_strategy == PlanningStrategy.ROUTER:
                     planning_strategy = PlanningStrategy.REACT_ROUTER
                     planning_strategy = PlanningStrategy.REACT_ROUTER
@@ -207,7 +205,10 @@ class OrchestratorRuleParser:
                 tool = self.to_current_datetime_tool()
                 tool = self.to_current_datetime_tool()
 
 
             if tool:
             if tool:
-                tool.callbacks.extend(callbacks)
+                if tool.callbacks is not None:
+                    tool.callbacks.extend(callbacks)
+                else:
+                    tool.callbacks = callbacks
                 tools.append(tool)
                 tools.append(tool)
 
 
         return tools
         return tools
@@ -269,10 +270,9 @@ class OrchestratorRuleParser:
             summary_model_instance = None
             summary_model_instance = None
 
 
         tool = WebReaderTool(
         tool = WebReaderTool(
-            llm=summary_model_instance.client if summary_model_instance else None,
+            model_instance=summary_model_instance if summary_model_instance else None,
             max_chunk_length=4000,
             max_chunk_length=4000,
-            continue_reading=True,
-            callbacks=[DifyStdOutCallbackHandler()]
+            continue_reading=True
         )
         )
 
 
         return tool
         return tool
@@ -290,16 +290,13 @@ class OrchestratorRuleParser:
                         "is not up to date. "
                         "is not up to date. "
                         "Input should be a search query.",
                         "Input should be a search query.",
             func=OptimizedSerpAPIWrapper(**func_kwargs).run,
             func=OptimizedSerpAPIWrapper(**func_kwargs).run,
-            args_schema=OptimizedSerpAPIInput,
-            callbacks=[DifyStdOutCallbackHandler()]
+            args_schema=OptimizedSerpAPIInput
         )
         )
 
 
         return tool
         return tool
 
 
     def to_current_datetime_tool(self) -> Optional[BaseTool]:
     def to_current_datetime_tool(self) -> Optional[BaseTool]:
-        tool = DatetimeTool(
-            callbacks=[DifyStdOutCallbackHandler()]
-        )
+        tool = DatetimeTool()
 
 
         return tool
         return tool
 
 
@@ -310,8 +307,7 @@ class OrchestratorRuleParser:
         return WikipediaQueryRun(
         return WikipediaQueryRun(
             name="wikipedia",
             name="wikipedia",
             api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
             api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
-            args_schema=WikipediaInput,
-            callbacks=[DifyStdOutCallbackHandler()]
+            args_schema=WikipediaInput
         )
         )
 
 
     @classmethod
     @classmethod

+ 24 - 6
api/core/tool/web_reader_tool.py

@@ -11,8 +11,8 @@ from typing import Type
 
 
 import requests
 import requests
 from bs4 import BeautifulSoup, NavigableString, Comment, CData
 from bs4 import BeautifulSoup, NavigableString, Comment, CData
-from langchain.base_language import BaseLanguageModel
-from langchain.chains.summarize import load_summarize_chain
+from langchain.chains import RefineDocumentsChain
+from langchain.chains.summarize import refine_prompts
 from langchain.schema import Document
 from langchain.schema import Document
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.tools.base import BaseTool
 from langchain.tools.base import BaseTool
@@ -20,8 +20,10 @@ from newspaper import Article
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from regex import regex
 from regex import regex
 
 
+from core.chain.llm_chain import LLMChain
 from core.data_loader import file_extractor
 from core.data_loader import file_extractor
 from core.data_loader.file_extractor import FileExtractor
 from core.data_loader.file_extractor import FileExtractor
+from core.model_providers.models.llm.base import BaseLLM
 
 
 FULL_TEMPLATE = """
 FULL_TEMPLATE = """
 TITLE: {title}
 TITLE: {title}
@@ -65,7 +67,7 @@ class WebReaderTool(BaseTool):
     summary_chunk_overlap: int = 0
     summary_chunk_overlap: int = 0
     summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
     summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
     continue_reading: bool = True
     continue_reading: bool = True
-    llm: BaseLanguageModel = None
+    model_instance: BaseLLM = None
 
 
     def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
     def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
         try:
         try:
@@ -78,7 +80,7 @@ class WebReaderTool(BaseTool):
         except Exception as e:
         except Exception as e:
             return f'Read this website failed, caused by: {str(e)}.'
             return f'Read this website failed, caused by: {str(e)}.'
 
 
-        if summary and self.llm:
+        if summary and self.model_instance:
             character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
             character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                 chunk_size=self.summary_chunk_tokens,
                 chunk_size=self.summary_chunk_tokens,
                 chunk_overlap=self.summary_chunk_overlap,
                 chunk_overlap=self.summary_chunk_overlap,
@@ -95,10 +97,9 @@ class WebReaderTool(BaseTool):
             if len(docs) > 5:
             if len(docs) > 5:
                 docs = docs[:5]
                 docs = docs[:5]
 
 
-            chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
+            chain = self.get_summary_chain()
             try:
             try:
                 page_contents = chain.run(docs)
                 page_contents = chain.run(docs)
-                # todo use cache
             except Exception as e:
             except Exception as e:
                 return f'Read this website failed, caused by: {str(e)}.'
                 return f'Read this website failed, caused by: {str(e)}.'
         else:
         else:
@@ -114,6 +115,23 @@ class WebReaderTool(BaseTool):
     async def _arun(self, url: str) -> str:
     async def _arun(self, url: str) -> str:
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def get_summary_chain(self) -> RefineDocumentsChain:
+        initial_chain = LLMChain(
+            model_instance=self.model_instance,
+            prompt=refine_prompts.PROMPT
+        )
+        refine_chain = LLMChain(
+            model_instance=self.model_instance,
+            prompt=refine_prompts.REFINE_PROMPT
+        )
+        return RefineDocumentsChain(
+            initial_llm_chain=initial_chain,
+            refine_llm_chain=refine_chain,
+            document_variable_name="text",
+            initial_response_name="existing_answer",
+            callbacks=self.callbacks
+        )
+
 
 
 def page_result(text: str, cursor: int, max_length: int) -> str:
 def page_result(text: str, cursor: int, max_length: int) -> str:
     """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
     """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""