|
@@ -1,6 +1,7 @@
|
|
|
import json
|
|
|
import logging
|
|
|
from collections.abc import Generator
|
|
|
+from copy import deepcopy
|
|
|
from typing import Any, Union
|
|
|
|
|
|
from core.agent.base_agent_runner import BaseAgentRunner
|
|
@@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
PromptMessage,
|
|
|
+ PromptMessageContentType,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
+ TextPromptMessageContent,
|
|
|
ToolPromptMessage,
|
|
|
UserPromptMessage,
|
|
|
)
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
|
|
from core.tools.tool_engine import ToolEngine
|
|
|
-from models.model import Conversation, Message, MessageAgentThought
|
|
|
+from models.model import Message
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
- def run(self, conversation: Conversation,
|
|
|
- message: Message,
|
|
|
+ def run(self, message: Message,
|
|
|
query: str,
|
|
|
) -> Generator[LLMResultChunk, None, None]:
|
|
|
"""
|
|
@@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
|
|
prompt_messages = self.history_prompt_messages
|
|
|
- prompt_messages = self.organize_prompt_messages(
|
|
|
- prompt_template=prompt_template,
|
|
|
- query=query,
|
|
|
- prompt_messages=prompt_messages
|
|
|
- )
|
|
|
+ prompt_messages = self._init_system_message(prompt_template, prompt_messages)
|
|
|
+ prompt_messages = self._organize_user_query(query, prompt_messages)
|
|
|
|
|
|
# convert tools into ModelRuntime Tool format
|
|
|
prompt_messages_tools: list[PromptMessageTool] = []
|
|
@@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
# continue to run until there is not any tool call
|
|
|
function_call_state = True
|
|
|
- agent_thoughts: list[MessageAgentThought] = []
|
|
|
llm_usage = {
|
|
|
'usage': None
|
|
|
}
|
|
@@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
}
|
|
|
|
|
|
tool_responses.append(tool_response)
|
|
|
- prompt_messages = self.organize_prompt_messages(
|
|
|
- prompt_template=prompt_template,
|
|
|
- query=None,
|
|
|
+ prompt_messages = self._organize_assistant_message(
|
|
|
tool_call_id=tool_call_id,
|
|
|
tool_call_name=tool_call_name,
|
|
|
tool_response=tool_response['tool_response'],
|
|
@@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
iteration_step += 1
|
|
|
|
|
|
+ prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
|
|
+
|
|
|
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
|
|
# publish end event
|
|
|
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
|
@@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
|
- def organize_prompt_messages(self, prompt_template: str,
|
|
|
- query: str = None,
|
|
|
- tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
|
|
- prompt_messages: list[PromptMessage] = None
|
|
|
- ) -> list[PromptMessage]:
|
|
|
+ def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
"""
|
|
|
- Organize prompt messages
|
|
|
+ Initialize system message
|
|
|
"""
|
|
|
-
|
|
|
- if not prompt_messages:
|
|
|
- prompt_messages = [
|
|
|
+ if not prompt_messages and prompt_template:
|
|
|
+ return [
|
|
|
SystemPromptMessage(content=prompt_template),
|
|
|
- UserPromptMessage(content=query),
|
|
|
]
|
|
|
+
|
|
|
+ if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
|
|
+ prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
|
|
+
|
|
|
+ return prompt_messages
|
|
|
+
|
|
|
+ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ Organize user query
|
|
|
+ """
|
|
|
+ if self.files:
|
|
|
+ prompt_message_contents = [TextPromptMessageContent(data=query)]
|
|
|
+ for file_obj in self.files:
|
|
|
+ prompt_message_contents.append(file_obj.prompt_message_content)
|
|
|
+
|
|
|
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
|
|
else:
|
|
|
- if tool_response:
|
|
|
- prompt_messages = prompt_messages.copy()
|
|
|
- prompt_messages.append(
|
|
|
- ToolPromptMessage(
|
|
|
- content=tool_response,
|
|
|
- tool_call_id=tool_call_id,
|
|
|
- name=tool_call_name,
|
|
|
- )
|
|
|
+ prompt_messages.append(UserPromptMessage(content=query))
|
|
|
+
|
|
|
+ return prompt_messages
|
|
|
+
|
|
|
+ def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
|
|
+ prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ Organize assistant message
|
|
|
+ """
|
|
|
+ prompt_messages = deepcopy(prompt_messages)
|
|
|
+
|
|
|
+ if tool_response is not None:
|
|
|
+ prompt_messages.append(
|
|
|
+ ToolPromptMessage(
|
|
|
+ content=tool_response,
|
|
|
+ tool_call_id=tool_call_id,
|
|
|
+ name=tool_call_name,
|
|
|
)
|
|
|
+ )
|
|
|
+
|
|
|
+ return prompt_messages
|
|
|
+
|
|
|
+ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ As for now, gpt supports both fc and vision at the first iteration.
|
|
|
+ We need to remove the image messages from the prompt messages at the first iteration.
|
|
|
+ """
|
|
|
+ prompt_messages = deepcopy(prompt_messages)
|
|
|
+
|
|
|
+ for prompt_message in prompt_messages:
|
|
|
+ if isinstance(prompt_message, UserPromptMessage):
|
|
|
+ if isinstance(prompt_message.content, list):
|
|
|
+ prompt_message.content = '\n'.join([
|
|
|
+ content.data if content.type == PromptMessageContentType.TEXT else
|
|
|
+ '[image]' if content.type == PromptMessageContentType.IMAGE else
|
|
|
+ '[file]'
|
|
|
+ for content in prompt_message.content
|
|
|
+ ])
|
|
|
|
|
|
return prompt_messages
|