Quellcode durchsuchen

fix: missing iterator in task pipeline (#4948)

Yeuoly vor 10 Monaten
Ursprung
Commit
80a87f36ea
1 geänderte Dateien mit 29 neuen und 2 gelöschten Zeilen
  1. 29 2
      api/core/app/apps/advanced_chat/generate_task_pipeline.py

+ 29 - 2
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -107,8 +107,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             usage=LLMUsage.empty_usage()
         )
 
-        self._stream_generate_routes = self._get_stream_generate_routes()
         self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
+        self._stream_generate_routes = self._get_stream_generate_routes()
         self._conversation_name_generate_thread = None
 
     def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@@ -410,6 +410,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 ingoing_edges.append(edge)
 
         if not ingoing_edges:
+            # check if it's the first node in the iteration
+            target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
+            if not target_node:
+                return []
+            
+            node_iteration_id = target_node.get('data', {}).get('iteration_id')
+            # get iteration start node id
+            for node in nodes:
+                if node.get('id') == node_iteration_id:
+                    if node.get('data', {}).get('start_node_id') == target_node_id:
+                        return [target_node_id]
+                    
             return []
 
         start_node_ids = []
@@ -514,6 +526,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 self._task_state.answer += route_chunk.text
                 yield self._message_to_stream_response(route_chunk.text, self._message.id)
             else:
+                value = None
                 route_chunk = cast(VarGenerateRouteChunk, route_chunk)
                 value_selector = route_chunk.value_selector
                 if not value_selector:
@@ -525,6 +538,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if route_chunk_node_id == 'sys':
                     # system variable
                     value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
+                elif route_chunk_node_id in self._iteration_nested_relations:
+                    # it's a iteration variable
+                    if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
+                        continue
+                    iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
+                    iterator = iteration_state.inputs
+                    if not iterator:
+                        continue
+                    iterator_selector = iterator.get('iterator_selector', [])
+                    if value_selector[1] == 'index':
+                        value = iteration_state.current_index
+                    elif value_selector[1] == 'item':
+                        value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
+                            iterator_selector) else None
                 else:
                     # check chunk node id is before current node id or equal to current node id
                     if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
@@ -554,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         else:
                             value = value.get(key)
 
-                if value:
+                if value is not None:
                     text = ''
                     if isinstance(value, str | int | float):
                         text = str(value)