Explorar el Código

Refactor/react agent (#3355)

Yeuoly hace 1 año
padre
commit
cea107b165

+ 29 - 1
api/core/agent/base_agent_runner.py

@@ -238,6 +238,34 @@ class BaseAgentRunner(AppRunner):
 
         return prompt_tool
     
+    def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
+        """
+        Init tools
+        """
+        tool_instances = {}
+        prompt_messages_tools = []
+
+        for tool in self.app_config.agent.tools if self.app_config.agent else []:
+            try:
+                prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
+            except Exception:
+                # api tool may be deleted
+                continue
+            # save tool entity
+            tool_instances[tool.tool_name] = tool_entity
+            # save prompt tool
+            prompt_messages_tools.append(prompt_tool)
+
+        # convert dataset tools into ModelRuntime Tool format
+        for dataset_tool in self.dataset_tools:
+            prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
+            # save prompt tool
+            prompt_messages_tools.append(prompt_tool)
+            # save tool entity
+            tool_instances[dataset_tool.identity.name] = dataset_tool
+
+        return tool_instances, prompt_messages_tools
+
     def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
         """
         update prompt message tool
@@ -325,7 +353,7 @@ class BaseAgentRunner(AppRunner):
                            tool_name: str,
                            tool_input: Union[str, dict],
                            thought: str, 
-                           observation: Union[str, str], 
+                           observation: Union[str, dict], 
                            tool_invoke_meta: Union[str, dict],
                            answer: str,
                            messages_ids: list[str],

+ 180 - 445
api/core/agent/cot_agent_runner.py

@@ -1,30 +1,34 @@
 import json
-import re
+from abc import ABC, abstractmethod
 from collections.abc import Generator
-from typing import Literal, Union
+from typing import Union
 
 from core.agent.base_agent_runner import BaseAgentRunner
-from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
+from core.agent.entities import AgentScratchpadUnit
+from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
 from core.app.apps.base_app_queue_manager import PublishFrom
 from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
-    PromptMessageTool,
-    SystemPromptMessage,
     ToolPromptMessage,
     UserPromptMessage,
 )
-from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.tool_entities import ToolInvokeMeta
+from core.tools.tool.tool import Tool
 from core.tools.tool_engine import ToolEngine
 from models.model import Message
 
 
-class CotAgentRunner(BaseAgentRunner):
+class CotAgentRunner(BaseAgentRunner, ABC):
     _is_first_iteration = True
     _ignore_observation_providers = ['wenxin']
+    _historic_prompt_messages: list[PromptMessage] = None
+    _agent_scratchpad: list[AgentScratchpadUnit] = None
+    _instruction: str = None
+    _query: str = None
+    _prompt_messages_tools: list[PromptMessage] = None
 
     def run(self, message: Message,
         query: str,
@@ -35,9 +39,7 @@ class CotAgentRunner(BaseAgentRunner):
         """
         app_generate_entity = self.application_generate_entity
         self._repack_app_generate_entity(app_generate_entity)
-
-        agent_scratchpad: list[AgentScratchpadUnit] = []
-        self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
+        self._init_react_state(query)
 
         # check model mode
         if 'Observation' not in app_generate_entity.model_config.stop:
@@ -46,38 +48,19 @@ class CotAgentRunner(BaseAgentRunner):
 
         app_config = self.app_config
 
-        # override inputs
+        # init instruction
         inputs = inputs or {}
         instruction = app_config.prompt_template.simple_prompt_template
-        instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
+        self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
 
         iteration_step = 1
         max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
 
-        prompt_messages = self.history_prompt_messages
-
         # convert tools into ModelRuntime Tool format
-        prompt_messages_tools: list[PromptMessageTool] = []
-        tool_instances = {}
-        for tool in app_config.agent.tools if app_config.agent else []:
-            try:
-                prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
-            except Exception:
-                # api tool may be deleted
-                continue
-            # save tool entity
-            tool_instances[tool.tool_name] = tool_entity
-            # save prompt tool
-            prompt_messages_tools.append(prompt_tool)
-
-        # convert dataset tools into ModelRuntime Tool format
-        for dataset_tool in self.dataset_tools:
-            prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
-            # save prompt tool
-            prompt_messages_tools.append(prompt_tool)
-            # save tool entity
-            tool_instances[dataset_tool.identity.name] = dataset_tool
+        tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
 
+        prompt_messages = self._organize_prompt_messages()
+        
         function_call_state = True
         llm_usage = {
             'usage': None
@@ -102,7 +85,7 @@ class CotAgentRunner(BaseAgentRunner):
 
             if iteration_step == max_iteration_steps:
                 # the last iteration, remove all tools
-                prompt_messages_tools = []
+                self._prompt_messages_tools = []
 
             message_file_ids = []
 
@@ -119,18 +102,8 @@ class CotAgentRunner(BaseAgentRunner):
                     agent_thought_id=agent_thought.id
                 ), PublishFrom.APPLICATION_MANAGER)
 
-            # update prompt messages
-            prompt_messages = self._organize_cot_prompt_messages(
-                mode=app_generate_entity.model_config.mode,
-                prompt_messages=prompt_messages,
-                tools=prompt_messages_tools,
-                agent_scratchpad=agent_scratchpad,
-                agent_prompt_message=app_config.agent.prompt,
-                instruction=instruction,
-                input=query
-            )
-
             # recalc llm max tokens
+            prompt_messages = self._organize_prompt_messages()
             self.recalc_llm_max_tokens(self.model_config, prompt_messages)
             # invoke model
             chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
@@ -148,7 +121,7 @@ class CotAgentRunner(BaseAgentRunner):
                 raise ValueError("failed to invoke llm")
             
             usage_dict = {}
-            react_chunks = self._handle_stream_react(chunks, usage_dict)
+            react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks)
             scratchpad = AgentScratchpadUnit(
                 agent_response='',
                 thought='',
@@ -164,30 +137,12 @@ class CotAgentRunner(BaseAgentRunner):
                 ), PublishFrom.APPLICATION_MANAGER)
 
             for chunk in react_chunks:
-                if isinstance(chunk, dict):
-                    scratchpad.agent_response += json.dumps(chunk)
-                    try:
-                        if scratchpad.action:
-                            raise Exception("")
-                        scratchpad.action_str = json.dumps(chunk)
-                        scratchpad.action = AgentScratchpadUnit.Action(
-                            action_name=chunk['action'],
-                            action_input=chunk['action_input']
-                        )
-                    except:
-                        scratchpad.thought += json.dumps(chunk)
-                        yield LLMResultChunk(
-                            model=self.model_config.model,
-                            prompt_messages=prompt_messages,
-                            system_fingerprint='',
-                            delta=LLMResultChunkDelta(
-                                index=0,
-                                message=AssistantPromptMessage(
-                                    content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text
-                                ),
-                                usage=None
-                            )
-                        )
+                if isinstance(chunk, AgentScratchpadUnit.Action):
+                    action = chunk
+                    # detect action
+                    scratchpad.agent_response += json.dumps(chunk.dict())
+                    scratchpad.action_str = json.dumps(chunk.dict())
+                    scratchpad.action = action
                 else:
                     scratchpad.agent_response += chunk
                     scratchpad.thought += chunk
@@ -205,27 +160,29 @@ class CotAgentRunner(BaseAgentRunner):
                     )
 
             scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
-            agent_scratchpad.append(scratchpad)
-                        
+            self._agent_scratchpad.append(scratchpad)
+            
             # get llm usage
             if 'usage' in usage_dict:
                 increase_usage(llm_usage, usage_dict['usage'])
             else:
                 usage_dict['usage'] = LLMUsage.empty_usage()
             
-            self.save_agent_thought(agent_thought=agent_thought,
-                                    tool_name=scratchpad.action.action_name if scratchpad.action else '',
-                                    tool_input={
-                                        scratchpad.action.action_name: scratchpad.action.action_input
-                                    } if scratchpad.action else '',
-                                    tool_invoke_meta={},
-                                    thought=scratchpad.thought,
-                                    observation='',
-                                    answer=scratchpad.agent_response,
-                                    messages_ids=[],
-                                    llm_usage=usage_dict['usage'])
+            self.save_agent_thought(
+                agent_thought=agent_thought,
+                tool_name=scratchpad.action.action_name if scratchpad.action else '',
+                tool_input={
+                    scratchpad.action.action_name: scratchpad.action.action_input
+                } if scratchpad.action else {},
+                tool_invoke_meta={},
+                thought=scratchpad.thought,
+                observation='',
+                answer=scratchpad.agent_response,
+                messages_ids=[],
+                llm_usage=usage_dict['usage']
+            )
             
-            if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
+            if not scratchpad.is_final():
                 self.queue_manager.publish(QueueAgentThoughtEvent(
                     agent_thought_id=agent_thought.id
                 ), PublishFrom.APPLICATION_MANAGER)
@@ -237,106 +194,43 @@ class CotAgentRunner(BaseAgentRunner):
                 if scratchpad.action.action_name.lower() == "final answer":
                     # action is final answer, return final answer directly
                     try:
-                        final_answer = scratchpad.action.action_input if \
-                            isinstance(scratchpad.action.action_input, str) else \
-                                json.dumps(scratchpad.action.action_input)
+                        if isinstance(scratchpad.action.action_input, dict):
+                            final_answer = json.dumps(scratchpad.action.action_input)
+                        elif isinstance(scratchpad.action.action_input, str):
+                            final_answer = scratchpad.action.action_input
+                        else:
+                            final_answer = f'{scratchpad.action.action_input}'
                     except json.JSONDecodeError:
                         final_answer = f'{scratchpad.action.action_input}'
                 else:
                     function_call_state = True
-
                     # action is tool call, invoke tool
-                    tool_call_name = scratchpad.action.action_name
-                    tool_call_args = scratchpad.action.action_input
-                    tool_instance = tool_instances.get(tool_call_name)
-                    if not tool_instance:
-                        answer = f"there is not a tool named {tool_call_name}"
-                        self.save_agent_thought(
-                            agent_thought=agent_thought, 
-                            tool_name='',
-                            tool_input='',
-                            tool_invoke_meta=ToolInvokeMeta.error_instance(
-                                f"there is not a tool named {tool_call_name}"
-                            ).to_dict(),
-                            thought=None, 
-                            observation={
-                                tool_call_name: answer
-                            }, 
-                            answer=answer,
-                            messages_ids=[]
-                        )
-                        self.queue_manager.publish(QueueAgentThoughtEvent(
-                            agent_thought_id=agent_thought.id
-                        ), PublishFrom.APPLICATION_MANAGER)
-                    else:
-                        if isinstance(tool_call_args, str):
-                            try:
-                                tool_call_args = json.loads(tool_call_args)
-                            except json.JSONDecodeError:
-                                pass
-
-                        # invoke tool
-                        tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
-                            tool=tool_instance,
-                            tool_parameters=tool_call_args,
-                            user_id=self.user_id,
-                            tenant_id=self.tenant_id,
-                            message=self.message,
-                            invoke_from=self.application_generate_entity.invoke_from,
-                            agent_tool_callback=self.agent_callback
-                        )
-                        # publish files
-                        for message_file, save_as in message_files:
-                            if save_as:
-                                self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
-
-                            # publish message file
-                            self.queue_manager.publish(QueueMessageFileEvent(
-                                message_file_id=message_file.id
-                            ), PublishFrom.APPLICATION_MANAGER)
-                            # add message file ids
-                            message_file_ids.append(message_file.id)
-
-                        # publish files
-                        for message_file, save_as in message_files:
-                            if save_as:
-                                self.variables_pool.set_file(tool_name=tool_call_name,
-                                                                value=message_file.id,
-                                                                name=save_as)
-                            self.queue_manager.publish(QueueMessageFileEvent(
-                                message_file_id=message_file.id
-                            ), PublishFrom.APPLICATION_MANAGER)
-
-                        message_file_ids = [message_file.id for message_file, _ in message_files]
-
-                        observation = tool_invoke_response
-
-                        # save scratchpad
-                        scratchpad.observation = observation
-
-                        # save agent thought
-                        self.save_agent_thought(
-                            agent_thought=agent_thought, 
-                            tool_name=tool_call_name,
-                            tool_input={
-                                tool_call_name: tool_call_args
-                            },
-                            tool_invoke_meta={
-                                tool_call_name: tool_invoke_meta.to_dict()
-                            },
-                            thought=None,
-                            observation={
-                                tool_call_name: observation
-                            }, 
-                            answer=scratchpad.agent_response,
-                            messages_ids=message_file_ids,
-                        )
-                        self.queue_manager.publish(QueueAgentThoughtEvent(
-                            agent_thought_id=agent_thought.id
-                        ), PublishFrom.APPLICATION_MANAGER)
+                    tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
+                        action=scratchpad.action, 
+                        tool_instances=tool_instances,
+                        message_file_ids=message_file_ids
+                    )
+                    scratchpad.observation = tool_invoke_response
+                    scratchpad.agent_response = tool_invoke_response
+
+                    self.save_agent_thought(
+                        agent_thought=agent_thought,
+                        tool_name=scratchpad.action.action_name,
+                        tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
+                        thought=scratchpad.thought,
+                        observation={scratchpad.action.action_name: tool_invoke_response},
+                        tool_invoke_meta=tool_invoke_meta.to_dict(),
+                        answer=scratchpad.agent_response,
+                        messages_ids=message_file_ids,
+                        llm_usage=usage_dict['usage']
+                    )
+
+                    self.queue_manager.publish(QueueAgentThoughtEvent(
+                        agent_thought_id=agent_thought.id
+                    ), PublishFrom.APPLICATION_MANAGER)
 
                 # update prompt tool message
-                for prompt_tool in prompt_messages_tools:
+                for prompt_tool in self._prompt_messages_tools:
                     self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
 
             iteration_step += 1
@@ -378,96 +272,63 @@ class CotAgentRunner(BaseAgentRunner):
             system_fingerprint=''
         )), PublishFrom.APPLICATION_MANAGER)
 
-    def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-        -> Generator[Union[str, dict], None, None]:
-        def parse_json(json_str):
+    def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, 
+                              tool_instances: dict[str, Tool],
+                              message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
+        """
+        handle invoke action
+        :param action: action
+        :param tool_instances: tool instances
+        :return: observation, meta
+        """
+        # action is tool call, invoke tool
+        tool_call_name = action.action_name
+        tool_call_args = action.action_input
+        tool_instance = tool_instances.get(tool_call_name)
+
+        if not tool_instance:
+            answer = f"there is not a tool named {tool_call_name}"
+            return answer, ToolInvokeMeta.error_instance(answer)
+        
+        if isinstance(tool_call_args, str):
             try:
-                return json.loads(json_str.strip())
-            except:
-                return json_str
-            
-        def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
-            code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
-            if not code_blocks:
-                return
-            for block in code_blocks:
-                json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
-                yield parse_json(json_text)
-            
-        code_block_cache = ''
-        code_block_delimiter_count = 0
-        in_code_block = False
-        json_cache = ''
-        json_quote_count = 0
-        in_json = False
-        got_json = False
-    
-        for response in llm_response:
-            response = response.delta.message.content
-            if not isinstance(response, str):
-                continue
+                tool_call_args = json.loads(tool_call_args)
+            except json.JSONDecodeError:
+                pass
+
+        # invoke tool
+        tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
+            tool=tool_instance,
+            tool_parameters=tool_call_args,
+            user_id=self.user_id,
+            tenant_id=self.tenant_id,
+            message=self.message,
+            invoke_from=self.application_generate_entity.invoke_from,
+            agent_tool_callback=self.agent_callback
+        )
 
-            # stream
-            index = 0
-            while index < len(response):
-                steps = 1
-                delta = response[index:index+steps]
-                if delta == '`':
-                    code_block_cache += delta
-                    code_block_delimiter_count += 1
-                else:
-                    if not in_code_block:
-                        if code_block_delimiter_count > 0:
-                            yield code_block_cache
-                        code_block_cache = ''
-                    else:
-                        code_block_cache += delta
-                    code_block_delimiter_count = 0
-
-                if code_block_delimiter_count == 3:
-                    if in_code_block:
-                        yield from extra_json_from_code_block(code_block_cache)
-                        code_block_cache = ''
-                        
-                    in_code_block = not in_code_block
-                    code_block_delimiter_count = 0
-
-                if not in_code_block:
-                    # handle single json
-                    if delta == '{':
-                        json_quote_count += 1
-                        in_json = True
-                        json_cache += delta
-                    elif delta == '}':
-                        json_cache += delta
-                        if json_quote_count > 0:
-                            json_quote_count -= 1
-                            if json_quote_count == 0:
-                                in_json = False
-                                got_json = True
-                                index += steps
-                                continue
-                    else:
-                        if in_json:
-                            json_cache += delta
-
-                    if got_json:
-                        got_json = False
-                        yield parse_json(json_cache)
-                        json_cache = ''
-                        json_quote_count = 0
-                        in_json = False
-                    
-                if not in_code_block and not in_json:
-                    yield delta.replace('`', '')
-
-                index += steps
-
-        if code_block_cache:
-            yield code_block_cache
-
-        if json_cache:
-            yield parse_json(json_cache)
+        # publish files
+        for message_file, save_as in message_files:
+            if save_as:
+                self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
+
+            # publish message file
+            self.queue_manager.publish(QueueMessageFileEvent(
+                message_file_id=message_file.id
+            ), PublishFrom.APPLICATION_MANAGER)
+            # add message file ids
+            message_file_ids.append(message_file.id)
+
+        return tool_invoke_response, tool_invoke_meta
+
+    def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
+        """
+        convert dict to action
+        """
+        return AgentScratchpadUnit.Action(
+            action_name=action['action'],
+            action_input=action['action_input']
+        )
 
     def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
         """
@@ -481,15 +342,46 @@ class CotAgentRunner(BaseAgentRunner):
 
         return instruction
     
-    def _init_agent_scratchpad(self, 
-                               agent_scratchpad: list[AgentScratchpadUnit],
-                               messages: list[PromptMessage]
-                               ) -> list[AgentScratchpadUnit]:
+    def _init_react_state(self, query) -> None:
         """
         init agent scratchpad
         """
+        self._query = query
+        self._agent_scratchpad = []
+        self._historic_prompt_messages = self._organize_historic_prompt_messages()
+    
+    @abstractmethod
+    def _organize_prompt_messages(self) -> list[PromptMessage]:
+        """
+            organize prompt messages
+        """
+
+    def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
+        """
+            format assistant message
+        """
+        message = ''
+        for scratchpad in agent_scratchpad:
+            if scratchpad.is_final():
+                message += f"Final Answer: {scratchpad.agent_response}"
+            else:
+                message += f"Thought: {scratchpad.thought}\n\n"
+                if scratchpad.action_str:
+                    message += f"Action: {scratchpad.action_str}\n\n"
+                if scratchpad.observation:
+                    message += f"Observation: {scratchpad.observation}\n\n"
+
+        return message
+
+    def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
+        """
+            organize historic prompt messages
+        """
+        result: list[PromptMessage] = []
+        scratchpad: list[AgentScratchpadUnit] = []
         current_scratchpad: AgentScratchpadUnit = None
-        for message in messages:
+
+        for message in self.history_prompt_messages:
             if isinstance(message, AssistantPromptMessage):
                 current_scratchpad = AgentScratchpadUnit(
                     agent_response=message.content,
@@ -504,186 +396,29 @@ class CotAgentRunner(BaseAgentRunner):
                             action_name=message.tool_calls[0].function.name,
                             action_input=json.loads(message.tool_calls[0].function.arguments)
                         )
+                        current_scratchpad.action_str = json.dumps(
+                            current_scratchpad.action.to_dict()
+                        )
                     except:
                         pass
-                    
-                agent_scratchpad.append(current_scratchpad)
+                
+                scratchpad.append(current_scratchpad)
             elif isinstance(message, ToolPromptMessage):
                 if current_scratchpad:
                     current_scratchpad.observation = message.content
-        
-        return agent_scratchpad
+            elif isinstance(message, UserPromptMessage):
+                result.append(message)
 
-    def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], 
-                                      agent_prompt_message: AgentPromptEntity,
-    ):
-        """
-            check chain of thought prompt messages, a standard prompt message is like:
-                Respond to the human as helpfully and accurately as possible. 
-
-                {{instruction}}
-
-                You have access to the following tools:
+                if scratchpad:
+                    result.append(AssistantPromptMessage(
+                        content=self._format_assistant_message(scratchpad)
+                    ))
 
-                {{tools}}
+                scratchpad = []
 
-                Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-                Valid action values: "Final Answer" or {{tool_names}}
-
-                Provide only ONE action per $JSON_BLOB, as shown:
-
-                ```
-                {
-                "action": $TOOL_NAME,
-                "action_input": $ACTION_INPUT
-                }
-                ```
-        """
-
-        # parse agent prompt message
-        first_prompt = agent_prompt_message.first_prompt
-        next_iteration = agent_prompt_message.next_iteration
-
-        if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
-            raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
-        
-        # check instruction, tools, and tool_names slots
-        if not first_prompt.find("{{instruction}}") >= 0:
-            raise ValueError("{{instruction}} is required in first_prompt")
-        if not first_prompt.find("{{tools}}") >= 0:
-            raise ValueError("{{tools}} is required in first_prompt")
-        if not first_prompt.find("{{tool_names}}") >= 0:
-            raise ValueError("{{tool_names}} is required in first_prompt")
+        if scratchpad:
+            result.append(AssistantPromptMessage(
+                content=self._format_assistant_message(scratchpad)
+            ))
         
-        if mode == "completion":
-            if not first_prompt.find("{{query}}") >= 0:
-                raise ValueError("{{query}} is required in first_prompt")
-            if not first_prompt.find("{{agent_scratchpad}}") >= 0:
-                raise ValueError("{{agent_scratchpad}} is required in first_prompt")
-        
-        if mode == "completion":
-            if not next_iteration.find("{{observation}}") >= 0:
-                raise ValueError("{{observation}} is required in next_iteration")
-            
-    def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
-        """
-            convert agent scratchpad list to str
-        """
-        next_iteration = self.app_config.agent.prompt.next_iteration
-
-        result = ''
-        for scratchpad in agent_scratchpad:
-            result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
-                next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
-
-        return result
-    
-    def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
-                                      prompt_messages: list[PromptMessage],
-                                      tools: list[PromptMessageTool], 
-                                      agent_scratchpad: list[AgentScratchpadUnit],
-                                      agent_prompt_message: AgentPromptEntity,
-                                      instruction: str,
-                                      input: str,
-        ) -> list[PromptMessage]:
-        """
-            organize chain of thought prompt messages, a standard prompt message is like:
-                Respond to the human as helpfully and accurately as possible. 
-
-                {{instruction}}
-
-                You have access to the following tools:
-
-                {{tools}}
-
-                Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-                Valid action values: "Final Answer" or {{tool_names}}
-
-                Provide only ONE action per $JSON_BLOB, as shown:
-
-                ```
-                {{{{
-                "action": $TOOL_NAME,
-                "action_input": $ACTION_INPUT
-                }}}}
-                ```
-        """
-
-        self._check_cot_prompt_messages(mode, agent_prompt_message)
-
-        # parse agent prompt message
-        first_prompt = agent_prompt_message.first_prompt
-
-        # parse tools
-        tools_str = self._jsonify_tool_prompt_messages(tools)
-
-        # parse tools name
-        tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
-
-        # get system message
-        system_message = first_prompt.replace("{{instruction}}", instruction) \
-                                     .replace("{{tools}}", tools_str) \
-                                     .replace("{{tool_names}}", tool_names)
-
-        # organize prompt messages
-        if mode == "chat":
-            # override system message
-            overridden = False
-            prompt_messages = prompt_messages.copy()
-            for prompt_message in prompt_messages:
-                if isinstance(prompt_message, SystemPromptMessage):
-                    prompt_message.content = system_message
-                    overridden = True
-                    break
-            
-            # convert tool prompt messages to user prompt messages
-            for idx, prompt_message in enumerate(prompt_messages):
-                if isinstance(prompt_message, ToolPromptMessage):
-                    prompt_messages[idx] = UserPromptMessage(
-                        content=prompt_message.content
-                    )
-
-            if not overridden:
-                prompt_messages.insert(0, SystemPromptMessage(
-                    content=system_message,
-                ))
-
-            # add assistant message
-            if len(agent_scratchpad) > 0 and not self._is_first_iteration:
-                prompt_messages.append(AssistantPromptMessage(
-                    content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
-                ))
-            
-            # add user message
-            if len(agent_scratchpad) > 0 and not self._is_first_iteration:
-                prompt_messages.append(UserPromptMessage(
-                    content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
-                ))
-
-            self._is_first_iteration = False
-
-            return prompt_messages
-        elif mode == "completion":
-            # parse agent scratchpad
-            agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
-            self._is_first_iteration = False
-            # parse prompt messages
-            return [UserPromptMessage(
-                content=first_prompt.replace("{{instruction}}", instruction)
-                                    .replace("{{tools}}", tools_str)
-                                    .replace("{{tool_names}}", tool_names)
-                                    .replace("{{query}}", input)
-                                    .replace("{{agent_scratchpad}}", agent_scratchpad_str),
-            )]
-        else:
-            raise ValueError(f"mode {mode} is not supported")
-            
-    def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
-        """
-            jsonify tool prompt messages
-        """
-        tools = jsonable_encoder(tools)
-        try:
-            return json.dumps(tools, ensure_ascii=False)
-        except json.JSONDecodeError:
-            return json.dumps(tools)
+        return result

+ 71 - 0
api/core/agent/cot_chat_agent_runner.py

@@ -0,0 +1,71 @@
+import json
+
+from core.agent.cot_agent_runner import CotAgentRunner
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.utils.encoders import jsonable_encoder
+
+
+class CotChatAgentRunner(CotAgentRunner):
+    def _organize_system_prompt(self) -> SystemPromptMessage:
+        """
+        Organize system prompt
+        """
+        prompt_entity = self.app_config.agent.prompt
+        first_prompt = prompt_entity.first_prompt
+
+        system_prompt = first_prompt \
+            .replace("{{instruction}}", self._instruction) \
+            .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
+            .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
+
+        return SystemPromptMessage(content=system_prompt)
+
+    def _organize_prompt_messages(self) -> list[PromptMessage]:
+        """
+        Organize 
+        """
+        # organize system prompt
+        system_message = self._organize_system_prompt()
+
+        # organize historic prompt messages
+        historic_messages = self._historic_prompt_messages
+
+        # organize current assistant messages
+        agent_scratchpad = self._agent_scratchpad
+        if not agent_scratchpad:
+            assistant_messages = []
+        else:
+            assistant_message = AssistantPromptMessage(content='')
+            for unit in agent_scratchpad:
+                if unit.is_final():
+                    assistant_message.content += f"Final Answer: {unit.agent_response}"
+                else:
+                    assistant_message.content += f"Thought: {unit.thought}\n\n"
+                    if unit.action_str:
+                        assistant_message.content += f"Action: {unit.action_str}\n\n"
+                    if unit.observation:
+                        assistant_message.content += f"Observation: {unit.observation}\n\n"
+
+            assistant_messages = [assistant_message]
+
+        # query messages
+        query_messages = UserPromptMessage(content=self._query)
+
+        if assistant_messages:
+            messages = [
+                system_message,
+                *historic_messages,
+                query_messages,
+                *assistant_messages,
+                UserPromptMessage(content='continue')
+            ]
+        else:
+            messages = [system_message, *historic_messages, query_messages]
+
+        # join all messages
+        return messages

+ 69 - 0
api/core/agent/cot_completion_agent_runner.py

@@ -0,0 +1,69 @@
+import json
+
+from core.agent.cot_agent_runner import CotAgentRunner
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
+from core.model_runtime.utils.encoders import jsonable_encoder
+
+
+class CotCompletionAgentRunner(CotAgentRunner):
+    def _organize_instruction_prompt(self) -> str:
+        """
+        Organize instruction prompt
+        """
+        prompt_entity = self.app_config.agent.prompt
+        first_prompt = prompt_entity.first_prompt
+
+        system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
+            .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
+            .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
+        
+        return system_prompt
+
+    def _organize_historic_prompt(self) -> str:
+        """
+        Organize historic prompt
+        """
+        historic_prompt_messages = self._historic_prompt_messages
+        historic_prompt = ""
+
+        for message in historic_prompt_messages:
+            if isinstance(message, UserPromptMessage):
+                historic_prompt += f"Question: {message.content}\n\n"
+            elif isinstance(message, AssistantPromptMessage):
+                historic_prompt += message.content + "\n\n"
+
+        return historic_prompt
+
+    def _organize_prompt_messages(self) -> list[PromptMessage]:
+        """
+        Organize prompt messages
+        """
+        # organize system prompt
+        system_prompt = self._organize_instruction_prompt()
+
+        # organize historic prompt messages
+        historic_prompt = self._organize_historic_prompt()
+
+        # organize current assistant messages
+        agent_scratchpad = self._agent_scratchpad
+        assistant_prompt = ''
+        for unit in agent_scratchpad:
+            if unit.is_final():
+                assistant_prompt += f"Final Answer: {unit.agent_response}"
+            else:
+                assistant_prompt += f"Thought: {unit.thought}\n\n"
+                if unit.action_str:
+                    assistant_prompt += f"Action: {unit.action_str}\n\n"
+                if unit.observation:
+                    assistant_prompt += f"Observation: {unit.observation}\n\n"
+
+        # query messages
+        query_prompt = f"Question: {self._query}"
+
+        # join all messages
+        prompt = system_prompt \
+            .replace("{{historic_messages}}", historic_prompt) \
+            .replace("{{agent_scratchpad}}", assistant_prompt) \
+            .replace("{{query}}", query_prompt)
+
+        return [UserPromptMessage(content=prompt)]

+ 17 - 0
api/core/agent/entities.py

@@ -34,12 +34,29 @@ class AgentScratchpadUnit(BaseModel):
         action_name: str
         action_input: Union[dict, str]
 
+        def to_dict(self) -> dict:
+            """
+            Convert to dictionary.
+            """
+            return {
+                'action': self.action_name,
+                'action_input': self.action_input,
+            }
+
     agent_response: Optional[str] = None
     thought: Optional[str] = None
     action_str: Optional[str] = None
     observation: Optional[str] = None
     action: Optional[Action] = None
 
+    def is_final(self) -> bool:
+        """
+        Check if the scratchpad unit is final.
+        """
+        return self.action is None or (
+            'final' in self.action.action_name.lower() and 
+            'answer' in self.action.action_name.lower()
+        )
 
 class AgentEntity(BaseModel):
     """

+ 3 - 23
api/core/agent/fc_agent_runner.py

@@ -12,7 +12,6 @@ from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
     PromptMessageContentType,
-    PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
     ToolPromptMessage,
@@ -25,8 +24,8 @@ from models.model import Message
 logger = logging.getLogger(__name__)
 
 class FunctionCallAgentRunner(BaseAgentRunner):
-    def run(self, message: Message,
-                query: str,
+    def run(self, 
+            message: Message, query: str, **kwargs: Any
     ) -> Generator[LLMResultChunk, None, None]:
         """
         Run FunctionCall agent application
@@ -41,26 +40,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
         prompt_messages = self._organize_user_query(query, prompt_messages)
 
         # convert tools into ModelRuntime Tool format
-        prompt_messages_tools: list[PromptMessageTool] = []
-        tool_instances = {}
-        for tool in app_config.agent.tools if app_config.agent else []:
-            try:
-                prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
-            except Exception:
-                # api tool may be deleted
-                continue
-            # save tool entity
-            tool_instances[tool.tool_name] = tool_entity
-            # save prompt tool
-            prompt_messages_tools.append(prompt_tool)
-
-        # convert dataset tools into ModelRuntime Tool format
-        for dataset_tool in self.dataset_tools:
-            prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
-            # save prompt tool
-            prompt_messages_tools.append(prompt_tool)
-            # save tool entity
-            tool_instances[dataset_tool.identity.name] = dataset_tool
+        tool_instances, prompt_messages_tools = self._init_prompt_tools()
 
         iteration_step = 1
         max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1

+ 183 - 0
api/core/agent/output_parser/cot_output_parser.py

@@ -0,0 +1,183 @@
+import json
+import re
+from collections.abc import Generator
+from typing import Union
+
+from core.agent.entities import AgentScratchpadUnit
+from core.model_runtime.entities.llm_entities import LLMResultChunk
+
+
+class CotAgentOutputParser:
+    @classmethod
+    def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \
+        Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
+        def parse_action(json_str):
+            try:
+                action = json.loads(json_str)
+                action_name = None
+                action_input = None
+
+                for key, value in action.items():
+                    if 'input' in key.lower():
+                        action_input = value
+                    else:
+                        action_name = value
+
+                if action_name is not None and action_input is not None:
+                    return AgentScratchpadUnit.Action(
+                        action_name=action_name,
+                        action_input=action_input,
+                    )
+                else:
+                    return json_str or ''
+            except:
+                return json_str or ''
+            
+        def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
+            code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
+            if not code_blocks:
+                return
+            for block in code_blocks:
+                json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
+                yield parse_action(json_text)
+            
+        code_block_cache = ''
+        code_block_delimiter_count = 0
+        in_code_block = False
+        json_cache = ''
+        json_quote_count = 0
+        in_json = False
+        got_json = False
+
+        action_cache = ''
+        action_str = 'action:'
+        action_idx = 0
+
+        thought_cache = ''
+        thought_str = 'thought:'
+        thought_idx = 0
+
+        for response in llm_response:
+            response = response.delta.message.content
+            if not isinstance(response, str):
+                continue
+
+            # stream
+            index = 0
+            while index < len(response):
+                steps = 1
+                delta = response[index:index+steps]
+                last_character = response[index-1] if index > 0 else ''
+
+                if delta == '`':
+                    code_block_cache += delta
+                    code_block_delimiter_count += 1
+                else:
+                    if not in_code_block:
+                        if code_block_delimiter_count > 0:
+                            yield code_block_cache
+                        code_block_cache = ''
+                    else:
+                        code_block_cache += delta
+                    code_block_delimiter_count = 0
+
+                if not in_code_block and not in_json:
+                    if delta.lower() == action_str[action_idx] and action_idx == 0:
+                        if last_character not in ['\n', ' ', '']:
+                            index += steps
+                            yield delta
+                            continue
+
+                        action_cache += delta
+                        action_idx += 1
+                        if action_idx == len(action_str):
+                            action_cache = ''
+                            action_idx = 0
+                        index += steps
+                        continue
+                    elif delta.lower() == action_str[action_idx] and action_idx > 0:
+                        action_cache += delta
+                        action_idx += 1
+                        if action_idx == len(action_str):
+                            action_cache = ''
+                            action_idx = 0
+                        index += steps
+                        continue
+                    else:
+                        if action_cache:
+                            yield action_cache
+                            action_cache = ''
+                            action_idx = 0
+                    
+                    if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
+                        if last_character not in ['\n', ' ', '']:
+                            index += steps
+                            yield delta
+                            continue
+
+                        thought_cache += delta
+                        thought_idx += 1
+                        if thought_idx == len(thought_str):
+                            thought_cache = ''
+                            thought_idx = 0
+                        index += steps
+                        continue
+                    elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
+                        thought_cache += delta
+                        thought_idx += 1
+                        if thought_idx == len(thought_str):
+                            thought_cache = ''
+                            thought_idx = 0
+                        index += steps
+                        continue
+                    else:
+                        if thought_cache:
+                            yield thought_cache
+                            thought_cache = ''
+                            thought_idx = 0
+
+                if code_block_delimiter_count == 3:
+                    if in_code_block:
+                        yield from extra_json_from_code_block(code_block_cache)
+                        code_block_cache = ''
+                        
+                    in_code_block = not in_code_block
+                    code_block_delimiter_count = 0
+
+                if not in_code_block:
+                    # handle single json
+                    if delta == '{':
+                        json_quote_count += 1
+                        in_json = True
+                        json_cache += delta
+                    elif delta == '}':
+                        json_cache += delta
+                        if json_quote_count > 0:
+                            json_quote_count -= 1
+                            if json_quote_count == 0:
+                                in_json = False
+                                got_json = True
+                                index += steps
+                                continue
+                    else:
+                        if in_json:
+                            json_cache += delta
+
+                    if got_json:
+                        got_json = False
+                        yield parse_action(json_cache)
+                        json_cache = ''
+                        json_quote_count = 0
+                        in_json = False
+                    
+                if not in_code_block and not in_json:
+                    yield delta.replace('`', '')
+
+                index += steps
+
+        if code_block_cache:
+            yield code_block_cache
+
+        if json_cache:
+            yield parse_action(json_cache)
+

+ 37 - 44
api/core/app/apps/agent_chat/app_runner.py

@@ -1,7 +1,8 @@
 import logging
 from typing import cast
 
-from core.agent.cot_agent_runner import CotAgentRunner
+from core.agent.cot_chat_agent_runner import CotChatAgentRunner
+from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
 from core.agent.entities import AgentEntity
 from core.agent.fc_agent_runner import FunctionCallAgentRunner
 from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
@@ -11,8 +12,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Mo
 from core.app.entities.queue_entities import QueueAnnotationReplyEvent
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
-from core.model_runtime.entities.llm_entities import LLMUsage
-from core.model_runtime.entities.model_entities import ModelFeature
+from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
+from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.moderation.base import ModerationException
 from core.tools.entities.tool_entities import ToolRuntimeVariablePool
@@ -207,48 +208,40 @@ class AgentChatAppRunner(AppRunner):
 
         # start agent runner
         if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
-            assistant_cot_runner = CotAgentRunner(
-                tenant_id=app_config.tenant_id,
-                application_generate_entity=application_generate_entity,
-                conversation=conversation,
-                app_config=app_config,
-                model_config=application_generate_entity.model_config,
-                config=agent_entity,
-                queue_manager=queue_manager,
-                message=message,
-                user_id=application_generate_entity.user_id,
-                memory=memory,
-                prompt_messages=prompt_message,
-                variables_pool=tool_variables,
-                db_variables=tool_conversation_variables,
-                model_instance=model_instance
-            )
-            invoke_result = assistant_cot_runner.run(
-                message=message,
-                query=query,
-                inputs=inputs,
-            )
+            # check LLM mode
+            if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
+                runner_cls = CotChatAgentRunner
+            elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
+                runner_cls = CotCompletionAgentRunner
+            else:
+                raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
         elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
-            assistant_fc_runner = FunctionCallAgentRunner(
-                tenant_id=app_config.tenant_id,
-                application_generate_entity=application_generate_entity,
-                conversation=conversation,
-                app_config=app_config,
-                model_config=application_generate_entity.model_config,
-                config=agent_entity,
-                queue_manager=queue_manager,
-                message=message,
-                user_id=application_generate_entity.user_id,
-                memory=memory,
-                prompt_messages=prompt_message,
-                variables_pool=tool_variables,
-                db_variables=tool_conversation_variables,
-                model_instance=model_instance
-            )
-            invoke_result = assistant_fc_runner.run(
-                message=message,
-                query=query,
-            )
+            runner_cls = FunctionCallAgentRunner
+        else:
+            raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
+        
+        runner = runner_cls(
+            tenant_id=app_config.tenant_id,
+            application_generate_entity=application_generate_entity,
+            conversation=conversation,
+            app_config=app_config,
+            model_config=application_generate_entity.model_config,
+            config=agent_entity,
+            queue_manager=queue_manager,
+            message=message,
+            user_id=application_generate_entity.user_id,
+            memory=memory,
+            prompt_messages=prompt_message,
+            variables_pool=tool_variables,
+            db_variables=tool_conversation_variables,
+            model_instance=model_instance
+        )
+
+        invoke_result = runner.run(
+            message=message,
+            query=query,
+            inputs=inputs,
+        )
 
         # handle invoke result
         self._handle_invoke_result(

+ 3 - 1
api/core/tools/prompt/template.py

@@ -38,8 +38,10 @@ Action:
 ```
 
 Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
+{{historic_messages}}
 Question: {{query}}
-Thought: {{agent_scratchpad}}"""
+{{agent_scratchpad}}
+Thought:"""
 
 ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
 Thought:"""