Просмотр исходного кода

fix organize agent's history messages without recalculating tokens (#4324)

Co-authored-by: chenyongzhao <chenyz@mama.cn>
zeroameli 10 месяцев назад
Родитель
Сommit
afed3610fc

+ 4 - 1
api/core/agent/base_agent_runner.py

@@ -128,6 +128,8 @@ class BaseAgentRunner(AppRunner):
             self.files = application_generate_entity.files
         else:
             self.files = []
+        self.query = None
+        self._current_thoughts: list[PromptMessage] = []
 
     def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
             -> AgentChatAppGenerateEntity:
@@ -464,7 +466,7 @@ class BaseAgentRunner(AppRunner):
         for message in messages:
             if message.id == self.message.id:
                 continue
-            
+
             result.append(self.organize_agent_user_prompt(message))
             agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
             if agent_thoughts:
@@ -545,3 +547,4 @@ class BaseAgentRunner(AppRunner):
                 return UserPromptMessage(content=prompt_message_contents)
         else:
             return UserPromptMessage(content=message.query)
+         

+ 9 - 1
api/core/agent/cot_agent_runner.py

@@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
     ToolPromptMessage,
     UserPromptMessage,
 )
+from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
 from core.tools.entities.tool_entities import ToolInvokeMeta
 from core.tools.tool.tool import Tool
 from core.tools.tool_engine import ToolEngine
@@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
 
         return message
 
-    def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
+    def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
         """
             organize historic prompt messages
         """
@@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         scratchpad: list[AgentScratchpadUnit] = []
         current_scratchpad: AgentScratchpadUnit = None
 
+        self.history_prompt_messages = AgentHistoryPromptTransform(
+            model_config=self.model_config,
+            prompt_messages=current_session_messages or [],
+            history_messages=self.history_prompt_messages,
+            memory=self.memory
+        ).get_prompt()
+
         for message in self.history_prompt_messages:
             if isinstance(message, AssistantPromptMessage):
                 current_scratchpad = AgentScratchpadUnit(

+ 9 - 3
api/core/agent/cot_chat_agent_runner.py

@@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner):
         # organize system prompt
         system_message = self._organize_system_prompt()
 
-        # organize historic prompt messages
-        historic_messages = self._historic_prompt_messages
-
         # organize current assistant messages
         agent_scratchpad = self._agent_scratchpad
         if not agent_scratchpad:
@@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner):
         query_messages = UserPromptMessage(content=self._query)
 
         if assistant_messages:
+            # organize historic prompt messages
+            historic_messages = self._organize_historic_prompt_messages([
+                system_message,
+                query_messages,
+                *assistant_messages,
+                UserPromptMessage(content='continue')
+            ])            
             messages = [
                 system_message,
                 *historic_messages,
@@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner):
                 UserPromptMessage(content='continue')
             ]
         else:
+            # organize historic prompt messages
+            historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
             messages = [system_message, *historic_messages, query_messages]
 
         # join all messages

+ 2 - 2
api/core/agent/cot_completion_agent_runner.py

@@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
         
         return system_prompt
 
-    def _organize_historic_prompt(self) -> str:
+    def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
         """
         Organize historic prompt
         """
-        historic_prompt_messages = self._historic_prompt_messages
+        historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
         historic_prompt = ""
 
         for message in historic_prompt_messages:

+ 36 - 33
api/core/agent/fc_agent_runner.py

@@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
     ToolPromptMessage,
     UserPromptMessage,
 )
+from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
 from core.tools.entities.tool_entities import ToolInvokeMeta
 from core.tools.tool_engine import ToolEngine
 from models.model import Message
@@ -24,21 +25,18 @@ from models.model import Message
 logger = logging.getLogger(__name__)
 
 class FunctionCallAgentRunner(BaseAgentRunner):
+
     def run(self, 
             message: Message, query: str, **kwargs: Any
     ) -> Generator[LLMResultChunk, None, None]:
         """
         Run FunctionCall agent application
         """
+        self.query = query
         app_generate_entity = self.application_generate_entity
 
         app_config = self.app_config
 
-        prompt_template = app_config.prompt_template.simple_prompt_template or ''
-        prompt_messages = self.history_prompt_messages
-        prompt_messages = self._init_system_message(prompt_template, prompt_messages)
-        prompt_messages = self._organize_user_query(query, prompt_messages)
-
         # convert tools into ModelRuntime Tool format
         tool_instances, prompt_messages_tools = self._init_prompt_tools()
 
@@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             )
 
             # recalc llm max tokens
+            prompt_messages = self._organize_prompt_messages()
             self.recalc_llm_max_tokens(self.model_config, prompt_messages)
             # invoke model
             chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
@@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             else:
                 assistant_message.content = response
             
-            prompt_messages.append(assistant_message)
+            self._current_thoughts.append(assistant_message)
 
             # save thought
             self.save_agent_thought(
@@ -265,12 +264,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                     }
                 
                 tool_responses.append(tool_response)
-                prompt_messages = self._organize_assistant_message(
-                    tool_call_id=tool_call_id,
-                    tool_call_name=tool_call_name,
-                    tool_response=tool_response['tool_response'],
-                    prompt_messages=prompt_messages,
-                )
+                if tool_response['tool_response'] is not None:
+                    self._current_thoughts.append(
+                        ToolPromptMessage(
+                            content=tool_response['tool_response'],
+                            tool_call_id=tool_call_id,
+                            name=tool_call_name,
+                        )
+                    ) 
 
             if len(tool_responses) > 0:
                 # save agent thought
@@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
             iteration_step += 1
 
-            prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
-
         self.update_db_variables(self.variables_pool, self.db_variables_pool)
         # publish end event
         self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
         return prompt_messages
     
-    def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, 
-                                    prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
-        """
-        Organize assistant message
-        """
-        prompt_messages = deepcopy(prompt_messages)
-
-        if tool_response is not None:
-            prompt_messages.append(
-                ToolPromptMessage(
-                    content=tool_response,
-                    tool_call_id=tool_call_id,
-                    name=tool_call_name,
-                )
-            )
-
-        return prompt_messages
-    
     def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
         """
         As for now, gpt supports both fc and vision at the first iteration.
@@ -428,4 +409,26 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                         for content in prompt_message.content 
                     ])
 
-        return prompt_messages
+        return prompt_messages
+
+    def _organize_prompt_messages(self):
+        prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
+        self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
+        query_prompt_messages = self._organize_user_query(self.query, [])
+
+        self.history_prompt_messages = AgentHistoryPromptTransform(
+            model_config=self.model_config,
+            prompt_messages=[*query_prompt_messages, *self._current_thoughts],
+            history_messages=self.history_prompt_messages,
+            memory=self.memory
+        ).get_prompt()
+
+        prompt_messages = [
+            *self.history_prompt_messages,
+            *query_prompt_messages,
+            *self._current_thoughts
+        ]
+        if len(self._current_thoughts) != 0:
+            # clear messages after the first iteration
+            prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
+        return prompt_messages

+ 82 - 0
api/core/prompt/agent_history_prompt_transform.py

@@ -0,0 +1,82 @@
+from typing import Optional, cast
+
+from core.app.entities.app_invoke_entities import (
+    ModelConfigWithCredentialsEntity,
+)
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.message_entities import (
+    PromptMessage,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.prompt_transform import PromptTransform
+
+
+class AgentHistoryPromptTransform(PromptTransform):
+    """
+    History Prompt Transform for Agent App
+    """
+    def __init__(self,
+                 model_config: ModelConfigWithCredentialsEntity,
+                 prompt_messages: list[PromptMessage],
+                 history_messages: list[PromptMessage],
+                 memory: Optional[TokenBufferMemory] = None,
+                 ):
+        self.model_config = model_config
+        self.prompt_messages = prompt_messages
+        self.history_messages = history_messages
+        self.memory = memory
+
+    def get_prompt(self) -> list[PromptMessage]:
+        prompt_messages = []
+        num_system = 0
+        for prompt_message in self.history_messages:
+            if isinstance(prompt_message, SystemPromptMessage):
+                prompt_messages.append(prompt_message)
+                num_system += 1
+
+        if not self.memory:
+            return prompt_messages
+
+        max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
+
+        model_type_instance = self.model_config.provider_model_bundle.model_type_instance
+        model_type_instance = cast(LargeLanguageModel, model_type_instance)
+
+        curr_message_tokens = model_type_instance.get_num_tokens(
+            self.memory.model_instance.model,
+            self.memory.model_instance.credentials,
+            self.history_messages
+        )
+        if curr_message_tokens <= max_token_limit:
+            return self.history_messages
+
+        # number of prompt has been appended in current message
+        num_prompt = 0
+        # append prompt messages in desc order
+        for prompt_message in self.history_messages[::-1]:
+            if isinstance(prompt_message, SystemPromptMessage):
+                continue
+            prompt_messages.append(prompt_message)
+            num_prompt += 1
+            # a message is start with UserPromptMessage
+            if isinstance(prompt_message, UserPromptMessage):
+                curr_message_tokens = model_type_instance.get_num_tokens(
+                    self.memory.model_instance.model,
+                    self.memory.model_instance.credentials,
+                    prompt_messages
+                )
+                # if current message token is overflow, drop all the prompts in current message and break
+                if curr_message_tokens > max_token_limit:
+                    prompt_messages = prompt_messages[:-num_prompt]
+                    break
+                num_prompt = 0
+        # return prompt messages in asc order
+        message_prompts = prompt_messages[num_system:]
+        message_prompts.reverse()
+
+        # merge system and message prompt
+        prompt_messages = prompt_messages[:num_system]
+        prompt_messages.extend(message_prompts)
+        return prompt_messages

+ 77 - 0
api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py

@@ -0,0 +1,77 @@
+from unittest.mock import MagicMock
+
+from core.app.entities.app_invoke_entities import (
+    ModelConfigWithCredentialsEntity,
+)
+from core.entities.provider_configuration import ProviderModelBundle
+from core.memory.token_buffer_memory import TokenBufferMemory
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
+from models.model import Conversation
+
+
+def test_get_prompt():
+    prompt_messages = [
+        SystemPromptMessage(content='System Template'),
+        UserPromptMessage(content='User Query'),
+    ]
+    history_messages = [
+        SystemPromptMessage(content='System Prompt 1'),
+        UserPromptMessage(content='User Prompt 1'),
+        AssistantPromptMessage(content='Assistant Thought 1'),
+        ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
+        ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
+        SystemPromptMessage(content='System Prompt 2'),
+        UserPromptMessage(content='User Prompt 2'),
+        AssistantPromptMessage(content='Assistant Thought 2'),
+        ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
+        ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
+        UserPromptMessage(content='User Prompt 3'),
+        AssistantPromptMessage(content='Assistant Thought 3'),
+    ]
+
+    # use message number instead of token for testing
+    def side_effect_get_num_tokens(*args):
+        return len(args[2])
+    large_language_model_mock = MagicMock(spec=LargeLanguageModel)
+    large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
+
+    provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
+    provider_model_bundle_mock.model_type_instance = large_language_model_mock
+
+    model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
+    model_config_mock.model = 'openai'
+    model_config_mock.credentials = {}
+    model_config_mock.provider_model_bundle = provider_model_bundle_mock
+
+    memory = TokenBufferMemory(
+        conversation=Conversation(),
+        model_instance=model_config_mock
+    )
+
+    transform = AgentHistoryPromptTransform(
+        model_config=model_config_mock,
+        prompt_messages=prompt_messages,
+        history_messages=history_messages,
+        memory=memory
+    )
+
+    max_token_limit = 5
+    transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
+    result = transform.get_prompt()
+
+    assert len(result) <= max_token_limit
+    assert len(result) == 4
+
+    max_token_limit = 20
+    transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
+    result = transform.get_prompt()
+
+    assert len(result) <= max_token_limit
+    assert len(result) == 12