Quellcode durchsuchen

feat(api/workflow): Add `Conversation.dialogue_count` (#7275)

-LAN- vor 8 Monaten
Ursprung
Commit
32dc963556
29 geänderte Dateien mit 205 neuen und 259 gelöschten Zeilen
  1. 5 1
      api/contexts/__init__.py
  2. 84 24
      api/core/app/apps/advanced_chat/app_generator.py
  3. 2 51
      api/core/app/apps/advanced_chat/app_runner.py
  4. 10 4
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  5. 4 1
      api/core/app/apps/message_based_app_generator.py
  6. 1 1
      api/core/app/apps/workflow/app_runner.py
  7. 3 3
      api/core/app/apps/workflow/generate_task_pipeline.py
  8. 0 6
      api/core/app/segments/__init__.py
  9. 0 12
      api/core/app/segments/factory.py
  10. 0 13
      api/core/app/segments/segments.py
  11. 0 2
      api/core/app/segments/types.py
  12. 0 9
      api/core/app/segments/variables.py
  13. 2 2
      api/core/app/task_pipeline/workflow_cycle_state_manager.py
  14. 4 24
      api/core/workflow/entities/node_entities.py
  15. 1 1
      api/core/workflow/entities/variable_pool.py
  16. 25 0
      api/core/workflow/enums.py
  17. 6 5
      api/core/workflow/nodes/llm/llm_node.py
  18. 6 5
      api/core/workflow/nodes/tool/tool_node.py
  19. 4 1
      api/core/workflow/workflow_engine_manager.py
  20. 33 0
      api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py
  21. 3 3
      api/models/__init__.py
  22. 3 2
      api/models/model.py
  23. 2 2
      api/tests/integration_tests/workflow/nodes/test_llm.py
  24. 3 3
      api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
  25. 0 80
      api/tests/unit_tests/core/app/segments/test_factory.py
  26. 1 1
      api/tests/unit_tests/core/app/segments/test_segment.py
  27. 1 1
      api/tests/unit_tests/core/workflow/nodes/test_answer.py
  28. 1 1
      api/tests/unit_tests/core/workflow/nodes/test_if_else.py
  29. 1 1
      api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py

+ 5 - 1
api/contexts/__init__.py

@@ -1,3 +1,7 @@
 from contextvars import ContextVar
 
-tenant_id: ContextVar[str] = ContextVar('tenant_id')
+from core.workflow.entities.variable_pool import VariablePool
+
+tenant_id: ContextVar[str] = ContextVar('tenant_id')
+
+workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')

+ 84 - 24
api/core/app/apps/advanced_chat/app_generator.py

@@ -8,6 +8,8 @@ from typing import Union
 
 from flask import Flask, current_app
 from pydantic import ValidationError
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 
 import contexts
 from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
 from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
 from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
 from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
-from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
+from core.app.entities.app_invoke_entities import (
+    AdvancedChatAppGenerateEntity,
+    InvokeFrom,
+)
 from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
 from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from extensions.ext_database import db
 from models.account import Account
 from models.model import App, Conversation, EndUser, Message
-from models.workflow import Workflow
+from models.workflow import ConversationVariable, Workflow
 
 logger = logging.getLogger(__name__)
 
@@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             conversation=conversation,
             stream=stream
         )
-    
+
     def single_iteration_generate(self, app_model: App,
                                   workflow: Workflow,
                                   node_id: str,
@@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         """
         if not node_id:
             raise ValueError('node_id is required')
-        
+
         if args.get('inputs') is None:
             raise ValueError('inputs is required')
-        
+
         extras = {
             "auto_generate_conversation_name": False
         }
@@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             # update conversation features
             conversation.override_model_configs = workflow.features
             db.session.commit()
-            db.session.refresh(conversation)
+            # db.session.refresh(conversation)
 
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
@@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             message_id=message.id
         )
 
+        # Init conversation variables
+        stmt = select(ConversationVariable).where(
+            ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
+        )
+        with Session(db.engine) as session:
+            conversation_variables = session.scalars(stmt).all()
+            if not conversation_variables:
+                # Create conversation variables if they don't exist.
+                conversation_variables = [
+                    ConversationVariable.from_variable(
+                        app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
+                    )
+                    for variable in workflow.conversation_variables
+                ]
+                session.add_all(conversation_variables)
+            # Convert database entities to variables.
+            conversation_variables = [item.to_variable() for item in conversation_variables]
+
+            session.commit()
+
+            # Increment dialogue count.
+            conversation.dialogue_count += 1
+
+            conversation_id = conversation.id
+            conversation_dialogue_count = conversation.dialogue_count
+            db.session.commit()
+            db.session.refresh(conversation)
+
+        inputs = application_generate_entity.inputs
+        query = application_generate_entity.query
+        files = application_generate_entity.files
+
+        user_id = None
+        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+            end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
+            if end_user:
+                user_id = end_user.session_id
+        else:
+            user_id = application_generate_entity.user_id
+
+        # Create a variable pool.
+        system_inputs = {
+            SystemVariable.QUERY: query,
+            SystemVariable.FILES: files,
+            SystemVariable.CONVERSATION_ID: conversation_id,
+            SystemVariable.USER_ID: user_id,
+            SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
+        }
+        variable_pool = VariablePool(
+            system_variables=system_inputs,
+            user_inputs=inputs,
+            environment_variables=workflow.environment_variables,
+            conversation_variables=conversation_variables,
+        )
+        contexts.workflow_variable_pool.set(variable_pool)
+
         # new thread
         worker_thread = threading.Thread(target=self._generate_worker, kwargs={
             'flask_app': current_app._get_current_object(),
             'application_generate_entity': application_generate_entity,
             'queue_manager': queue_manager,
-            'conversation_id': conversation.id,
             'message_id': message.id,
-            'user': user,
-            'context': contextvars.copy_context()
+            'context': contextvars.copy_context(),
         })
 
         worker_thread.start()
@@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             conversation=conversation,
             message=message,
             user=user,
-            stream=stream
+            stream=stream,
         )
 
         return AdvancedChatAppGenerateResponseConverter.convert(
@@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
     def _generate_worker(self, flask_app: Flask,
                          application_generate_entity: AdvancedChatAppGenerateEntity,
                          queue_manager: AppQueueManager,
-                         conversation_id: str,
                          message_id: str,
-                         user: Account,
                          context: contextvars.Context) -> None:
         """
         Generate worker in a new thread.
@@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
                         user_id=application_generate_entity.user_id
                     )
                 else:
-                    # get conversation and message
-                    conversation = self._get_conversation(conversation_id)
+                    # get message
                     message = self._get_message(message_id)
 
                     # chatbot app
@@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
                     runner.run(
                         application_generate_entity=application_generate_entity,
                         queue_manager=queue_manager,
-                        conversation=conversation,
                         message=message
                     )
             except GenerateTaskStoppedException:
@@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             finally:
                 db.session.close()
 
-    def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
-                                       workflow: Workflow,
-                                       queue_manager: AppQueueManager,
-                                       conversation: Conversation,
-                                       message: Message,
-                                       user: Union[Account, EndUser],
-                                       stream: bool = False) \
-            -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
+    def _handle_advanced_chat_response(
+        self,
+        *,
+        application_generate_entity: AdvancedChatAppGenerateEntity,
+        workflow: Workflow,
+        queue_manager: AppQueueManager,
+        conversation: Conversation,
+        message: Message,
+        user: Union[Account, EndUser],
+        stream: bool = False,
+    ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
         """
         Handle response.
         :param application_generate_entity: application generate entity
@@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             conversation=conversation,
             message=message,
             user=user,
-            stream=stream
+            stream=stream,
         )
 
         try:

+ 2 - 51
api/core/app/apps/advanced_chat/app_runner.py

@@ -4,9 +4,6 @@ import time
 from collections.abc import Mapping
 from typing import Any, Optional, cast
 
-from sqlalchemy import select
-from sqlalchemy.orm import Session
-
 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
 from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
 from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
 from core.moderation.base import ModerationException
 from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.node_entities import SystemVariable
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import UserFrom
 from core.workflow.workflow_engine_manager import WorkflowEngineManager
 from extensions.ext_database import db
-from models.model import App, Conversation, EndUser, Message
-from models.workflow import ConversationVariable, Workflow
+from models import App, Message, Workflow
 
 logger = logging.getLogger(__name__)
 
@@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
         self,
         application_generate_entity: AdvancedChatAppGenerateEntity,
         queue_manager: AppQueueManager,
-        conversation: Conversation,
         message: Message,
     ) -> None:
         """
@@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):
 
         inputs = application_generate_entity.inputs
         query = application_generate_entity.query
-        files = application_generate_entity.files
-
-        user_id = None
-        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
-            end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
-            if end_user:
-                user_id = end_user.session_id
-        else:
-            user_id = application_generate_entity.user_id
 
         # moderation
         if self.handle_input_moderation(
@@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
         if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
             workflow_callbacks.append(WorkflowLoggingCallback())
 
-        # Init conversation variables
-        stmt = select(ConversationVariable).where(
-            ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
-        )
-        with Session(db.engine) as session:
-            conversation_variables = session.scalars(stmt).all()
-            if not conversation_variables:
-                conversation_variables = [
-                    ConversationVariable.from_variable(
-                        app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
-                    )
-                    for variable in workflow.conversation_variables
-                ]
-                session.add_all(conversation_variables)
-                session.commit()
-            # Convert database entities to variables
-            conversation_variables = [item.to_variable() for item in conversation_variables]
-
-        # Create a variable pool.
-        system_inputs = {
-            SystemVariable.QUERY: query,
-            SystemVariable.FILES: files,
-            SystemVariable.CONVERSATION_ID: conversation.id,
-            SystemVariable.USER_ID: user_id,
-        }
-        variable_pool = VariablePool(
-            system_variables=system_inputs,
-            user_inputs=inputs,
-            environment_variables=workflow.environment_variables,
-            conversation_variables=conversation_variables,
-        )
-
         # RUN WORKFLOW
         workflow_engine_manager = WorkflowEngineManager()
         workflow_engine_manager.run_workflow(
@@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
             invoke_from=application_generate_entity.invoke_from,
             callbacks=workflow_callbacks,
             call_depth=application_generate_entity.call_depth,
-            variable_pool=variable_pool,
         )
 
     def single_iteration_run(
@@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
         """
         Single iteration run
         """
-        app_record: App = db.session.query(App).filter(App.id == app_id).first()
+        app_record = db.session.query(App).filter(App.id == app_id).first()
         if not app_record:
             raise ValueError('App not found')
 

+ 10 - 4
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -4,6 +4,7 @@ import time
 from collections.abc import Generator
 from typing import Any, Optional, Union, cast
 
+import contexts
 from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
 from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.ops.ops_trace_manager import TraceQueueManager
-from core.workflow.entities.node_entities import NodeType, SystemVariable
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.answer.answer_node import AnswerNode
 from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
 from events.message_event import message_was_created
@@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     _application_generate_entity: AdvancedChatAppGenerateEntity
     _workflow: Workflow
     _user: Union[Account, EndUser]
+    # Deprecated
     _workflow_system_variables: dict[SystemVariable, Any]
     _iteration_nested_relations: dict[str, list[str]]
 
@@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             conversation: Conversation,
             message: Message,
             user: Union[Account, EndUser],
-            stream: bool
+            stream: bool,
     ) -> None:
         """
         Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         self._workflow = workflow
         self._conversation = conversation
         self._message = message
+        # Deprecated
         self._workflow_system_variables = {
             SystemVariable.QUERY: message.query,
             SystemVariable.FILES: application_generate_entity.files,
             SystemVariable.CONVERSATION_ID: conversation.id,
-            SystemVariable.USER_ID: user_id
+            SystemVariable.USER_ID: user_id,
         }
 
         self._task_state = AdvancedChatTaskState(
@@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
                 if route_chunk_node_id == 'sys':
                     # system variable
-                    value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
+                    value = contexts.workflow_variable_pool.get().get(value_selector)
+                    if value:
+                        value = value.text
                 elif route_chunk_node_id in self._iteration_nested_relations:
                     # it's a iteration variable
                     if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:

+ 4 - 1
api/core/app/apps/message_based_app_generator.py

@@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
 
         return introduction
 
-    def _get_conversation(self, conversation_id: str) -> Conversation:
+    def _get_conversation(self, conversation_id: str):
         """
         Get conversation by conversation id
         :param conversation_id: conversation id
@@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             .first()
         )
 
+        if not conversation:
+            raise ConversationNotExistsError()
+
         return conversation
 
     def _get_message(self, message_id: str) -> Message:

+ 1 - 1
api/core/app/apps/workflow/app_runner.py

@@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
     WorkflowAppGenerateEntity,
 )
 from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import UserFrom
 from core.workflow.workflow_engine_manager import WorkflowEngineManager
 from extensions.ext_database import db

+ 3 - 3
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
 from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
 from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
 from core.ops.ops_trace_manager import TraceQueueManager
-from core.workflow.entities.node_entities import NodeType, SystemVariable
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.end.end_node import EndNode
 from extensions.ext_database import db
 from models.account import Account
@@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         """
         nodes = graph.get('nodes')
 
-        iteration_ids = [node.get('id') for node in nodes 
+        iteration_ids = [node.get('id') for node in nodes
                          if node.get('data', {}).get('type') in [
                              NodeType.ITERATION.value,
                              NodeType.LOOP.value,
@@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
             ] for iteration_id in iteration_ids
         }
-    

+ 0 - 6
api/core/app/segments/__init__.py

@@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
 from .segments import (
     ArrayAnySegment,
     ArraySegment,
-    FileSegment,
     FloatSegment,
     IntegerSegment,
     NoneSegment,
@@ -13,11 +12,9 @@ from .segments import (
 from .types import SegmentType
 from .variables import (
     ArrayAnyVariable,
-    ArrayFileVariable,
     ArrayNumberVariable,
     ArrayObjectVariable,
     ArrayStringVariable,
-    FileVariable,
     FloatVariable,
     IntegerVariable,
     NoneVariable,
@@ -32,7 +29,6 @@ __all__ = [
     'FloatVariable',
     'ObjectVariable',
     'SecretVariable',
-    'FileVariable',
     'StringVariable',
     'ArrayAnyVariable',
     'Variable',
@@ -45,11 +41,9 @@ __all__ = [
     'FloatSegment',
     'ObjectSegment',
     'ArrayAnySegment',
-    'FileSegment',
     'StringSegment',
     'ArrayStringVariable',
     'ArrayNumberVariable',
     'ArrayObjectVariable',
-    'ArrayFileVariable',
     'ArraySegment',
 ]

+ 0 - 12
api/core/app/segments/factory.py

@@ -2,12 +2,10 @@ from collections.abc import Mapping
 from typing import Any
 
 from configs import dify_config
-from core.file.file_obj import FileVar
 
 from .exc import VariableError
 from .segments import (
     ArrayAnySegment,
-    FileSegment,
     FloatSegment,
     IntegerSegment,
     NoneSegment,
@@ -17,11 +15,9 @@ from .segments import (
 )
 from .types import SegmentType
 from .variables import (
-    ArrayFileVariable,
     ArrayNumberVariable,
     ArrayObjectVariable,
     ArrayStringVariable,
-    FileVariable,
     FloatVariable,
     IntegerVariable,
     ObjectVariable,
@@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
             result = FloatVariable.model_validate(mapping)
         case SegmentType.NUMBER if not isinstance(value, float | int):
             raise VariableError(f'invalid number value {value}')
-        case SegmentType.FILE:
-            result = FileVariable.model_validate(mapping)
         case SegmentType.OBJECT if isinstance(value, dict):
             result = ObjectVariable.model_validate(mapping)
         case SegmentType.ARRAY_STRING if isinstance(value, list):
@@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
             result = ArrayNumberVariable.model_validate(mapping)
         case SegmentType.ARRAY_OBJECT if isinstance(value, list):
             result = ArrayObjectVariable.model_validate(mapping)
-        case SegmentType.ARRAY_FILE if isinstance(value, list):
-            mapping = dict(mapping)
-            mapping['value'] = [{'value': v} for v in value]
-            result = ArrayFileVariable.model_validate(mapping)
         case _:
             raise VariableError(f'not supported value type {value_type}')
     if result.size > dify_config.MAX_VARIABLE_SIZE:
@@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
         return ObjectSegment(value=value)
     if isinstance(value, list):
         return ArrayAnySegment(value=value)
-    if isinstance(value, FileVar):
-        return FileSegment(value=value)
     raise ValueError(f'not supported value {value}')

+ 0 - 13
api/core/app/segments/segments.py

@@ -5,8 +5,6 @@ from typing import Any
 
 from pydantic import BaseModel, ConfigDict, field_validator
 
-from core.file.file_obj import FileVar
-
 from .types import SegmentType
 
 
@@ -78,14 +76,7 @@ class IntegerSegment(Segment):
     value: int
 
 
-class FileSegment(Segment):
-    value_type: SegmentType = SegmentType.FILE
-    # TODO: embed FileVar in this model.
-    value: FileVar
 
-    @property
-    def markdown(self) -> str:
-        return self.value.to_markdown()
 
 
 class ObjectSegment(Segment):
@@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
     value_type: SegmentType = SegmentType.ARRAY_OBJECT
     value: Sequence[Mapping[str, Any]]
 
-
-class ArrayFileSegment(ArraySegment):
-    value_type: SegmentType = SegmentType.ARRAY_FILE
-    value: Sequence[FileSegment]

+ 0 - 2
api/core/app/segments/types.py

@@ -10,8 +10,6 @@ class SegmentType(str, Enum):
     ARRAY_STRING = 'array[string]'
     ARRAY_NUMBER = 'array[number]'
     ARRAY_OBJECT = 'array[object]'
-    ARRAY_FILE = 'array[file]'
     OBJECT = 'object'
-    FILE = 'file'
 
     GROUP = 'group'

+ 0 - 9
api/core/app/segments/variables.py

@@ -4,11 +4,9 @@ from core.helper import encrypter
 
 from .segments import (
     ArrayAnySegment,
-    ArrayFileSegment,
     ArrayNumberSegment,
     ArrayObjectSegment,
     ArrayStringSegment,
-    FileSegment,
     FloatSegment,
     IntegerSegment,
     NoneSegment,
@@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
     pass
 
 
-class FileVariable(FileSegment, Variable):
-    pass
-
-
 class ObjectVariable(ObjectSegment, Variable):
     pass
 
@@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
     pass
 
 
-class ArrayFileVariable(ArrayFileSegment, Variable):
-    pass
-
 
 class SecretVariable(StringVariable):
     value_type: SegmentType = SegmentType.SECRET

+ 2 - 2
api/core/app/task_pipeline/workflow_cycle_state_manager.py

@@ -2,7 +2,7 @@ from typing import Any, Union
 
 from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
 from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
-from core.workflow.entities.node_entities import SystemVariable
+from core.workflow.enums import SystemVariable
 from models.account import Account
 from models.model import EndUser
 from models.workflow import Workflow
@@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
     _workflow: Workflow
     _user: Union[Account, EndUser]
     _task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
-    _workflow_system_variables: dict[SystemVariable, Any]
+    _workflow_system_variables: dict[SystemVariable, Any]

+ 4 - 24
api/core/workflow/entities/node_entities.py

@@ -4,13 +4,14 @@ from typing import Any, Optional
 
 from pydantic import BaseModel
 
-from models.workflow import WorkflowNodeExecutionStatus
+from models import WorkflowNodeExecutionStatus
 
 
 class NodeType(Enum):
     """
     Node Types.
     """
+
     START = 'start'
     END = 'end'
     ANSWER = 'answer'
@@ -44,33 +45,11 @@ class NodeType(Enum):
         raise ValueError(f'invalid node type value {value}')
 
 
-class SystemVariable(Enum):
-    """
-    System Variables.
-    """
-    QUERY = 'query'
-    FILES = 'files'
-    CONVERSATION_ID = 'conversation_id'
-    USER_ID = 'user_id'
-
-    @classmethod
-    def value_of(cls, value: str) -> 'SystemVariable':
-        """
-        Get value of given system variable.
-
-        :param value: system variable value
-        :return: system variable
-        """
-        for system_variable in cls:
-            if system_variable.value == value:
-                return system_variable
-        raise ValueError(f'invalid system variable value {value}')
-
-
 class NodeRunMetadataKey(Enum):
     """
     Node Run Metadata Key.
     """
+
     TOTAL_TOKENS = 'total_tokens'
     TOTAL_PRICE = 'total_price'
     CURRENCY = 'currency'
@@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
     """
     Node Run Result.
     """
+
     status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
 
     inputs: Optional[Mapping[str, Any]] = None  # node inputs

+ 1 - 1
api/core/workflow/entities/variable_pool.py

@@ -6,7 +6,7 @@ from typing_extensions import deprecated
 
 from core.app.segments import Segment, Variable, factory
 from core.file.file_obj import FileVar
-from core.workflow.entities.node_entities import SystemVariable
+from core.workflow.enums import SystemVariable
 
 VariableValue = Union[str, int, float, dict, list, FileVar]
 

+ 25 - 0
api/core/workflow/enums.py

@@ -0,0 +1,25 @@
+from enum import Enum
+
+
+class SystemVariable(str, Enum):
+    """
+    System Variables.
+    """
+    QUERY = 'query'
+    FILES = 'files'
+    CONVERSATION_ID = 'conversation_id'
+    USER_ID = 'user_id'
+    DIALOGUE_COUNT = 'dialogue_count'
+
+    @classmethod
+    def value_of(cls, value: str):
+        """
+        Get value of given system variable.
+
+        :param value: system variable value
+        :return: system variable
+        """
+        for system_variable in cls:
+            if system_variable.value == value:
+                return system_variable
+        raise ValueError(f'invalid system variable value {value}')

+ 6 - 5
api/core/workflow/nodes/llm/llm_node.py

@@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.llm.entities import (
     LLMNodeChatModelMessage,
@@ -201,8 +202,8 @@ class LLMNode(BaseNode):
             usage = LLMUsage.empty_usage()
 
         return full_text, usage
-    
-    def _transform_chat_messages(self, 
+
+    def _transform_chat_messages(self,
         messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
     ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
         """
@@ -249,13 +250,13 @@ class LLMNode(BaseNode):
                 # check if it's a context structure
                 if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
                     return d['content']
-                
+
                 # else, parse the dict
                 try:
                     return json.dumps(d, ensure_ascii=False)
                 except Exception:
                     return str(d)
-                
+
             if isinstance(value, str):
                 value = value
             elif isinstance(value, list):

+ 6 - 5
api/core/workflow/nodes/tool/tool_node.py

@@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
 from os import path
 from typing import Any, cast
 
-from core.app.segments import parser
+from core.app.segments import ArrayAnyVariable, parser
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
 from core.file.file_obj import FileTransferMethod, FileType, FileVar
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
 from core.tools.tool_engine import ToolEngine
 from core.tools.tool_manager import ToolManager
 from core.tools.utils.message_transformer import ToolFileMessageTransformer
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.tool.entities import ToolNodeData
 from core.workflow.utils.variable_template_parser import VariableTemplateParser
-from models.workflow import WorkflowNodeExecutionStatus
+from models import WorkflowNodeExecutionStatus
 
 
 class ToolNode(BaseNode):
@@ -140,9 +141,9 @@ class ToolNode(BaseNode):
         return result
 
     def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
-        # FIXME: ensure this is a ArrayVariable contains FileVariable.
         variable = variable_pool.get(['sys', SystemVariable.FILES.value])
-        return [file_var.value for file_var in variable.value] if variable else []
+        assert isinstance(variable, ArrayAnyVariable)
+        return list(variable.value) if variable else []
 
     def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
         """

+ 4 - 1
api/core/workflow/workflow_engine_manager.py

@@ -3,6 +3,7 @@ import time
 from collections.abc import Mapping, Sequence
 from typing import Any, Optional, cast
 
+import contexts
 from configs import dify_config
 from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
 from core.app.entities.app_invoke_entities import InvokeFrom
@@ -97,7 +98,7 @@ class WorkflowEngineManager:
         invoke_from: InvokeFrom,
         callbacks: Sequence[WorkflowCallback],
         call_depth: int = 0,
-        variable_pool: VariablePool,
+        variable_pool: VariablePool | None = None,
     ) -> None:
         """
         :param workflow: Workflow instance
@@ -128,6 +129,8 @@ class WorkflowEngineManager:
             raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
 
         # init workflow run state
+        if not variable_pool:
+            variable_pool = contexts.workflow_variable_pool.get()
         workflow_run_state = WorkflowRunState(
             workflow=workflow,
             start_at=time.perf_counter(),

+ 33 - 0
api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py

@@ -0,0 +1,33 @@
+"""add conversations.dialogue_count
+
+Revision ID: 8782057ff0dc
+Revises: 63a83fcf12ba
+Create Date: 2024-08-14 13:54:25.161324
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '8782057ff0dc'
+down_revision = '63a83fcf12ba'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('conversations', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('conversations', schema=None) as batch_op:
+        batch_op.drop_column('dialogue_count')
+
+    # ### end Alembic commands ###

+ 3 - 3
api/models/__init__.py

@@ -1,10 +1,10 @@
 from enum import Enum
 
-from .model import AppMode
+from .model import App, AppMode, Message
 from .types import StringUUID
-from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
+from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
 
-__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
+__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message']
 
 
 class CreatedByRole(Enum):

+ 3 - 2
api/models/model.py

@@ -7,6 +7,7 @@ from typing import Optional
 from flask import request
 from flask_login import UserMixin
 from sqlalchemy import Float, func, text
+from sqlalchemy.orm import Mapped, mapped_column
 
 from configs import dify_config
 from core.file.tool_file_parser import ToolFileParser
@@ -512,12 +513,12 @@ class Conversation(db.Model):
     from_account_id = db.Column(StringUUID)
     read_at = db.Column(db.DateTime)
     read_account_id = db.Column(StringUUID)
+    dialogue_count: Mapped[int] = mapped_column(default=0)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
     messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
-    message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
-                                          passive_deletes="all")
+    message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
 
     is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
 

+ 2 - 2
api/tests/integration_tests/workflow/nodes/test_llm.py

@@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers import ModelProviderFactory
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import UserFrom
 from core.workflow.nodes.llm.llm_node import LLMNode
 from extensions.ext_database import db
@@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
     assert 'sunny' in json.dumps(result.process_data)
-    assert 'what\'s the weather today?' in json.dumps(result.process_data)
+    assert 'what\'s the weather today?' in json.dumps(result.process_data)

+ 3 - 3
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py

@@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import UserFrom
 from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
 from extensions.ext_database import db
@@ -363,7 +363,7 @@ def test_extract_json_response():
         {
             "location": "kawaii"
         }
-        hello world.                          
+        hello world.
     """)
 
     assert result['location'] == 'kawaii'
@@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
             assert latest_role != prompt.get('role')
 
         if prompt.get('role') in ['user', 'assistant']:
-            latest_role = prompt.get('role')
+            latest_role = prompt.get('role')

+ 0 - 80
api/tests/unit_tests/core/app/segments/test_factory.py

@@ -3,12 +3,9 @@ from uuid import uuid4
 import pytest
 
 from core.app.segments import (
-    ArrayFileVariable,
     ArrayNumberVariable,
     ArrayObjectVariable,
     ArrayStringVariable,
-    FileSegment,
-    FileVariable,
     FloatVariable,
     IntegerVariable,
     ObjectSegment,
@@ -149,83 +146,6 @@ def test_array_object_variable():
     assert isinstance(variable.value[1]['key2'], int)
 
 
-def test_file_variable():
-    mapping = {
-        'id': str(uuid4()),
-        'value_type': 'file',
-        'name': 'test_file',
-        'description': 'Description of the variable.',
-        'value': {
-            'id': str(uuid4()),
-            'tenant_id': 'tenant_id',
-            'type': 'image',
-            'transfer_method': 'local_file',
-            'url': 'url',
-            'related_id': 'related_id',
-            'extra_config': {
-                'image_config': {
-                    'width': 100,
-                    'height': 100,
-                },
-            },
-            'filename': 'filename',
-            'extension': 'extension',
-            'mime_type': 'mime_type',
-        },
-    }
-    variable = factory.build_variable_from_mapping(mapping)
-    assert isinstance(variable, FileVariable)
-
-
-def test_array_file_variable():
-    mapping = {
-        'id': str(uuid4()),
-        'value_type': 'array[file]',
-        'name': 'test_array_file',
-        'description': 'Description of the variable.',
-        'value': [
-            {
-                'id': str(uuid4()),
-                'tenant_id': 'tenant_id',
-                'type': 'image',
-                'transfer_method': 'local_file',
-                'url': 'url',
-                'related_id': 'related_id',
-                'extra_config': {
-                    'image_config': {
-                        'width': 100,
-                        'height': 100,
-                    },
-                },
-                'filename': 'filename',
-                'extension': 'extension',
-                'mime_type': 'mime_type',
-            },
-            {
-                'id': str(uuid4()),
-                'tenant_id': 'tenant_id',
-                'type': 'image',
-                'transfer_method': 'local_file',
-                'url': 'url',
-                'related_id': 'related_id',
-                'extra_config': {
-                    'image_config': {
-                        'width': 100,
-                        'height': 100,
-                    },
-                },
-                'filename': 'filename',
-                'extension': 'extension',
-                'mime_type': 'mime_type',
-            },
-        ],
-    }
-    variable = factory.build_variable_from_mapping(mapping)
-    assert isinstance(variable, ArrayFileVariable)
-    assert isinstance(variable.value[0], FileSegment)
-    assert isinstance(variable.value[1], FileSegment)
-
-
 def test_variable_cannot_large_than_5_kb():
     with pytest.raises(VariableError):
         factory.build_variable_from_mapping(

+ 1 - 1
api/tests/unit_tests/core/app/segments/test_segment.py

@@ -1,7 +1,7 @@
 from core.app.segments import SecretVariable, StringSegment, parser
 from core.helper import encrypter
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 
 
 def test_segment_group_to_text():

+ 1 - 1
api/tests/unit_tests/core/workflow/nodes/test_answer.py

@@ -1,8 +1,8 @@
 from unittest.mock import MagicMock
 
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.answer.answer_node import AnswerNode
 from core.workflow.nodes.base_node import UserFrom
 from extensions.ext_database import db

+ 1 - 1
api/tests/unit_tests/core/workflow/nodes/test_if_else.py

@@ -1,8 +1,8 @@
 from unittest.mock import MagicMock
 
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import UserFrom
 from core.workflow.nodes.if_else.if_else_node import IfElseNode
 from extensions.ext_database import db

+ 1 - 1
api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py

@@ -3,8 +3,8 @@ from uuid import uuid4
 
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.segments import ArrayStringVariable, StringVariable
-from core.workflow.entities.node_entities import SystemVariable
 from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariable
 from core.workflow.nodes.base_node import UserFrom
 from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode