فهرست منبع

fix: `dialogue_count` incorrect in chatflow when there's... (#11175)

Hash Brown 4 ماه پیش
والد
کامیت
c4fad66f2a

+ 8 - 0
api/core/app/apps/advanced_chat/app_generator.py

@@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity,
 from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
+from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
 from extensions.ext_database import db
 from factories import file_factory
 from models.account import Account
@@ -33,6 +34,8 @@ logger = logging.getLogger(__name__)
 
 
 class AdvancedChatAppGenerator(MessageBasedAppGenerator):
+    _dialogue_count: int
+
     def generate(
         self,
         app_model: App,
@@ -211,6 +214,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             db.session.commit()
             db.session.refresh(conversation)
 
+        # get conversation dialogue count
+        self._dialogue_count = get_thread_messages_length(conversation.id)
+
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
             task_id=application_generate_entity.task_id,
@@ -281,6 +287,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
                     queue_manager=queue_manager,
                     conversation=conversation,
                     message=message,
+                    dialogue_count=self._dialogue_count,
                 )
 
                 runner.run()
@@ -334,6 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             message=message,
             user=user,
             stream=stream,
+            dialogue_count=self._dialogue_count,
         )
 
         try:

+ 3 - 7
api/core/app/apps/advanced_chat/app_runner.py

@@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
         queue_manager: AppQueueManager,
         conversation: Conversation,
         message: Message,
+        dialogue_count: int,
     ) -> None:
         super().__init__(queue_manager)
 
         self.application_generate_entity = application_generate_entity
         self.conversation = conversation
         self.message = message
+        self._dialogue_count = dialogue_count
 
     def run(self) -> None:
         app_config = self.application_generate_entity.app_config
@@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
 
                 session.commit()
 
-            # Increment dialogue count.
-            self.conversation.dialogue_count += 1
-
-            conversation_dialogue_count = self.conversation.dialogue_count
-            db.session.commit()
-
             # Create a variable pool.
             system_inputs = {
                 SystemVariableKey.QUERY: query,
                 SystemVariableKey.FILES: files,
                 SystemVariableKey.CONVERSATION_ID: self.conversation.id,
                 SystemVariableKey.USER_ID: user_id,
-                SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
+                SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
                 SystemVariableKey.APP_ID: app_config.app_id,
                 SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
                 SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,

+ 3 - 1
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         message: Message,
         user: Union[Account, EndUser],
         stream: bool,
+        dialogue_count: int,
     ) -> None:
         """
         Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         :param message: message
         :param user: user
         :param stream: stream
+        :param dialogue_count: dialogue count
         """
         super().__init__(application_generate_entity, queue_manager, user, stream)
 
@@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             SystemVariableKey.FILES: application_generate_entity.files,
             SystemVariableKey.CONVERSATION_ID: conversation.id,
             SystemVariableKey.USER_ID: user_id,
-            SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count,
+            SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
             SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,

+ 32 - 0
api/core/prompt/utils/get_thread_messages_length.py

@@ -0,0 +1,32 @@
+from core.prompt.utils.extract_thread_messages import extract_thread_messages
+from extensions.ext_database import db
+from models.model import Message
+
+
+def get_thread_messages_length(conversation_id: str) -> int:
+    """
+    Get the number of thread messages based on the parent message id.
+    """
+    # Fetch all messages related to the conversation
+    query = (
+        db.session.query(
+            Message.id,
+            Message.parent_message_id,
+            Message.answer,
+        )
+        .filter(
+            Message.conversation_id == conversation_id,
+        )
+        .order_by(Message.created_at.desc())
+    )
+
+    messages = query.all()
+
+    # Extract thread messages
+    thread_messages = extract_thread_messages(messages)
+
+    # Exclude the newly created message with an empty answer
+    if thread_messages and not thread_messages[0].answer:
+        thread_messages.pop(0)
+
+    return len(thread_messages)