|
@@ -1,6 +1,7 @@
|
|
|
import json
|
|
|
import logging
|
|
|
import uuid
|
|
|
+from collections.abc import Mapping, Sequence
|
|
|
from datetime import datetime, timezone
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
@@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
+
|
|
|
class BaseAgentRunner(AppRunner):
|
|
|
- def __init__(self, tenant_id: str,
|
|
|
- application_generate_entity: AgentChatAppGenerateEntity,
|
|
|
- conversation: Conversation,
|
|
|
- app_config: AgentChatAppConfig,
|
|
|
- model_config: ModelConfigWithCredentialsEntity,
|
|
|
- config: AgentEntity,
|
|
|
- queue_manager: AppQueueManager,
|
|
|
- message: Message,
|
|
|
- user_id: str,
|
|
|
- memory: Optional[TokenBufferMemory] = None,
|
|
|
- prompt_messages: Optional[list[PromptMessage]] = None,
|
|
|
- variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
|
|
- db_variables: Optional[ToolConversationVariables] = None,
|
|
|
- model_instance: ModelInstance = None
|
|
|
- ) -> None:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ tenant_id: str,
|
|
|
+ application_generate_entity: AgentChatAppGenerateEntity,
|
|
|
+ conversation: Conversation,
|
|
|
+ app_config: AgentChatAppConfig,
|
|
|
+ model_config: ModelConfigWithCredentialsEntity,
|
|
|
+ config: AgentEntity,
|
|
|
+ queue_manager: AppQueueManager,
|
|
|
+ message: Message,
|
|
|
+ user_id: str,
|
|
|
+ memory: Optional[TokenBufferMemory] = None,
|
|
|
+ prompt_messages: Optional[list[PromptMessage]] = None,
|
|
|
+ variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
|
|
+ db_variables: Optional[ToolConversationVariables] = None,
|
|
|
+ model_instance: ModelInstance = None,
|
|
|
+ ) -> None:
|
|
|
"""
|
|
|
Agent runner
|
|
|
:param tenant_id: tenant id
|
|
@@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
self.message = message
|
|
|
self.user_id = user_id
|
|
|
self.memory = memory
|
|
|
- self.history_prompt_messages = self.organize_agent_history(
|
|
|
- prompt_messages=prompt_messages or []
|
|
|
- )
|
|
|
+ self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
|
|
self.variables_pool = variables_pool
|
|
|
self.db_variables_pool = db_variables
|
|
|
self.model_instance = model_instance
|
|
@@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
|
|
|
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
|
|
return_resource=app_config.additional_features.show_retrieve_source,
|
|
|
invoke_from=application_generate_entity.invoke_from,
|
|
|
- hit_callback=hit_callback
|
|
|
+ hit_callback=hit_callback,
|
|
|
)
|
|
|
# get how many agent thoughts have been created
|
|
|
- self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
|
|
- MessageAgentThought.message_id == self.message.id,
|
|
|
- ).count()
|
|
|
+ self.agent_thought_count = (
|
|
|
+ db.session.query(MessageAgentThought)
|
|
|
+ .filter(
|
|
|
+ MessageAgentThought.message_id == self.message.id,
|
|
|
+ )
|
|
|
+ .count()
|
|
|
+ )
|
|
|
db.session.close()
|
|
|
|
|
|
# check if model supports stream tool call
|
|
@@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
|
|
|
self.query = None
|
|
|
self._current_thoughts: list[PromptMessage] = []
|
|
|
|
|
|
- def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
|
|
- -> AgentChatAppGenerateEntity:
|
|
|
+ def _repack_app_generate_entity(
|
|
|
+ self, app_generate_entity: AgentChatAppGenerateEntity
|
|
|
+ ) -> AgentChatAppGenerateEntity:
|
|
|
"""
|
|
|
Repack app generate entity
|
|
|
"""
|
|
|
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
|
|
- app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
|
|
+ app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
|
|
|
|
|
|
return app_generate_entity
|
|
|
-
|
|
|
+
|
|
|
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
|
|
"""
|
|
|
- convert tool to prompt message tool
|
|
|
+ convert tool to prompt message tool
|
|
|
"""
|
|
|
tool_entity = ToolManager.get_agent_tool_runtime(
|
|
|
tenant_id=self.tenant_id,
|
|
|
app_id=self.app_config.app_id,
|
|
|
agent_tool=tool,
|
|
|
- invoke_from=self.application_generate_entity.invoke_from
|
|
|
+ invoke_from=self.application_generate_entity.invoke_from,
|
|
|
)
|
|
|
tool_entity.load_variables(self.variables_pool)
|
|
|
|
|
@@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
"type": "object",
|
|
|
"properties": {},
|
|
|
"required": [],
|
|
|
- }
|
|
|
+ },
|
|
|
)
|
|
|
|
|
|
parameters = tool_entity.get_all_runtime_parameters()
|
|
@@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
|
|
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
|
|
enum = [option.value for option in parameter.options]
|
|
|
|
|
|
- message_tool.parameters['properties'][parameter.name] = {
|
|
|
+ message_tool.parameters["properties"][parameter.name] = {
|
|
|
"type": parameter_type,
|
|
|
- "description": parameter.llm_description or '',
|
|
|
+ "description": parameter.llm_description or "",
|
|
|
}
|
|
|
|
|
|
if len(enum) > 0:
|
|
|
- message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
|
|
+ message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
|
|
|
|
|
if parameter.required:
|
|
|
- message_tool.parameters['required'].append(parameter.name)
|
|
|
+ message_tool.parameters["required"].append(parameter.name)
|
|
|
|
|
|
return message_tool, tool_entity
|
|
|
-
|
|
|
+
|
|
|
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
|
|
"""
|
|
|
convert dataset retriever tool to prompt message tool
|
|
@@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
|
|
|
"type": "object",
|
|
|
"properties": {},
|
|
|
"required": [],
|
|
|
- }
|
|
|
+ },
|
|
|
)
|
|
|
|
|
|
for parameter in tool.get_runtime_parameters():
|
|
|
- parameter_type = 'string'
|
|
|
-
|
|
|
- prompt_tool.parameters['properties'][parameter.name] = {
|
|
|
+ parameter_type = "string"
|
|
|
+
|
|
|
+ prompt_tool.parameters["properties"][parameter.name] = {
|
|
|
"type": parameter_type,
|
|
|
- "description": parameter.llm_description or '',
|
|
|
+ "description": parameter.llm_description or "",
|
|
|
}
|
|
|
|
|
|
if parameter.required:
|
|
|
- if parameter.name not in prompt_tool.parameters['required']:
|
|
|
- prompt_tool.parameters['required'].append(parameter.name)
|
|
|
+ if parameter.name not in prompt_tool.parameters["required"]:
|
|
|
+ prompt_tool.parameters["required"].append(parameter.name)
|
|
|
|
|
|
return prompt_tool
|
|
|
-
|
|
|
- def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
|
|
+
|
|
|
+ def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
|
|
"""
|
|
|
Init tools
|
|
|
"""
|
|
@@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
|
|
|
enum = []
|
|
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
|
|
enum = [option.value for option in parameter.options]
|
|
|
-
|
|
|
- prompt_tool.parameters['properties'][parameter.name] = {
|
|
|
+
|
|
|
+ prompt_tool.parameters["properties"][parameter.name] = {
|
|
|
"type": parameter_type,
|
|
|
- "description": parameter.llm_description or '',
|
|
|
+ "description": parameter.llm_description or "",
|
|
|
}
|
|
|
|
|
|
if len(enum) > 0:
|
|
|
- prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
|
|
+ prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
|
|
|
|
|
if parameter.required:
|
|
|
- if parameter.name not in prompt_tool.parameters['required']:
|
|
|
- prompt_tool.parameters['required'].append(parameter.name)
|
|
|
+ if parameter.name not in prompt_tool.parameters["required"]:
|
|
|
+ prompt_tool.parameters["required"].append(parameter.name)
|
|
|
|
|
|
return prompt_tool
|
|
|
-
|
|
|
- def create_agent_thought(self, message_id: str, message: str,
|
|
|
- tool_name: str, tool_input: str, messages_ids: list[str]
|
|
|
- ) -> MessageAgentThought:
|
|
|
+
|
|
|
+ def create_agent_thought(
|
|
|
+ self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
|
|
+ ) -> MessageAgentThought:
|
|
|
"""
|
|
|
Create agent thought
|
|
|
"""
|
|
|
thought = MessageAgentThought(
|
|
|
message_id=message_id,
|
|
|
message_chain_id=None,
|
|
|
- thought='',
|
|
|
+ thought="",
|
|
|
tool=tool_name,
|
|
|
- tool_labels_str='{}',
|
|
|
- tool_meta_str='{}',
|
|
|
+ tool_labels_str="{}",
|
|
|
+ tool_meta_str="{}",
|
|
|
tool_input=tool_input,
|
|
|
message=message,
|
|
|
message_token=0,
|
|
|
message_unit_price=0,
|
|
|
message_price_unit=0,
|
|
|
- message_files=json.dumps(messages_ids) if messages_ids else '',
|
|
|
- answer='',
|
|
|
- observation='',
|
|
|
+ message_files=json.dumps(messages_ids) if messages_ids else "",
|
|
|
+ answer="",
|
|
|
+ observation="",
|
|
|
answer_token=0,
|
|
|
answer_unit_price=0,
|
|
|
answer_price_unit=0,
|
|
|
tokens=0,
|
|
|
total_price=0,
|
|
|
position=self.agent_thought_count + 1,
|
|
|
- currency='USD',
|
|
|
+ currency="USD",
|
|
|
latency=0,
|
|
|
- created_by_role='account',
|
|
|
+ created_by_role="account",
|
|
|
created_by=self.user_id,
|
|
|
)
|
|
|
|
|
@@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
|
|
|
|
|
|
return thought
|
|
|
|
|
|
- def save_agent_thought(self,
|
|
|
- agent_thought: MessageAgentThought,
|
|
|
- tool_name: str,
|
|
|
- tool_input: Union[str, dict],
|
|
|
- thought: str,
|
|
|
- observation: Union[str, dict],
|
|
|
- tool_invoke_meta: Union[str, dict],
|
|
|
- answer: str,
|
|
|
- messages_ids: list[str],
|
|
|
- llm_usage: LLMUsage = None) -> MessageAgentThought:
|
|
|
+ def save_agent_thought(
|
|
|
+ self,
|
|
|
+ agent_thought: MessageAgentThought,
|
|
|
+ tool_name: str,
|
|
|
+ tool_input: Union[str, dict],
|
|
|
+ thought: str,
|
|
|
+ observation: Union[str, dict],
|
|
|
+ tool_invoke_meta: Union[str, dict],
|
|
|
+ answer: str,
|
|
|
+ messages_ids: list[str],
|
|
|
+ llm_usage: LLMUsage = None,
|
|
|
+ ) -> MessageAgentThought:
|
|
|
"""
|
|
|
Save agent thought
|
|
|
"""
|
|
|
- agent_thought = db.session.query(MessageAgentThought).filter(
|
|
|
- MessageAgentThought.id == agent_thought.id
|
|
|
- ).first()
|
|
|
+ agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
|
|
|
|
|
if thought is not None:
|
|
|
agent_thought.thought = thought
|
|
@@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
observation = json.dumps(observation, ensure_ascii=False)
|
|
|
except Exception as e:
|
|
|
observation = json.dumps(observation)
|
|
|
-
|
|
|
+
|
|
|
agent_thought.observation = observation
|
|
|
|
|
|
if answer is not None:
|
|
@@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
|
|
|
if messages_ids is not None and len(messages_ids) > 0:
|
|
|
agent_thought.message_files = json.dumps(messages_ids)
|
|
|
-
|
|
|
+
|
|
|
if llm_usage:
|
|
|
agent_thought.message_token = llm_usage.prompt_tokens
|
|
|
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
|
@@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
|
|
|
# check if tool labels is not empty
|
|
|
labels = agent_thought.tool_labels or {}
|
|
|
- tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
|
|
+ tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
|
|
for tool in tools:
|
|
|
if not tool:
|
|
|
continue
|
|
@@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
if tool_label:
|
|
|
labels[tool] = tool_label.to_dict()
|
|
|
else:
|
|
|
- labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
|
|
+ labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
|
|
|
|
|
agent_thought.tool_labels_str = json.dumps(labels)
|
|
|
|
|
@@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
|
|
|
|
|
|
db.session.commit()
|
|
|
db.session.close()
|
|
|
-
|
|
|
+
|
|
|
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
|
|
"""
|
|
|
convert tool variables to db variables
|
|
|
"""
|
|
|
- db_variables = db.session.query(ToolConversationVariables).filter(
|
|
|
- ToolConversationVariables.conversation_id == self.message.conversation_id,
|
|
|
- ).first()
|
|
|
+ db_variables = (
|
|
|
+ db.session.query(ToolConversationVariables)
|
|
|
+ .filter(
|
|
|
+ ToolConversationVariables.conversation_id == self.message.conversation_id,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
|
@@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
|
|
|
if isinstance(prompt_message, SystemPromptMessage):
|
|
|
result.append(prompt_message)
|
|
|
|
|
|
- messages: list[Message] = db.session.query(Message).filter(
|
|
|
- Message.conversation_id == self.message.conversation_id,
|
|
|
- ).order_by(Message.created_at.asc()).all()
|
|
|
+ messages: list[Message] = (
|
|
|
+ db.session.query(Message)
|
|
|
+ .filter(
|
|
|
+ Message.conversation_id == self.message.conversation_id,
|
|
|
+ )
|
|
|
+ .order_by(Message.created_at.asc())
|
|
|
+ .all()
|
|
|
+ )
|
|
|
|
|
|
for message in messages:
|
|
|
if message.id == self.message.id:
|
|
@@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
|
|
|
for agent_thought in agent_thoughts:
|
|
|
tools = agent_thought.tool
|
|
|
if tools:
|
|
|
- tools = tools.split(';')
|
|
|
+ tools = tools.split(";")
|
|
|
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
|
|
tool_call_response: list[ToolPromptMessage] = []
|
|
|
try:
|
|
|
tool_inputs = json.loads(agent_thought.tool_input)
|
|
|
except Exception as e:
|
|
|
- tool_inputs = { tool: {} for tool in tools }
|
|
|
+ tool_inputs = {tool: {} for tool in tools}
|
|
|
try:
|
|
|
tool_responses = json.loads(agent_thought.observation)
|
|
|
except Exception as e:
|
|
@@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
|
|
|
for tool in tools:
|
|
|
# generate a uuid for tool call
|
|
|
tool_call_id = str(uuid.uuid4())
|
|
|
- tool_calls.append(AssistantPromptMessage.ToolCall(
|
|
|
- id=tool_call_id,
|
|
|
- type='function',
|
|
|
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ tool_calls.append(
|
|
|
+ AssistantPromptMessage.ToolCall(
|
|
|
+ id=tool_call_id,
|
|
|
+ type="function",
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=tool,
|
|
|
+ arguments=json.dumps(tool_inputs.get(tool, {})),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ tool_call_response.append(
|
|
|
+ ToolPromptMessage(
|
|
|
+ content=tool_responses.get(tool, agent_thought.observation),
|
|
|
name=tool,
|
|
|
- arguments=json.dumps(tool_inputs.get(tool, {})),
|
|
|
+ tool_call_id=tool_call_id,
|
|
|
)
|
|
|
- ))
|
|
|
- tool_call_response.append(ToolPromptMessage(
|
|
|
- content=tool_responses.get(tool, agent_thought.observation),
|
|
|
- name=tool,
|
|
|
- tool_call_id=tool_call_id,
|
|
|
- ))
|
|
|
-
|
|
|
- result.extend([
|
|
|
- AssistantPromptMessage(
|
|
|
- content=agent_thought.thought,
|
|
|
- tool_calls=tool_calls,
|
|
|
- ),
|
|
|
- *tool_call_response
|
|
|
- ])
|
|
|
+ )
|
|
|
+
|
|
|
+ result.extend(
|
|
|
+ [
|
|
|
+ AssistantPromptMessage(
|
|
|
+ content=agent_thought.thought,
|
|
|
+ tool_calls=tool_calls,
|
|
|
+ ),
|
|
|
+ *tool_call_response,
|
|
|
+ ]
|
|
|
+ )
|
|
|
if not tools:
|
|
|
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
|
|
else:
|
|
@@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
|
|
|
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
|
|
|
|
|
if file_extra_config:
|
|
|
- file_objs = message_file_parser.transform_message_files(
|
|
|
- files,
|
|
|
- file_extra_config
|
|
|
- )
|
|
|
+ file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
|
|
else:
|
|
|
file_objs = []
|
|
|
|