Selaa lähdekoodia

feat: optimize completion model agent (#1364)

takatost 1 vuosi sitten
vanhempi
commit
07285e5f8b

+ 1 - 1
api/core/agent/agent/multi_dataset_router_agent.py

@@ -76,7 +76,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
             if isinstance(agent_decision, AgentAction):
                 tool_inputs = agent_decision.tool_input
-                if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
+                if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
                     tool_inputs['query'] = kwargs['input']
                     agent_decision.tool_input = tool_inputs
             else:

+ 80 - 10
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -1,7 +1,7 @@
 import re
 from typing import List, Tuple, Any, Union, Sequence, Optional, cast
 
-from langchain import BasePromptTemplate
+from langchain import BasePromptTemplate, PromptTemplate
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
 from langchain.callbacks.base import BaseCallbackManager
@@ -12,6 +12,7 @@ from langchain.tools import BaseTool
 from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 from core.chain.llm_chain import LLMChain
+from core.model_providers.models.entity.model_params import ModelMode
 from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
@@ -92,6 +93,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
             rst = tool.run(tool_input={'query': kwargs['input']})
             return AgentFinish(return_values={"output": rst}, log=rst)
 
+        if intermediate_steps:
+            _, observation = intermediate_steps[-1]
+            return AgentFinish(return_values={"output": observation}, log=observation)
+
         full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
 
         try:
@@ -107,6 +112,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
                 if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
                     tool_inputs['query'] = kwargs['input']
                     agent_decision.tool_input = tool_inputs
+                elif isinstance(tool_inputs, str):
+                    agent_decision.tool_input = kwargs['input']
             else:
                 agent_decision.return_values['output'] = ''
             return agent_decision
@@ -143,6 +150,61 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
         ]
         return ChatPromptTemplate(input_variables=input_variables, messages=messages)
 
+    @classmethod
+    def create_completion_prompt(
+            cls,
+            tools: Sequence[BaseTool],
+            prefix: str = PREFIX,
+            format_instructions: str = FORMAT_INSTRUCTIONS,
+            input_variables: Optional[List[str]] = None,
+    ) -> PromptTemplate:
+        """Create prompt in the style of the zero shot agent.
+
+        Args:
+            tools: List of tools the agent will have access to, used to format the
+                prompt.
+            prefix: String to put before the list of tools.
+            input_variables: List of input variables the final prompt will expect.
+
+        Returns:
+            A PromptTemplate with the template assembled from the pieces here.
+        """
+        suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
+Question: {input}
+Thought: {agent_scratchpad}
+"""
+
+        tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
+        tool_names = ", ".join([tool.name for tool in tools])
+        format_instructions = format_instructions.format(tool_names=tool_names)
+        template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
+        if input_variables is None:
+            input_variables = ["input", "agent_scratchpad"]
+        return PromptTemplate(template=template, input_variables=input_variables)
+
+    def _construct_scratchpad(
+        self, intermediate_steps: List[Tuple[AgentAction, str]]
+    ) -> str:
+        agent_scratchpad = ""
+        for action, observation in intermediate_steps:
+            agent_scratchpad += action.log
+            agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
+
+        if not isinstance(agent_scratchpad, str):
+            raise ValueError("agent_scratchpad should be of type string.")
+        if agent_scratchpad:
+            llm_chain = cast(LLMChain, self.llm_chain)
+            if llm_chain.model_instance.model_mode == ModelMode.CHAT:
+                return (
+                    f"This was your previous work "
+                    f"(but I haven't seen any of it! I only see what "
+                    f"you return as final answer):\n{agent_scratchpad}"
+                )
+            else:
+                return agent_scratchpad
+        else:
+            return agent_scratchpad
+
     @classmethod
     def from_llm_and_tools(
             cls,
@@ -160,15 +222,23 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
     ) -> Agent:
         """Construct an agent from an LLM and tools."""
         cls._validate_tools(tools)
-        prompt = cls.create_prompt(
-            tools,
-            prefix=prefix,
-            suffix=suffix,
-            human_message_template=human_message_template,
-            format_instructions=format_instructions,
-            input_variables=input_variables,
-            memory_prompts=memory_prompts,
-        )
+        if model_instance.model_mode == ModelMode.CHAT:
+            prompt = cls.create_prompt(
+                tools,
+                prefix=prefix,
+                suffix=suffix,
+                human_message_template=human_message_template,
+                format_instructions=format_instructions,
+                input_variables=input_variables,
+                memory_prompts=memory_prompts,
+            )
+        else:
+            prompt = cls.create_completion_prompt(
+                tools,
+                prefix=prefix,
+                format_instructions=format_instructions,
+                input_variables=input_variables
+            )
         llm_chain = LLMChain(
             model_instance=model_instance,
             prompt=prompt,

+ 75 - 11
api/core/agent/agent/structured_chat.py

@@ -1,7 +1,7 @@
 import re
-from typing import List, Tuple, Any, Union, Sequence, Optional
+from typing import List, Tuple, Any, Union, Sequence, Optional, cast
 
-from langchain import BasePromptTemplate
+from langchain import BasePromptTemplate, PromptTemplate
 from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
 from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
 from langchain.callbacks.base import BaseCallbackManager
@@ -15,6 +15,7 @@ from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
 
 from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
 from core.chain.llm_chain import LLMChain
+from core.model_providers.models.entity.model_params import ModelMode
 from core.model_providers.models.llm.base import BaseLLM
 
 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -184,6 +185,61 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
         ]
         return ChatPromptTemplate(input_variables=input_variables, messages=messages)
 
+    @classmethod
+    def create_completion_prompt(
+            cls,
+            tools: Sequence[BaseTool],
+            prefix: str = PREFIX,
+            format_instructions: str = FORMAT_INSTRUCTIONS,
+            input_variables: Optional[List[str]] = None,
+    ) -> PromptTemplate:
+        """Create prompt in the style of the zero shot agent.
+
+        Args:
+            tools: List of tools the agent will have access to, used to format the
+                prompt.
+            prefix: String to put before the list of tools.
+            input_variables: List of input variables the final prompt will expect.
+
+        Returns:
+            A PromptTemplate with the template assembled from the pieces here.
+        """
+        suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
+Question: {input}
+Thought: {agent_scratchpad}
+"""
+
+        tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
+        tool_names = ", ".join([tool.name for tool in tools])
+        format_instructions = format_instructions.format(tool_names=tool_names)
+        template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
+        if input_variables is None:
+            input_variables = ["input", "agent_scratchpad"]
+        return PromptTemplate(template=template, input_variables=input_variables)
+
+    def _construct_scratchpad(
+        self, intermediate_steps: List[Tuple[AgentAction, str]]
+    ) -> str:
+        agent_scratchpad = ""
+        for action, observation in intermediate_steps:
+            agent_scratchpad += action.log
+            agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
+
+        if not isinstance(agent_scratchpad, str):
+            raise ValueError("agent_scratchpad should be of type string.")
+        if agent_scratchpad:
+            llm_chain = cast(LLMChain, self.llm_chain)
+            if llm_chain.model_instance.model_mode == ModelMode.CHAT:
+                return (
+                    f"This was your previous work "
+                    f"(but I haven't seen any of it! I only see what "
+                    f"you return as final answer):\n{agent_scratchpad}"
+                )
+            else:
+                return agent_scratchpad
+        else:
+            return agent_scratchpad
+
     @classmethod
     def from_llm_and_tools(
             cls,
@@ -201,15 +257,23 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     ) -> Agent:
         """Construct an agent from an LLM and tools."""
         cls._validate_tools(tools)
-        prompt = cls.create_prompt(
-            tools,
-            prefix=prefix,
-            suffix=suffix,
-            human_message_template=human_message_template,
-            format_instructions=format_instructions,
-            input_variables=input_variables,
-            memory_prompts=memory_prompts,
-        )
+        if model_instance.model_mode == ModelMode.CHAT:
+            prompt = cls.create_prompt(
+                tools,
+                prefix=prefix,
+                suffix=suffix,
+                human_message_template=human_message_template,
+                format_instructions=format_instructions,
+                input_variables=input_variables,
+                memory_prompts=memory_prompts,
+            )
+        else:
+            prompt = cls.create_completion_prompt(
+                tools,
+                prefix=prefix,
+                format_instructions=format_instructions,
+                input_variables=input_variables,
+            )
         llm_chain = LLMChain(
             model_instance=model_instance,
             prompt=prompt,