|
@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
|
|
|
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
|
|
|
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
|
|
|
from core.model_manager import ModelInstance
|
|
|
from core.application_queue_manager import PublishFrom
|
|
|
|
|
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
- def run(self, model_instance: ModelInstance,
|
|
|
- conversation: Conversation,
|
|
|
+ def run(self, conversation: Conversation,
|
|
|
message: Message,
|
|
|
query: str,
|
|
|
) -> Generator[LLMResultChunk, None, None]:
|
|
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
llm_usage.prompt_price += usage.prompt_price
|
|
|
llm_usage.completion_price += usage.completion_price
|
|
|
|
|
|
+ model_instance = self.model_instance
|
|
|
+
|
|
|
while function_call_state and iteration_step <= max_iteration_steps:
|
|
|
function_call_state = False
|
|
|
|
|
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
# recale llm max tokens
|
|
|
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
|
|
# invoke model
|
|
|
- chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
|
|
+ chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
|
|
prompt_messages=prompt_messages,
|
|
|
model_parameters=app_orchestration_config.model_config.parameters,
|
|
|
tools=prompt_messages_tools,
|
|
|
stop=app_orchestration_config.model_config.stop,
|
|
|
- stream=True,
|
|
|
+ stream=self.stream_tool_call,
|
|
|
user=self.user_id,
|
|
|
callbacks=[],
|
|
|
)
|
|
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
|
current_llm_usage = None
|
|
|
|
|
|
- for chunk in chunks:
|
|
|
+ if self.stream_tool_call:
|
|
|
+ for chunk in chunks:
|
|
|
+ # check if there is any tool call
|
|
|
+ if self.check_tool_calls(chunk):
|
|
|
+ function_call_state = True
|
|
|
+ tool_calls.extend(self.extract_tool_calls(chunk))
|
|
|
+ tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
|
|
+ try:
|
|
|
+ tool_call_inputs = json.dumps({
|
|
|
+ tool_call[1]: tool_call[2] for tool_call in tool_calls
|
|
|
+ }, ensure_ascii=False)
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ # ensure ascii to avoid encoding error
|
|
|
+ tool_call_inputs = json.dumps({
|
|
|
+ tool_call[1]: tool_call[2] for tool_call in tool_calls
|
|
|
+ })
|
|
|
+
|
|
|
+ if chunk.delta.message and chunk.delta.message.content:
|
|
|
+ if isinstance(chunk.delta.message.content, list):
|
|
|
+ for content in chunk.delta.message.content:
|
|
|
+ response += content.data
|
|
|
+ else:
|
|
|
+ response += chunk.delta.message.content
|
|
|
+
|
|
|
+ if chunk.delta.usage:
|
|
|
+ increase_usage(llm_usage, chunk.delta.usage)
|
|
|
+ current_llm_usage = chunk.delta.usage
|
|
|
+
|
|
|
+ yield chunk
|
|
|
+ else:
|
|
|
+ result: LLMResult = chunks
|
|
|
# check if there is any tool call
|
|
|
- if self.check_tool_calls(chunk):
|
|
|
+ if self.check_blocking_tool_calls(result):
|
|
|
function_call_state = True
|
|
|
- tool_calls.extend(self.extract_tool_calls(chunk))
|
|
|
+ tool_calls.extend(self.extract_blocking_tool_calls(result))
|
|
|
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
|
|
try:
|
|
|
tool_call_inputs = json.dumps({
|
|
@@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
|
|
})
|
|
|
|
|
|
- if chunk.delta.message and chunk.delta.message.content:
|
|
|
- if isinstance(chunk.delta.message.content, list):
|
|
|
- for content in chunk.delta.message.content:
|
|
|
+ if result.usage:
|
|
|
+ increase_usage(llm_usage, result.usage)
|
|
|
+ current_llm_usage = result.usage
|
|
|
+
|
|
|
+ if result.message and result.message.content:
|
|
|
+ if isinstance(result.message.content, list):
|
|
|
+ for content in result.message.content:
|
|
|
response += content.data
|
|
|
else:
|
|
|
- response += chunk.delta.message.content
|
|
|
-
|
|
|
- if chunk.delta.usage:
|
|
|
- increase_usage(llm_usage, chunk.delta.usage)
|
|
|
- current_llm_usage = chunk.delta.usage
|
|
|
+ response += result.message.content
|
|
|
+
|
|
|
+ if not result.message.content:
|
|
|
+ result.message.content = ''
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model_instance.model,
|
|
|
+ prompt_messages=result.prompt_messages,
|
|
|
+ system_fingerprint=result.system_fingerprint,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=result.message,
|
|
|
+ usage=result.usage,
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
- yield chunk
|
|
|
+ if tool_calls:
|
|
|
+ prompt_messages.append(AssistantPromptMessage(
|
|
|
+ content='',
|
|
|
+ name='',
|
|
|
+ tool_calls=[AssistantPromptMessage.ToolCall(
|
|
|
+ id=tool_call[0],
|
|
|
+ type='function',
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=tool_call[1],
|
|
|
+ arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
|
|
+ )
|
|
|
+ ) for tool_call in tool_calls]
|
|
|
+ ))
|
|
|
|
|
|
# save thought
|
|
|
self.save_agent_thought(
|
|
@@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
|
final_answer += response + '\n'
|
|
|
|
|
|
+ # update prompt messages
|
|
|
+ if response.strip():
|
|
|
+ prompt_messages.append(AssistantPromptMessage(
|
|
|
+ content=response,
|
|
|
+ ))
|
|
|
+
|
|
|
# call tools
|
|
|
tool_responses = []
|
|
|
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
|
@@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
)
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
|
|
|
|
|
- # update prompt messages
|
|
|
- if response.strip():
|
|
|
- prompt_messages.append(AssistantPromptMessage(
|
|
|
- content=response,
|
|
|
- ))
|
|
|
-
|
|
|
# update prompt tool
|
|
|
for prompt_tool in prompt_messages_tools:
|
|
|
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
|
@@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
if llm_result_chunk.delta.message.tool_calls:
|
|
|
return True
|
|
|
return False
|
|
|
+
|
|
|
+ def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
|
|
+ """
|
|
|
+ Check if there is any blocking tool call in llm result
|
|
|
+ """
|
|
|
+ if llm_result.message.tool_calls:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
|
|
|
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
|
|
"""
|
|
@@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
))
|
|
|
|
|
|
return tool_calls
|
|
|
+
|
|
|
+ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
|
|
+ """
|
|
|
+ Extract blocking tool calls from llm result
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
|
|
+ """
|
|
|
+ tool_calls = []
|
|
|
+ for prompt_message in llm_result.message.tool_calls:
|
|
|
+ tool_calls.append((
|
|
|
+ prompt_message.id,
|
|
|
+ prompt_message.function.name,
|
|
|
+ json.loads(prompt_message.function.arguments),
|
|
|
+ ))
|
|
|
+
|
|
|
+ return tool_calls
|
|
|
|
|
|
def organize_prompt_messages(self, prompt_template: str,
|
|
|
query: str = None,
|