|
@@ -133,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
# recale llm max tokens
|
|
|
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
|
|
# invoke model
|
|
|
- llm_result: LLMResult = model_instance.invoke_llm(
|
|
|
+ chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
|
|
prompt_messages=prompt_messages,
|
|
|
model_parameters=app_orchestration_config.model_config.parameters,
|
|
|
tools=[],
|
|
|
stop=app_orchestration_config.model_config.stop,
|
|
|
- stream=False,
|
|
|
+ stream=True,
|
|
|
user=self.user_id,
|
|
|
callbacks=[],
|
|
|
)
|
|
|
|
|
|
# check llm result
|
|
|
- if not llm_result:
|
|
|
+ if not chunks:
|
|
|
raise ValueError("failed to invoke llm")
|
|
|
-
|
|
|
- # get scratchpad
|
|
|
- scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
|
|
- agent_scratchpad.append(scratchpad)
|
|
|
-
|
|
|
- # get llm usage
|
|
|
- if llm_result.usage:
|
|
|
- increase_usage(llm_usage, llm_result.usage)
|
|
|
|
|
|
+ usage_dict = {}
|
|
|
+ react_chunks = self._handle_stream_react(chunks, usage_dict)
|
|
|
+ scratchpad = AgentScratchpadUnit(
|
|
|
+ agent_response='',
|
|
|
+ thought='',
|
|
|
+ action_str='',
|
|
|
+ observation='',
|
|
|
+ action=None
|
|
|
+ )
|
|
|
+
|
|
|
# publish agent thought if it's first iteration
|
|
|
if iteration_step == 1:
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
|
|
|
|
|
+ for chunk in react_chunks:
|
|
|
+ if isinstance(chunk, dict):
|
|
|
+ scratchpad.agent_response += json.dumps(chunk)
|
|
|
+ try:
|
|
|
+ if scratchpad.action:
|
|
|
+ raise Exception("")
|
|
|
+ scratchpad.action_str = json.dumps(chunk)
|
|
|
+ scratchpad.action = AgentScratchpadUnit.Action(
|
|
|
+ action_name=chunk['action'],
|
|
|
+ action_input=chunk['action_input']
|
|
|
+ )
|
|
|
+ except:
|
|
|
+ scratchpad.thought += json.dumps(chunk)
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=self.model_config.model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint='',
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=json.dumps(chunk)
|
|
|
+ ),
|
|
|
+ usage=None
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ scratchpad.agent_response += chunk
|
|
|
+ scratchpad.thought += chunk
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=self.model_config.model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint='',
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=chunk
|
|
|
+ ),
|
|
|
+ usage=None
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ agent_scratchpad.append(scratchpad)
|
|
|
+
|
|
|
+ # get llm usage
|
|
|
+ if 'usage' in usage_dict:
|
|
|
+ increase_usage(llm_usage, usage_dict['usage'])
|
|
|
+ else:
|
|
|
+ usage_dict['usage'] = LLMUsage.empty_usage()
|
|
|
+
|
|
|
self.save_agent_thought(agent_thought=agent_thought,
|
|
|
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
|
|
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
|
|
thought=scratchpad.thought,
|
|
|
observation='',
|
|
|
- answer=llm_result.message.content,
|
|
|
+ answer=scratchpad.agent_response,
|
|
|
messages_ids=[],
|
|
|
- llm_usage=llm_result.usage)
|
|
|
+ llm_usage=usage_dict['usage'])
|
|
|
|
|
|
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
|
|
|
|
|
- # publish agent thought if it's not empty and there is a action
|
|
|
- if scratchpad.thought and scratchpad.action:
|
|
|
- # check if final answer
|
|
|
- if not scratchpad.action.action_name.lower() == "final answer":
|
|
|
- yield LLMResultChunk(
|
|
|
- model=model_instance.model,
|
|
|
- prompt_messages=prompt_messages,
|
|
|
- delta=LLMResultChunkDelta(
|
|
|
- index=0,
|
|
|
- message=AssistantPromptMessage(
|
|
|
- content=scratchpad.thought
|
|
|
- ),
|
|
|
- usage=llm_result.usage,
|
|
|
- ),
|
|
|
- system_fingerprint=''
|
|
|
- )
|
|
|
-
|
|
|
if not scratchpad.action:
|
|
|
# failed to extract action, return final answer directly
|
|
|
final_answer = scratchpad.agent_response or ''
|
|
@@ -262,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
|
# save scratchpad
|
|
|
scratchpad.observation = observation
|
|
|
- scratchpad.agent_response = llm_result.message.content
|
|
|
|
|
|
# save agent thought
|
|
|
self.save_agent_thought(
|
|
@@ -271,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
tool_input=tool_call_args,
|
|
|
thought=None,
|
|
|
observation=observation,
|
|
|
- answer=llm_result.message.content,
|
|
|
+ answer=scratchpad.agent_response,
|
|
|
messages_ids=message_file_ids,
|
|
|
)
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
|
@@ -318,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
system_fingerprint=''
|
|
|
), PublishFrom.APPLICATION_MANAGER)
|
|
|
|
|
|
+ def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
|
|
+ -> Generator[Union[str, dict], None, None]:
|
|
|
+ def parse_json(json_str):
|
|
|
+ try:
|
|
|
+ return json.loads(json_str.strip())
|
|
|
+ except:
|
|
|
+ return json_str
|
|
|
+
|
|
|
+ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
|
|
+ code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
|
|
+ if not code_blocks:
|
|
|
+ return
|
|
|
+ for block in code_blocks:
|
|
|
+ json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
|
|
+ yield parse_json(json_text)
|
|
|
+
|
|
|
+ code_block_cache = ''
|
|
|
+ code_block_delimiter_count = 0
|
|
|
+ in_code_block = False
|
|
|
+ json_cache = ''
|
|
|
+ json_quote_count = 0
|
|
|
+ in_json = False
|
|
|
+ got_json = False
|
|
|
+
|
|
|
+ for response in llm_response:
|
|
|
+ response = response.delta.message.content
|
|
|
+ if not isinstance(response, str):
|
|
|
+ continue
|
|
|
+
|
|
|
+ # stream
|
|
|
+ index = 0
|
|
|
+ while index < len(response):
|
|
|
+ steps = 1
|
|
|
+ delta = response[index:index+steps]
|
|
|
+ if delta == '`':
|
|
|
+ code_block_cache += delta
|
|
|
+ code_block_delimiter_count += 1
|
|
|
+ else:
|
|
|
+ if not in_code_block:
|
|
|
+ if code_block_delimiter_count > 0:
|
|
|
+ yield code_block_cache
|
|
|
+ code_block_cache = ''
|
|
|
+ else:
|
|
|
+ code_block_cache += delta
|
|
|
+ code_block_delimiter_count = 0
|
|
|
+
|
|
|
+ if code_block_delimiter_count == 3:
|
|
|
+ if in_code_block:
|
|
|
+ yield from extra_json_from_code_block(code_block_cache)
|
|
|
+ code_block_cache = ''
|
|
|
+
|
|
|
+ in_code_block = not in_code_block
|
|
|
+ code_block_delimiter_count = 0
|
|
|
+
|
|
|
+ if not in_code_block:
|
|
|
+ # handle single json
|
|
|
+ if delta == '{':
|
|
|
+ json_quote_count += 1
|
|
|
+ in_json = True
|
|
|
+ json_cache += delta
|
|
|
+ elif delta == '}':
|
|
|
+ json_cache += delta
|
|
|
+ if json_quote_count > 0:
|
|
|
+ json_quote_count -= 1
|
|
|
+ if json_quote_count == 0:
|
|
|
+ in_json = False
|
|
|
+ got_json = True
|
|
|
+ index += steps
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ if in_json:
|
|
|
+ json_cache += delta
|
|
|
+
|
|
|
+ if got_json:
|
|
|
+ got_json = False
|
|
|
+ yield parse_json(json_cache)
|
|
|
+ json_cache = ''
|
|
|
+ json_quote_count = 0
|
|
|
+ in_json = False
|
|
|
+
|
|
|
+ if not in_code_block and not in_json:
|
|
|
+ yield delta.replace('`', '')
|
|
|
+
|
|
|
+ index += steps
|
|
|
+
|
|
|
+ if code_block_cache:
|
|
|
+ yield code_block_cache
|
|
|
+
|
|
|
+ if json_cache:
|
|
|
+ yield parse_json(json_cache)
|
|
|
+
|
|
|
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
|
|
"""
|
|
|
fill in inputs from external data tools
|
|
@@ -363,121 +487,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
|
return agent_scratchpad
|
|
|
|
|
|
- def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
|
|
- """
|
|
|
- extract response from llm response
|
|
|
- """
|
|
|
- def extra_quotes() -> AgentScratchpadUnit:
|
|
|
- agent_response = content
|
|
|
- # try to extract all quotes
|
|
|
- pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
|
|
- quotes = pattern.findall(content)
|
|
|
-
|
|
|
- # try to extract action from end to start
|
|
|
- for i in range(len(quotes) - 1, 0, -1):
|
|
|
- """
|
|
|
- 1. use json load to parse action
|
|
|
- 2. use plain text `Action: xxx` to parse action
|
|
|
- """
|
|
|
- try:
|
|
|
- action = json.loads(quotes[i].replace('```', ''))
|
|
|
- action_name = action.get("action")
|
|
|
- action_input = action.get("action_input")
|
|
|
- agent_thought = agent_response.replace(quotes[i], '')
|
|
|
-
|
|
|
- if action_name and action_input:
|
|
|
- return AgentScratchpadUnit(
|
|
|
- agent_response=content,
|
|
|
- thought=agent_thought,
|
|
|
- action_str=quotes[i],
|
|
|
- action=AgentScratchpadUnit.Action(
|
|
|
- action_name=action_name,
|
|
|
- action_input=action_input,
|
|
|
- )
|
|
|
- )
|
|
|
- except:
|
|
|
- # try to parse action from plain text
|
|
|
- action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
|
|
- action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
|
|
- # delete action from agent response
|
|
|
- agent_thought = agent_response.replace(quotes[i], '')
|
|
|
- # remove extra quotes
|
|
|
- agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
|
|
- # remove Action: xxx from agent thought
|
|
|
- agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
|
|
-
|
|
|
- if action_name and action_input:
|
|
|
- return AgentScratchpadUnit(
|
|
|
- agent_response=content,
|
|
|
- thought=agent_thought,
|
|
|
- action_str=quotes[i],
|
|
|
- action=AgentScratchpadUnit.Action(
|
|
|
- action_name=action_name[0],
|
|
|
- action_input=action_input[0],
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- def extra_json():
|
|
|
- agent_response = content
|
|
|
- # try to extract all json
|
|
|
- structures, pair_match_stack = [], []
|
|
|
- started_at, end_at = 0, 0
|
|
|
- for i in range(len(content)):
|
|
|
- if content[i] == '{':
|
|
|
- pair_match_stack.append(i)
|
|
|
- if len(pair_match_stack) == 1:
|
|
|
- started_at = i
|
|
|
- elif content[i] == '}':
|
|
|
- begin = pair_match_stack.pop()
|
|
|
- if not pair_match_stack:
|
|
|
- end_at = i + 1
|
|
|
- structures.append((content[begin:i+1], (started_at, end_at)))
|
|
|
-
|
|
|
- # handle the last character
|
|
|
- if pair_match_stack:
|
|
|
- end_at = len(content)
|
|
|
- structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
|
|
-
|
|
|
- for i in range(len(structures), 0, -1):
|
|
|
- try:
|
|
|
- json_content, (started_at, end_at) = structures[i - 1]
|
|
|
- action = json.loads(json_content)
|
|
|
- action_name = action.get("action")
|
|
|
- action_input = action.get("action_input")
|
|
|
- # delete json content from agent response
|
|
|
- agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
|
|
- # remove extra quotes like ```(json)*\n\n```
|
|
|
- agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
|
|
- # remove Action: xxx from agent thought
|
|
|
- agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
|
|
-
|
|
|
- if action_name and action_input is not None:
|
|
|
- return AgentScratchpadUnit(
|
|
|
- agent_response=content,
|
|
|
- thought=agent_thought,
|
|
|
- action_str=json_content,
|
|
|
- action=AgentScratchpadUnit.Action(
|
|
|
- action_name=action_name,
|
|
|
- action_input=action_input,
|
|
|
- )
|
|
|
- )
|
|
|
- except:
|
|
|
- pass
|
|
|
-
|
|
|
- agent_scratchpad = extra_quotes()
|
|
|
- if agent_scratchpad:
|
|
|
- return agent_scratchpad
|
|
|
- agent_scratchpad = extra_json()
|
|
|
- if agent_scratchpad:
|
|
|
- return agent_scratchpad
|
|
|
-
|
|
|
- return AgentScratchpadUnit(
|
|
|
- agent_response=content,
|
|
|
- thought=content,
|
|
|
- action_str='',
|
|
|
- action=None
|
|
|
- )
|
|
|
-
|
|
|
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
|
|
agent_prompt_message: AgentPromptEntity,
|
|
|
):
|
|
@@ -591,15 +600,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
# organize prompt messages
|
|
|
if mode == "chat":
|
|
|
# override system message
|
|
|
- overrided = False
|
|
|
+ overridden = False
|
|
|
prompt_messages = prompt_messages.copy()
|
|
|
for prompt_message in prompt_messages:
|
|
|
if isinstance(prompt_message, SystemPromptMessage):
|
|
|
prompt_message.content = system_message
|
|
|
- overrided = True
|
|
|
+ overridden = True
|
|
|
break
|
|
|
|
|
|
- if not overrided:
|
|
|
+ if not overridden:
|
|
|
prompt_messages.insert(0, SystemPromptMessage(
|
|
|
content=system_message,
|
|
|
))
|