Quellcode durchsuchen

Feat/stream react (#2498)

Yeuoly vor 1 Jahr
Ursprung
Commit
edb86f5f5a
1 geänderte Dateien mit 159 neuen und 150 gelöschten Zeilen
  1. 159 150
      api/core/features/assistant_cot_runner.py

+ 159 - 150
api/core/features/assistant_cot_runner.py

@@ -133,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
             # recale llm max tokens
             self.recale_llm_max_tokens(self.model_config, prompt_messages)
             # invoke model
-            llm_result: LLMResult = model_instance.invoke_llm(
+            chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
                 prompt_messages=prompt_messages,
                 model_parameters=app_orchestration_config.model_config.parameters,
                 tools=[],
                 stop=app_orchestration_config.model_config.stop,
-                stream=False,
+                stream=True,
                 user=self.user_id,
                 callbacks=[],
             )
 
             # check llm result
-            if not llm_result:
+            if not chunks:
                 raise ValueError("failed to invoke llm")
-
-            # get scratchpad
-            scratchpad = self._extract_response_scratchpad(llm_result.message.content)
-            agent_scratchpad.append(scratchpad)
-                        
-            # get llm usage
-            if llm_result.usage:
-                increase_usage(llm_usage, llm_result.usage)
             
+            usage_dict = {}
+            react_chunks = self._handle_stream_react(chunks, usage_dict)
+            scratchpad = AgentScratchpadUnit(
+                agent_response='',
+                thought='',
+                action_str='',
+                observation='',
+                action=None
+            )
+
             # publish agent thought if it's first iteration
             if iteration_step == 1:
                 self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
 
+            for chunk in react_chunks:
+                if isinstance(chunk, dict):
+                    scratchpad.agent_response += json.dumps(chunk)
+                    try:
+                        if scratchpad.action:
+                            raise Exception("")
+                        scratchpad.action_str = json.dumps(chunk)
+                        scratchpad.action = AgentScratchpadUnit.Action(
+                            action_name=chunk['action'],
+                            action_input=chunk['action_input']
+                        )
+                    except:
+                        scratchpad.thought += json.dumps(chunk)
+                        yield LLMResultChunk(
+                            model=self.model_config.model,
+                            prompt_messages=prompt_messages,
+                            system_fingerprint='',
+                            delta=LLMResultChunkDelta(
+                                index=0,
+                                message=AssistantPromptMessage(
+                                    content=json.dumps(chunk)
+                                ),
+                                usage=None
+                            )
+                        )
+                else:
+                    scratchpad.agent_response += chunk
+                    scratchpad.thought += chunk
+                    yield LLMResultChunk(
+                        model=self.model_config.model,
+                        prompt_messages=prompt_messages,
+                        system_fingerprint='',
+                        delta=LLMResultChunkDelta(
+                            index=0,
+                            message=AssistantPromptMessage(
+                                content=chunk
+                            ),
+                            usage=None
+                        )
+                    )
+
+            agent_scratchpad.append(scratchpad)
+                        
+            # get llm usage
+            if 'usage' in usage_dict:
+                increase_usage(llm_usage, usage_dict['usage'])
+            else:
+                usage_dict['usage'] = LLMUsage.empty_usage()
+            
             self.save_agent_thought(agent_thought=agent_thought,
                                     tool_name=scratchpad.action.action_name if scratchpad.action else '',
                                     tool_input=scratchpad.action.action_input if scratchpad.action else '',
                                     thought=scratchpad.thought,
                                     observation='',
-                                    answer=llm_result.message.content,
+                                    answer=scratchpad.agent_response,
                                     messages_ids=[],
-                                    llm_usage=llm_result.usage)
+                                    llm_usage=usage_dict['usage'])
             
             if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
                 self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
 
-            # publish agent thought if it's not empty and there is a action
-            if scratchpad.thought and scratchpad.action:
-                # check if final answer
-                if not scratchpad.action.action_name.lower() == "final answer":
-                    yield LLMResultChunk(
-                        model=model_instance.model,
-                        prompt_messages=prompt_messages,
-                        delta=LLMResultChunkDelta(
-                            index=0,
-                            message=AssistantPromptMessage(
-                                content=scratchpad.thought
-                            ),
-                            usage=llm_result.usage,
-                        ),
-                        system_fingerprint=''
-                    )
-
             if not scratchpad.action:
                 # failed to extract action, return final answer directly
                 final_answer = scratchpad.agent_response or ''
@@ -262,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
 
                         # save scratchpad
                         scratchpad.observation = observation
-                        scratchpad.agent_response = llm_result.message.content
 
                         # save agent thought
                         self.save_agent_thought(
@@ -271,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                             tool_input=tool_call_args,
                             thought=None,
                             observation=observation, 
-                            answer=llm_result.message.content,
+                            answer=scratchpad.agent_response,
                             messages_ids=message_file_ids,
                         )
                         self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
@@ -318,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
             system_fingerprint=''
         ), PublishFrom.APPLICATION_MANAGER)
 
+    def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
+        -> Generator[Union[str, dict], None, None]:
+        def parse_json(json_str):
+            try:
+                return json.loads(json_str.strip())
+            except:
+                return json_str
+            
+        def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
+            code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
+            if not code_blocks:
+                return
+            for block in code_blocks:
+                json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
+                yield parse_json(json_text)
+            
+        code_block_cache = ''
+        code_block_delimiter_count = 0
+        in_code_block = False
+        json_cache = ''
+        json_quote_count = 0
+        in_json = False
+        got_json = False
+    
+        for response in llm_response:
+            response = response.delta.message.content
+            if not isinstance(response, str):
+                continue
+
+            # stream
+            index = 0
+            while index < len(response):
+                steps = 1
+                delta = response[index:index+steps]
+                if delta == '`':
+                    code_block_cache += delta
+                    code_block_delimiter_count += 1
+                else:
+                    if not in_code_block:
+                        if code_block_delimiter_count > 0:
+                            yield code_block_cache
+                        code_block_cache = ''
+                    else:
+                        code_block_cache += delta
+                    code_block_delimiter_count = 0
+
+                if code_block_delimiter_count == 3:
+                    if in_code_block:
+                        yield from extra_json_from_code_block(code_block_cache)
+                        code_block_cache = ''
+                        
+                    in_code_block = not in_code_block
+                    code_block_delimiter_count = 0
+
+                if not in_code_block:
+                    # handle single json
+                    if delta == '{':
+                        json_quote_count += 1
+                        in_json = True
+                        json_cache += delta
+                    elif delta == '}':
+                        json_cache += delta
+                        if json_quote_count > 0:
+                            json_quote_count -= 1
+                            if json_quote_count == 0:
+                                in_json = False
+                                got_json = True
+                                index += steps
+                                continue
+                    else:
+                        if in_json:
+                            json_cache += delta
+
+                    if got_json:
+                        got_json = False
+                        yield parse_json(json_cache)
+                        json_cache = ''
+                        json_quote_count = 0
+                        in_json = False
+                    
+                if not in_code_block and not in_json:
+                    yield delta.replace('`', '')
+
+                index += steps
+
+        if code_block_cache:
+            yield code_block_cache
+
+        if json_cache:
+            yield parse_json(json_cache)
+
     def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
         """
         fill in inputs from external data tools
@@ -363,121 +487,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
 
         return agent_scratchpad
 
-    def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
-        """
-        extract response from llm response
-        """
-        def extra_quotes() -> AgentScratchpadUnit:
-            agent_response = content
-            # try to extract all quotes
-            pattern = re.compile(r'```(.*?)```', re.DOTALL)
-            quotes = pattern.findall(content)
-
-            # try to extract action from end to start
-            for i in range(len(quotes) - 1, 0, -1):
-                """
-                    1. use json load to parse action
-                    2. use plain text `Action: xxx` to parse action
-                """
-                try:
-                    action = json.loads(quotes[i].replace('```', ''))
-                    action_name = action.get("action")
-                    action_input = action.get("action_input")
-                    agent_thought = agent_response.replace(quotes[i], '')
-
-                    if action_name and action_input:
-                        return AgentScratchpadUnit(
-                            agent_response=content,
-                            thought=agent_thought,
-                            action_str=quotes[i],
-                            action=AgentScratchpadUnit.Action(
-                                action_name=action_name,
-                                action_input=action_input,
-                            )
-                        )
-                except:
-                    # try to parse action from plain text
-                    action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
-                    action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
-                    # delete action from agent response
-                    agent_thought = agent_response.replace(quotes[i], '')
-                    # remove extra quotes
-                    agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
-                    # remove Action: xxx from agent thought
-                    agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
-
-                    if action_name and action_input:
-                        return AgentScratchpadUnit(
-                            agent_response=content,
-                            thought=agent_thought,
-                            action_str=quotes[i],
-                            action=AgentScratchpadUnit.Action(
-                                action_name=action_name[0],
-                                action_input=action_input[0],
-                            )
-                        )
-
-        def extra_json():
-            agent_response = content
-            # try to extract all json
-            structures, pair_match_stack = [], []
-            started_at, end_at = 0, 0
-            for i in range(len(content)):
-                if content[i] == '{':
-                    pair_match_stack.append(i)
-                    if len(pair_match_stack) == 1:
-                        started_at = i
-                elif content[i] == '}':
-                    begin = pair_match_stack.pop()
-                    if not pair_match_stack:
-                        end_at = i + 1
-                        structures.append((content[begin:i+1], (started_at, end_at)))
-
-            # handle the last character
-            if pair_match_stack:
-                end_at = len(content)
-                structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
-            
-            for i in range(len(structures), 0, -1):
-                try:
-                    json_content, (started_at, end_at) = structures[i - 1]
-                    action = json.loads(json_content)
-                    action_name = action.get("action")
-                    action_input = action.get("action_input")
-                    # delete json content from agent response
-                    agent_thought = agent_response[:started_at] + agent_response[end_at:]
-                    # remove extra quotes like ```(json)*\n\n```
-                    agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
-                    # remove Action: xxx from agent thought
-                    agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
-
-                    if action_name and action_input is not None:
-                        return AgentScratchpadUnit(
-                            agent_response=content,
-                            thought=agent_thought,
-                            action_str=json_content,
-                            action=AgentScratchpadUnit.Action(
-                                action_name=action_name,
-                                action_input=action_input,
-                            )
-                        )
-                except:
-                    pass
-        
-        agent_scratchpad = extra_quotes()
-        if agent_scratchpad:
-            return agent_scratchpad
-        agent_scratchpad = extra_json()
-        if agent_scratchpad:
-            return agent_scratchpad
-        
-        return AgentScratchpadUnit(
-            agent_response=content,
-            thought=content,
-            action_str='',
-            action=None
-        )
-        
     def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], 
                                       agent_prompt_message: AgentPromptEntity,
     ):
@@ -591,15 +600,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
         # organize prompt messages
         if mode == "chat":
             # override system message
-            overrided = False
+            overridden = False
             prompt_messages = prompt_messages.copy()
             for prompt_message in prompt_messages:
                 if isinstance(prompt_message, SystemPromptMessage):
                     prompt_message.content = system_message
-                    overrided = True
+                    overridden = True
                     break
 
-            if not overrided:
+            if not overridden:
                 prompt_messages.insert(0, SystemPromptMessage(
                     content=system_message,
                 ))