浏览代码

fix: got unknown type of prompt message in multi-round ReAct agent chat (#5245)

sino 10 月之前
父节点
当前提交
edffa5666d
共有 1 个文件被更改,包括 45 次插入36 次删除
  1. 45 36
      api/core/agent/cot_agent_runner.py

+ 45 - 36
api/core/agent/cot_agent_runner.py

@@ -32,9 +32,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
     _prompt_messages_tools: list[PromptMessage] = None
 
     def run(self, message: Message,
-        query: str,
-        inputs: dict[str, str],
-    ) -> Union[Generator, LLMResult]:
+            query: str,
+            inputs: dict[str, str],
+            ) -> Union[Generator, LLMResult]:
         """
         Run Cot agent application
         """
@@ -52,7 +52,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         # init instruction
         inputs = inputs or {}
         instruction = app_config.prompt_template.simple_prompt_template
-        self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
+        self._instruction = self._fill_in_inputs_from_external_data_tools(
+            instruction, inputs)
 
         iteration_step = 1
         max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@@ -61,7 +62,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
 
         prompt_messages = self._organize_prompt_messages()
-        
+
         function_call_state = True
         llm_usage = {
             'usage': None
@@ -120,9 +121,10 @@ class CotAgentRunner(BaseAgentRunner, ABC):
             # check llm result
             if not chunks:
                 raise ValueError("failed to invoke llm")
-            
+
             usage_dict = {}
-            react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
+            react_chunks = CotAgentOutputParser.handle_react_stream_output(
+                chunks, usage_dict)
             scratchpad = AgentScratchpadUnit(
                 agent_response='',
                 thought='',
@@ -160,15 +162,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                         )
                     )
 
-            scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
+            scratchpad.thought = scratchpad.thought.strip(
+            ) or 'I am thinking about how to help you'
             self._agent_scratchpad.append(scratchpad)
-            
+
             # get llm usage
             if 'usage' in usage_dict:
                 increase_usage(llm_usage, usage_dict['usage'])
             else:
                 usage_dict['usage'] = LLMUsage.empty_usage()
-            
+
             self.save_agent_thought(
                 agent_thought=agent_thought,
                 tool_name=scratchpad.action.action_name if scratchpad.action else '',
@@ -182,7 +185,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                 messages_ids=[],
                 llm_usage=usage_dict['usage']
             )
-            
+
             if not scratchpad.is_final():
                 self.queue_manager.publish(QueueAgentThoughtEvent(
                     agent_thought_id=agent_thought.id
@@ -196,7 +199,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     # action is final answer, return final answer directly
                     try:
                         if isinstance(scratchpad.action.action_input, dict):
-                            final_answer = json.dumps(scratchpad.action.action_input)
+                            final_answer = json.dumps(
+                                scratchpad.action.action_input)
                         elif isinstance(scratchpad.action.action_input, str):
                             final_answer = scratchpad.action.action_input
                         else:
@@ -207,7 +211,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     function_call_state = True
                     # action is tool call, invoke tool
                     tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
-                        action=scratchpad.action, 
+                        action=scratchpad.action,
                         tool_instances=tool_instances,
                         message_file_ids=message_file_ids
                     )
@@ -217,10 +221,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     self.save_agent_thought(
                         agent_thought=agent_thought,
                         tool_name=scratchpad.action.action_name,
-                        tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
+                        tool_input={
+                            scratchpad.action.action_name: scratchpad.action.action_input},
                         thought=scratchpad.thought,
-                        observation={scratchpad.action.action_name: tool_invoke_response},
-                        tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
+                        observation={
+                            scratchpad.action.action_name: tool_invoke_response},
+                        tool_invoke_meta={
+                            scratchpad.action.action_name: tool_invoke_meta.to_dict()},
                         answer=scratchpad.agent_response,
                         messages_ids=message_file_ids,
                         llm_usage=usage_dict['usage']
@@ -232,7 +239,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
 
                 # update prompt tool message
                 for prompt_tool in self._prompt_messages_tools:
-                    self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
+                    self.update_prompt_message_tool(
+                        tool_instances[prompt_tool.name], prompt_tool)
 
             iteration_step += 1
 
@@ -251,12 +259,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
 
         # save agent thought
         self.save_agent_thought(
-            agent_thought=agent_thought, 
+            agent_thought=agent_thought,
             tool_name='',
             tool_input={},
             tool_invoke_meta={},
             thought=final_answer,
-            observation={}, 
+            observation={},
             answer=final_answer,
             messages_ids=[]
         )
@@ -269,11 +277,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
             message=AssistantPromptMessage(
                 content=final_answer
             ),
-            usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
+            usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(
+            ),
             system_fingerprint=''
         )), PublishFrom.APPLICATION_MANAGER)
 
-    def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, 
+    def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
                               tool_instances: dict[str, Tool],
                               message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
         """
@@ -290,7 +299,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         if not tool_instance:
             answer = f"there is not a tool named {tool_call_name}"
             return answer, ToolInvokeMeta.error_instance(answer)
-        
+
         if isinstance(tool_call_args, str):
             try:
                 tool_call_args = json.loads(tool_call_args)
@@ -311,7 +320,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         # publish files
         for message_file, save_as in message_files:
             if save_as:
-                self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
+                self.variables_pool.set_file(
+                    tool_name=tool_call_name, value=message_file.id, name=save_as)
 
             # publish message file
             self.queue_manager.publish(QueueMessageFileEvent(
@@ -342,7 +352,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                 continue
 
         return instruction
-    
+
     def _init_react_state(self, query) -> None:
         """
         init agent scratchpad
@@ -350,7 +360,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         self._query = query
         self._agent_scratchpad = []
         self._historic_prompt_messages = self._organize_historic_prompt_messages()
-    
+
     @abstractmethod
     def _organize_prompt_messages(self) -> list[PromptMessage]:
         """
@@ -382,13 +392,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         scratchpads: list[AgentScratchpadUnit] = []
         current_scratchpad: AgentScratchpadUnit = None
 
-        self.history_prompt_messages = AgentHistoryPromptTransform(
-            model_config=self.model_config,
-            prompt_messages=current_session_messages or [],
-            history_messages=self.history_prompt_messages,
-            memory=self.memory
-        ).get_prompt()
-
         for message in self.history_prompt_messages:
             if isinstance(message, AssistantPromptMessage):
                 if not current_scratchpad:
@@ -404,7 +407,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     try:
                         current_scratchpad.action = AgentScratchpadUnit.Action(
                             action_name=message.tool_calls[0].function.name,
-                            action_input=json.loads(message.tool_calls[0].function.arguments)
+                            action_input=json.loads(
+                                message.tool_calls[0].function.arguments)
                         )
                         current_scratchpad.action_str = json.dumps(
                             current_scratchpad.action.to_dict()
@@ -424,10 +428,15 @@ class CotAgentRunner(BaseAgentRunner, ABC):
 
                 result.append(message)
 
-
         if scratchpads:
             result.append(AssistantPromptMessage(
                 content=self._format_assistant_message(scratchpads)
             ))
-        
-        return result
+
+        historic_prompts = AgentHistoryPromptTransform(
+            model_config=self.model_config,
+            prompt_messages=current_session_messages or [],
+            history_messages=result,
+            memory=self.memory
+        ).get_prompt()
+        return historic_prompts