소스 검색

Fix/agent react output parser (#2689)

Yeuoly 1 년 전
부모
커밋
38e5952417
2개의 변경된 파일35개의 추가작업 그리고 12개의 파일을 삭제
  1. 23 11
      api/core/features/assistant_cot_runner.py
  2. 12 1
      api/core/tools/tool/tool.py

+ 23 - 11
api/core/features/assistant_cot_runner.py

@@ -28,6 +28,9 @@ from models.model import Conversation, Message
 
 
 class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
+    _is_first_iteration = True
+    _ignore_observation_providers = ['wenxin']
+
     def run(self, conversation: Conversation,
         message: Message,
         query: str,
@@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
         agent_scratchpad: list[AgentScratchpadUnit] = []
         self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
 
-        # check model mode
-        if self.app_orchestration_config.model_config.mode == "completion":
-            # TODO: stop words
-            if 'Observation' not in app_orchestration_config.model_config.stop:
+        if 'Observation' not in app_orchestration_config.model_config.stop:
+            if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
                 app_orchestration_config.model_config.stop.append('Observation')
 
         # override inputs
@@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                         )
                     )
 
+            scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
             agent_scratchpad.append(scratchpad)
                         
             # get llm usage
@@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                         # invoke tool
                         error_response = None
                         try:
+                            if isinstance(tool_call_args, str):
+                                try:
+                                    tool_call_args = json.loads(tool_call_args)
+                                except json.JSONDecodeError:
+                                    pass
+                            
                             tool_response = tool_instance.invoke(
                                 user_id=self.user_id, 
-                                tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
+                                tool_parameters=tool_call_args
                             )
                             # transform tool response to llm friendly response
                             tool_response = self.transform_tool_invoke_messages(tool_response)
@@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
             if isinstance(message, AssistantPromptMessage):
                 current_scratchpad = AgentScratchpadUnit(
                     agent_response=message.content,
-                    thought=message.content,
+                    thought=message.content or 'I am thinking about how to help you',
                     action_str='',
                     action=None,
                     observation=None,
@@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
 
         result = ''
         for scratchpad in agent_scratchpad:
-            result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
+            result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
+                next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
 
         return result
     
@@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                 ))
 
             # add assistant message
-            if len(agent_scratchpad) > 0:
+            if len(agent_scratchpad) > 0 and not self._is_first_iteration:
                 prompt_messages.append(AssistantPromptMessage(
-                    content=(agent_scratchpad[-1].thought or '')
+                    content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
                 ))
             
             # add user message
-            if len(agent_scratchpad) > 0:
+            if len(agent_scratchpad) > 0 and not self._is_first_iteration:
                 prompt_messages.append(UserPromptMessage(
-                    content=(agent_scratchpad[-1].observation or ''),
+                    content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
                 ))
 
+            self._is_first_iteration = False
+
             return prompt_messages
         elif mode == "completion":
             # parse agent scratchpad
             agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
+            self._is_first_iteration = False
             # parse prompt messages
             return [UserPromptMessage(
                 content=first_prompt.replace("{{instruction}}", instruction)

+ 12 - 1
api/core/tools/tool/tool.py

@@ -174,7 +174,18 @@ class Tool(BaseModel, ABC):
 
         return result
 
-    def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
+    def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
+        # check if tool_parameters is a string
+        if isinstance(tool_parameters, str):
+            # check if this tool has only one parameter
+            parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
+            if parameters and len(parameters) == 1:
+                tool_parameters = {
+                    parameters[0].name: tool_parameters
+                }
+            else:
+                raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
+
         # update tool_parameters
         if self.runtime.runtime_parameters:
             tool_parameters.update(self.runtime.runtime_parameters)