Browse Source

fix: convert tool messages into user messages in react mode and fill … (#2584)

Yeuoly 1 year ago
parent
commit
3a34370422
2 changed files with 46 additions and 33 deletions
  1. 36 30
      api/core/features/assistant_base_runner.py
  2. 10 3
      api/core/features/assistant_cot_runner.py

+ 36 - 30
api/core/features/assistant_base_runner.py

@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
         for message in messages:
             result.append(UserPromptMessage(content=message.query))
             agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
-            for agent_thought in agent_thoughts:
-                tools = agent_thought.tool
-                if tools:
-                    tools = tools.split(';')
-                    tool_calls: list[AssistantPromptMessage.ToolCall] = []
-                    tool_call_response: list[ToolPromptMessage] = []
-                    tool_inputs = json.loads(agent_thought.tool_input)
-                    for tool in tools:
-                        # generate a uuid for tool call
-                        tool_call_id = str(uuid.uuid4())
-                        tool_calls.append(AssistantPromptMessage.ToolCall(
-                            id=tool_call_id,
-                            type='function',
-                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+            if agent_thoughts:
+                for agent_thought in agent_thoughts:
+                    tools = agent_thought.tool
+                    if tools:
+                        tools = tools.split(';')
+                        tool_calls: list[AssistantPromptMessage.ToolCall] = []
+                        tool_call_response: list[ToolPromptMessage] = []
+                        tool_inputs = json.loads(agent_thought.tool_input)
+                        for tool in tools:
+                            # generate a uuid for tool call
+                            tool_call_id = str(uuid.uuid4())
+                            tool_calls.append(AssistantPromptMessage.ToolCall(
+                                id=tool_call_id,
+                                type='function',
+                                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                    name=tool,
+                                    arguments=json.dumps(tool_inputs.get(tool, {})),
+                                )
+                            ))
+                            tool_call_response.append(ToolPromptMessage(
+                                content=agent_thought.observation,
                                 name=tool,
-                                arguments=json.dumps(tool_inputs.get(tool, {})),
-                            )
-                        ))
-                        tool_call_response.append(ToolPromptMessage(
-                            content=agent_thought.observation,
-                            name=tool,
-                            tool_call_id=tool_call_id,
-                        ))
-
-                    result.extend([
-                        AssistantPromptMessage(
-                            content=agent_thought.thought,
-                            tool_calls=tool_calls,
-                        ),
-                        *tool_call_response
-                    ])
+                                tool_call_id=tool_call_id,
+                            ))
+
+                        result.extend([
+                            AssistantPromptMessage(
+                                content=agent_thought.thought,
+                                tool_calls=tool_calls,
+                            ),
+                            *tool_call_response
+                        ])
+                    if not tools:
+                        result.append(AssistantPromptMessage(content=agent_thought.thought))
+            else:
+                if message.answer:
+                    result.append(AssistantPromptMessage(content=message.answer))
 
         return result

+ 10 - 3
api/core/features/assistant_cot_runner.py

@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                 thought='',
                 action_str='',
                 observation='',
-                action=None
+                action=None,
             )
 
             # publish agent thought if it's first iteration
@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                     thought=message.content,
                     action_str='',
                     action=None,
-                    observation=None
+                    observation=None,
                 )
                 if message.tool_calls:
                     try:
@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
             elif isinstance(message, ToolPromptMessage):
                 if current_scratchpad:
                     current_scratchpad.observation = message.content
-
+        
         return agent_scratchpad
 
     def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], 
@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                     prompt_message.content = system_message
                     overridden = True
                     break
+            
+            # convert tool prompt messages to user prompt messages
+            for idx, prompt_message in enumerate(prompt_messages):
+                if isinstance(prompt_message, ToolPromptMessage):
+                    prompt_messages[idx] = UserPromptMessage(
+                        content=prompt_message.content
+                    )
 
             if not overridden:
                 prompt_messages.insert(0, SystemPromptMessage(