Sfoglia il codice sorgente

feat: add from_variable_selector for stream chunk / message event (#8228)

takatost 7 mesi fa
parent
commit
cee0c51dbb

+ 3 - 1
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -451,7 +451,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     tts_publisher.publish(message=queue_message)
 
                 self._task_state.answer += delta_text
-                yield self._message_to_stream_response(delta_text, self._message.id)
+                yield self._message_to_stream_response(
+                    answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
+                )
             elif isinstance(event, QueueMessageReplaceEvent):
                 # published by moderation
                 yield self._message_replace_to_stream_response(answer=event.text)

+ 8 - 3
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -376,7 +376,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     tts_publisher.publish(message=queue_message)
 
                 self._task_state.answer += delta_text
-                yield self._text_chunk_to_stream_response(delta_text)
+                yield self._text_chunk_to_stream_response(
+                    delta_text, from_variable_selector=event.from_variable_selector
+                )
             else:
                 continue
 
@@ -412,14 +414,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         db.session.commit()
         db.session.close()
 
-    def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse:
+    def _text_chunk_to_stream_response(
+        self, text: str, from_variable_selector: Optional[list[str]] = None
+    ) -> TextChunkStreamResponse:
         """
         Handle completed event.
         :param text: text
         :return:
         """
         response = TextChunkStreamResponse(
-            task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
+            task_id=self._application_generate_entity.task_id,
+            data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
         )
 
         return response

+ 2 - 0
api/core/app/entities/task_entities.py

@@ -90,6 +90,7 @@ class MessageStreamResponse(StreamResponse):
     event: StreamEvent = StreamEvent.MESSAGE
     id: str
     answer: str
+    from_variable_selector: Optional[list[str]] = None
 
 
 class MessageAudioStreamResponse(StreamResponse):
@@ -479,6 +480,7 @@ class TextChunkStreamResponse(StreamResponse):
         """
 
         text: str
+        from_variable_selector: Optional[list[str]] = None
 
     event: StreamEvent = StreamEvent.TEXT_CHUNK
     data: Data

+ 9 - 2
api/core/app/task_pipeline/message_cycle_manage.py

@@ -153,14 +153,21 @@ class MessageCycleManage:
 
         return None
 
-    def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse:
+    def _message_to_stream_response(
+        self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
+    ) -> MessageStreamResponse:
         """
         Message to stream response.
         :param answer: answer
         :param message_id: message id
         :return:
         """
-        return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
+        return MessageStreamResponse(
+            task_id=self._application_generate_entity.task_id,
+            id=message_id,
+            answer=answer,
+            from_variable_selector=from_variable_selector,
+        )
 
     def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
         """

+ 1 - 0
api/core/workflow/nodes/answer/answer_stream_processor.py

@@ -108,6 +108,7 @@ class AnswerStreamProcessor(StreamProcessor):
                         route_node_state=event.route_node_state,
                         parallel_id=event.parallel_id,
                         parallel_start_node_id=event.parallel_start_node_id,
+                        from_variable_selector=[answer_node_id, "answer"],
                     )
                 else:
                     route_chunk = cast(VarGenerateRouteChunk, route_chunk)