소스 검색

Feat: Retry on node execution errors (#11871)

Co-authored-by: Novice Lee <novicelee@NoviPro.local>
Novice 4 달 전
부모
커밋
7abc7fa573

+ 17 - 0
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
     QueueParallelBranchRunFailedEvent,
@@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     workflow_node_execution=workflow_node_execution,
                     workflow_node_execution=workflow_node_execution,
                 )
                 )
 
 
+                if response:
+                    yield response
+            elif isinstance(
+                event,
+                QueueNodeRetryEvent,
+            ):
+                workflow_node_execution = self._handle_workflow_node_execution_retried(
+                    workflow_run=workflow_run, event=event
+                )
+
+                response = self._workflow_node_retry_to_stream_response(
+                    event=event,
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_node_execution=workflow_node_execution,
+                )
+
                 if response:
                 if response:
                     yield response
                     yield response
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
             elif isinstance(event, QueueParallelBranchRunStartedEvent):

+ 18 - 1
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
     QueueParallelBranchRunFailedEvent,
@@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     task_id=self._application_generate_entity.task_id,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution,
                     workflow_node_execution=workflow_node_execution,
                 )
                 )
-
                 if node_failed_response:
                 if node_failed_response:
                     yield node_failed_response
                     yield node_failed_response
+            elif isinstance(
+                event,
+                QueueNodeRetryEvent,
+            ):
+                workflow_node_execution = self._handle_workflow_node_execution_retried(
+                    workflow_run=workflow_run, event=event
+                )
+
+                response = self._workflow_node_retry_to_stream_response(
+                    event=event,
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_node_execution=workflow_node_execution,
+                )
+
+                if response:
+                    yield response
+
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
                 if not workflow_run:
                 if not workflow_run:
                     raise Exception("Workflow run not initialized.")
                     raise Exception("Workflow run not initialized.")

+ 32 - 0
api/core/app/apps/workflow_app_runner.py

@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
     QueueParallelBranchRunFailedEvent,
@@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
     NodeRunExceptionEvent,
     NodeRunExceptionEvent,
     NodeRunFailedEvent,
     NodeRunFailedEvent,
     NodeRunRetrieverResourceEvent,
     NodeRunRetrieverResourceEvent,
+    NodeRunRetryEvent,
     NodeRunStartedEvent,
     NodeRunStartedEvent,
     NodeRunStreamChunkEvent,
     NodeRunStreamChunkEvent,
     NodeRunSucceededEvent,
     NodeRunSucceededEvent,
@@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
                     error=event.error if isinstance(event, IterationRunFailedEvent) else None,
                     error=event.error if isinstance(event, IterationRunFailedEvent) else None,
                 )
                 )
             )
             )
+        elif isinstance(event, NodeRunRetryEvent):
+            self._publish_event(
+                QueueNodeRetryEvent(
+                    node_execution_id=event.id,
+                    node_id=event.node_id,
+                    node_type=event.node_type,
+                    node_data=event.node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.start_at,
+                    inputs=event.route_node_state.node_run_result.inputs
+                    if event.route_node_state.node_run_result
+                    else {},
+                    process_data=event.route_node_state.node_run_result.process_data
+                    if event.route_node_state.node_run_result
+                    else {},
+                    outputs=event.route_node_state.node_run_result.outputs
+                    if event.route_node_state.node_run_result
+                    else {},
+                    error=event.error,
+                    execution_metadata=event.route_node_state.node_run_result.metadata
+                    if event.route_node_state.node_run_result
+                    else {},
+                    in_iteration_id=event.in_iteration_id,
+                    retry_index=event.retry_index,
+                    start_index=event.start_index,
+                )
+            )
 
 
     def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
     def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
         """
         """

+ 32 - 0
api/core/app/entities/queue_entities.py

@@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
     ERROR = "error"
     ERROR = "error"
     PING = "ping"
     PING = "ping"
     STOP = "stop"
     STOP = "stop"
+    RETRY = "retry"
 
 
 
 
 class AppQueueEvent(BaseModel):
 class AppQueueEvent(BaseModel):
@@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
     iteration_duration_map: Optional[dict[str, float]] = None
     iteration_duration_map: Optional[dict[str, float]] = None
 
 
 
 
+class QueueNodeRetryEvent(AppQueueEvent):
+    """QueueNodeRetryEvent entity"""
+
+    event: QueueEvent = QueueEvent.RETRY
+
+    node_execution_id: str
+    node_id: str
+    node_type: NodeType
+    node_data: BaseNodeData
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+    start_at: datetime
+
+    inputs: Optional[dict[str, Any]] = None
+    process_data: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
+    execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
+
+    error: str
+    retry_index: int  # retry index
+    start_index: int  # start index
+
+
 class QueueNodeInIterationFailedEvent(AppQueueEvent):
 class QueueNodeInIterationFailedEvent(AppQueueEvent):
     """
     """
     QueueNodeInIterationFailedEvent entity
     QueueNodeInIterationFailedEvent entity

+ 70 - 0
api/core/app/entities/task_entities.py

@@ -52,6 +52,7 @@ class StreamEvent(Enum):
     WORKFLOW_FINISHED = "workflow_finished"
     WORKFLOW_FINISHED = "workflow_finished"
     NODE_STARTED = "node_started"
     NODE_STARTED = "node_started"
     NODE_FINISHED = "node_finished"
     NODE_FINISHED = "node_finished"
+    NODE_RETRY = "node_retry"
     PARALLEL_BRANCH_STARTED = "parallel_branch_started"
     PARALLEL_BRANCH_STARTED = "parallel_branch_started"
     PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
     PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
     ITERATION_STARTED = "iteration_started"
     ITERATION_STARTED = "iteration_started"
@@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
         }
         }
 
 
 
 
+class NodeRetryStreamResponse(StreamResponse):
+    """
+    NodeFinishStreamResponse entity
+    """
+
+    class Data(BaseModel):
+        """
+        Data entity
+        """
+
+        id: str
+        node_id: str
+        node_type: str
+        title: str
+        index: int
+        predecessor_node_id: Optional[str] = None
+        inputs: Optional[dict] = None
+        process_data: Optional[dict] = None
+        outputs: Optional[dict] = None
+        status: str
+        error: Optional[str] = None
+        elapsed_time: float
+        execution_metadata: Optional[dict] = None
+        created_at: int
+        finished_at: int
+        files: Optional[Sequence[Mapping[str, Any]]] = []
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
+        parent_parallel_id: Optional[str] = None
+        parent_parallel_start_node_id: Optional[str] = None
+        iteration_id: Optional[str] = None
+        retry_index: int = 0
+
+    event: StreamEvent = StreamEvent.NODE_RETRY
+    workflow_run_id: str
+    data: Data
+
+    def to_ignore_detail_dict(self):
+        return {
+            "event": self.event.value,
+            "task_id": self.task_id,
+            "workflow_run_id": self.workflow_run_id,
+            "data": {
+                "id": self.data.id,
+                "node_id": self.data.node_id,
+                "node_type": self.data.node_type,
+                "title": self.data.title,
+                "index": self.data.index,
+                "predecessor_node_id": self.data.predecessor_node_id,
+                "inputs": None,
+                "process_data": None,
+                "outputs": None,
+                "status": self.data.status,
+                "error": None,
+                "elapsed_time": self.data.elapsed_time,
+                "execution_metadata": None,
+                "created_at": self.data.created_at,
+                "finished_at": self.data.finished_at,
+                "files": [],
+                "parallel_id": self.data.parallel_id,
+                "parallel_start_node_id": self.data.parallel_start_node_id,
+                "parent_parallel_id": self.data.parent_parallel_id,
+                "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
+                "iteration_id": self.data.iteration_id,
+                "retry_index": self.data.retry_index,
+            },
+        }
+
+
 class ParallelBranchStartStreamResponse(StreamResponse):
 class ParallelBranchStartStreamResponse(StreamResponse):
     """
     """
     ParallelBranchStartStreamResponse entity
     ParallelBranchStartStreamResponse entity

+ 93 - 0
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
     QueueParallelBranchRunFailedEvent,
@@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
     IterationNodeNextStreamResponse,
     IterationNodeNextStreamResponse,
     IterationNodeStartStreamResponse,
     IterationNodeStartStreamResponse,
     NodeFinishStreamResponse,
     NodeFinishStreamResponse,
+    NodeRetryStreamResponse,
     NodeStartStreamResponse,
     NodeStartStreamResponse,
     ParallelBranchFinishedStreamResponse,
     ParallelBranchFinishedStreamResponse,
     ParallelBranchStartStreamResponse,
     ParallelBranchStartStreamResponse,
@@ -423,6 +425,52 @@ class WorkflowCycleManage:
 
 
         return workflow_node_execution
         return workflow_node_execution
 
 
+    def _handle_workflow_node_execution_retried(
+        self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
+    ) -> WorkflowNodeExecution:
+        """
+        Workflow node execution failed
+        :param event: queue node failed event
+        :return:
+        """
+        created_at = event.start_at
+        finished_at = datetime.now(UTC).replace(tzinfo=None)
+        elapsed_time = (finished_at - created_at).total_seconds()
+        inputs = WorkflowEntry.handle_special_values(event.inputs)
+        outputs = WorkflowEntry.handle_special_values(event.outputs)
+
+        workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.tenant_id = workflow_run.tenant_id
+        workflow_node_execution.app_id = workflow_run.app_id
+        workflow_node_execution.workflow_id = workflow_run.workflow_id
+        workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
+        workflow_node_execution.workflow_run_id = workflow_run.id
+        workflow_node_execution.node_execution_id = event.node_execution_id
+        workflow_node_execution.node_id = event.node_id
+        workflow_node_execution.node_type = event.node_type.value
+        workflow_node_execution.title = event.node_data.title
+        workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
+        workflow_node_execution.created_by_role = workflow_run.created_by_role
+        workflow_node_execution.created_by = workflow_run.created_by
+        workflow_node_execution.created_at = created_at
+        workflow_node_execution.finished_at = finished_at
+        workflow_node_execution.elapsed_time = elapsed_time
+        workflow_node_execution.error = event.error
+        workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
+        workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
+        workflow_node_execution.execution_metadata = json.dumps(
+            {
+                NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
+            }
+        )
+        workflow_node_execution.index = event.start_index
+
+        db.session.add(workflow_node_execution)
+        db.session.commit()
+        db.session.refresh(workflow_node_execution)
+
+        return workflow_node_execution
+
     #################################################
     #################################################
     #             to stream responses               #
     #             to stream responses               #
     #################################################
     #################################################
@@ -587,6 +635,51 @@ class WorkflowCycleManage:
             ),
             ),
         )
         )
 
 
+    def _workflow_node_retry_to_stream_response(
+        self,
+        event: QueueNodeRetryEvent,
+        task_id: str,
+        workflow_node_execution: WorkflowNodeExecution,
+    ) -> Optional[NodeFinishStreamResponse]:
+        """
+        Workflow node finish to stream response.
+        :param event: queue node succeeded or failed event
+        :param task_id: task id
+        :param workflow_node_execution: workflow node execution
+        :return:
+        """
+        if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
+            return None
+
+        return NodeRetryStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_node_execution.workflow_run_id,
+            data=NodeRetryStreamResponse.Data(
+                id=workflow_node_execution.id,
+                node_id=workflow_node_execution.node_id,
+                node_type=workflow_node_execution.node_type,
+                index=workflow_node_execution.index,
+                title=workflow_node_execution.title,
+                predecessor_node_id=workflow_node_execution.predecessor_node_id,
+                inputs=workflow_node_execution.inputs_dict,
+                process_data=workflow_node_execution.process_data_dict,
+                outputs=workflow_node_execution.outputs_dict,
+                status=workflow_node_execution.status,
+                error=workflow_node_execution.error,
+                elapsed_time=workflow_node_execution.elapsed_time,
+                execution_metadata=workflow_node_execution.execution_metadata_dict,
+                created_at=int(workflow_node_execution.created_at.timestamp()),
+                finished_at=int(workflow_node_execution.finished_at.timestamp()),
+                files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
+                parent_parallel_id=event.parent_parallel_id,
+                parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                iteration_id=event.in_iteration_id,
+                retry_index=event.retry_index,
+            ),
+        )
+
     def _workflow_parallel_branch_start_to_stream_response(
     def _workflow_parallel_branch_start_to_stream_response(
         self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
         self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
     ) -> ParallelBranchStartStreamResponse:
     ) -> ParallelBranchStartStreamResponse:

+ 0 - 1
api/core/helper/ssrf_proxy.py

@@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
         )
         )
 
 
     retries = 0
     retries = 0
-    stream = kwargs.pop("stream", False)
     while retries <= max_retries:
     while retries <= max_retries:
         try:
         try:
             if dify_config.SSRF_PROXY_ALL_URL:
             if dify_config.SSRF_PROXY_ALL_URL:

+ 3 - 0
api/core/workflow/entities/node_entities.py

@@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
 
 
     error: Optional[str] = None  # error message if status is failed
     error: Optional[str] = None  # error message if status is failed
     error_type: Optional[str] = None  # error type if status is failed
     error_type: Optional[str] = None  # error type if status is failed
+
+    # single step node run retry
+    retry_index: int = 0

+ 7 - 0
api/core/workflow/graph_engine/entities/event.py

@@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
     error: str = Field(..., description="error")
     error: str = Field(..., description="error")
 
 
 
 
+class NodeRunRetryEvent(BaseNodeEvent):
+    error: str = Field(..., description="error")
+    retry_index: int = Field(..., description="which retry attempt is about to be performed")
+    start_at: datetime = Field(..., description="retry start time")
+    start_index: int = Field(..., description="retry start index")
+
+
 ###########################################
 ###########################################
 # Parallel Branch Events
 # Parallel Branch Events
 ###########################################
 ###########################################

+ 175 - 135
api/core/workflow/graph_engine/graph_engine.py

@@ -5,6 +5,7 @@ import uuid
 from collections.abc import Generator, Mapping
 from collections.abc import Generator, Mapping
 from concurrent.futures import ThreadPoolExecutor, wait
 from concurrent.futures import ThreadPoolExecutor, wait
 from copy import copy, deepcopy
 from copy import copy, deepcopy
+from datetime import UTC, datetime
 from typing import Any, Optional, cast
 from typing import Any, Optional, cast
 
 
 from flask import Flask, current_app
 from flask import Flask, current_app
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
     NodeRunExceptionEvent,
     NodeRunExceptionEvent,
     NodeRunFailedEvent,
     NodeRunFailedEvent,
     NodeRunRetrieverResourceEvent,
     NodeRunRetrieverResourceEvent,
+    NodeRunRetryEvent,
     NodeRunStartedEvent,
     NodeRunStartedEvent,
     NodeRunStreamChunkEvent,
     NodeRunStreamChunkEvent,
     NodeRunSucceededEvent,
     NodeRunSucceededEvent,
@@ -581,7 +583,7 @@ class GraphEngine:
 
 
     def _run_node(
     def _run_node(
         self,
         self,
-        node_instance: BaseNode,
+        node_instance: BaseNode[BaseNodeData],
         route_node_state: RouteNodeState,
         route_node_state: RouteNodeState,
         parallel_id: Optional[str] = None,
         parallel_id: Optional[str] = None,
         parallel_start_node_id: Optional[str] = None,
         parallel_start_node_id: Optional[str] = None,
@@ -607,36 +609,121 @@ class GraphEngine:
         )
         )
 
 
         db.session.close()
         db.session.close()
+        max_retries = node_instance.node_data.retry_config.max_retries
+        retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
+        retries = 0
+        shoudl_continue_retry = True
+        while shoudl_continue_retry and retries <= max_retries:
+            try:
+                # run node
+                retry_start_at = datetime.now(UTC).replace(tzinfo=None)
+                generator = node_instance.run()
+                for item in generator:
+                    if isinstance(item, GraphEngineEvent):
+                        if isinstance(item, BaseIterationEvent):
+                            # add parallel info to iteration event
+                            item.parallel_id = parallel_id
+                            item.parallel_start_node_id = parallel_start_node_id
+                            item.parent_parallel_id = parent_parallel_id
+                            item.parent_parallel_start_node_id = parent_parallel_start_node_id
+
+                        yield item
+                    else:
+                        if isinstance(item, RunCompletedEvent):
+                            run_result = item.run_result
+                            if run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                                if (
+                                    retries == max_retries
+                                    and node_instance.node_type == NodeType.HTTP_REQUEST
+                                    and run_result.outputs
+                                    and not node_instance.should_continue_on_error
+                                ):
+                                    run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
+                                if node_instance.should_retry and retries < max_retries:
+                                    retries += 1
+                                    self.graph_runtime_state.node_run_steps += 1
+                                    route_node_state.node_run_result = run_result
+                                    yield NodeRunRetryEvent(
+                                        id=node_instance.id,
+                                        node_id=node_instance.node_id,
+                                        node_type=node_instance.node_type,
+                                        node_data=node_instance.node_data,
+                                        route_node_state=route_node_state,
+                                        error=run_result.error,
+                                        retry_index=retries,
+                                        parallel_id=parallel_id,
+                                        parallel_start_node_id=parallel_start_node_id,
+                                        parent_parallel_id=parent_parallel_id,
+                                        parent_parallel_start_node_id=parent_parallel_start_node_id,
+                                        start_at=retry_start_at,
+                                        start_index=self.graph_runtime_state.node_run_steps,
+                                    )
+                                    time.sleep(retry_interval)
+                                    continue
+                            route_node_state.set_finished(run_result=run_result)
+
+                            if run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                                if node_instance.should_continue_on_error:
+                                    # if run failed, handle error
+                                    run_result = self._handle_continue_on_error(
+                                        node_instance,
+                                        item.run_result,
+                                        self.graph_runtime_state.variable_pool,
+                                        handle_exceptions=handle_exceptions,
+                                    )
+                                    route_node_state.node_run_result = run_result
+                                    route_node_state.status = RouteNodeState.Status.EXCEPTION
+                                    if run_result.outputs:
+                                        for variable_key, variable_value in run_result.outputs.items():
+                                            # append variables to variable pool recursively
+                                            self._append_variables_recursively(
+                                                node_id=node_instance.node_id,
+                                                variable_key_list=[variable_key],
+                                                variable_value=variable_value,
+                                            )
+                                    yield NodeRunExceptionEvent(
+                                        error=run_result.error or "System Error",
+                                        id=node_instance.id,
+                                        node_id=node_instance.node_id,
+                                        node_type=node_instance.node_type,
+                                        node_data=node_instance.node_data,
+                                        route_node_state=route_node_state,
+                                        parallel_id=parallel_id,
+                                        parallel_start_node_id=parallel_start_node_id,
+                                        parent_parallel_id=parent_parallel_id,
+                                        parent_parallel_start_node_id=parent_parallel_start_node_id,
+                                    )
+                                    shoudl_continue_retry = False
+                                else:
+                                    yield NodeRunFailedEvent(
+                                        error=route_node_state.failed_reason or "Unknown error.",
+                                        id=node_instance.id,
+                                        node_id=node_instance.node_id,
+                                        node_type=node_instance.node_type,
+                                        node_data=node_instance.node_data,
+                                        route_node_state=route_node_state,
+                                        parallel_id=parallel_id,
+                                        parallel_start_node_id=parallel_start_node_id,
+                                        parent_parallel_id=parent_parallel_id,
+                                        parent_parallel_start_node_id=parent_parallel_start_node_id,
+                                    )
+                                shoudl_continue_retry = False
+                            elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
+                                if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
+                                    node_instance.node_id
+                                ):
+                                    run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
+                                if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
+                                    # plus state total_tokens
+                                    self.graph_runtime_state.total_tokens += int(
+                                        run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)  # type: ignore[arg-type]
+                                    )
 
 
-        try:
-            # run node
-            generator = node_instance.run()
-            for item in generator:
-                if isinstance(item, GraphEngineEvent):
-                    if isinstance(item, BaseIterationEvent):
-                        # add parallel info to iteration event
-                        item.parallel_id = parallel_id
-                        item.parallel_start_node_id = parallel_start_node_id
-                        item.parent_parallel_id = parent_parallel_id
-                        item.parent_parallel_start_node_id = parent_parallel_start_node_id
+                                if run_result.llm_usage:
+                                    # use the latest usage
+                                    self.graph_runtime_state.llm_usage += run_result.llm_usage
 
 
-                    yield item
-                else:
-                    if isinstance(item, RunCompletedEvent):
-                        run_result = item.run_result
-                        route_node_state.set_finished(run_result=run_result)
-
-                        if run_result.status == WorkflowNodeExecutionStatus.FAILED:
-                            if node_instance.should_continue_on_error:
-                                # if run failed, handle error
-                                run_result = self._handle_continue_on_error(
-                                    node_instance,
-                                    item.run_result,
-                                    self.graph_runtime_state.variable_pool,
-                                    handle_exceptions=handle_exceptions,
-                                )
-                                route_node_state.node_run_result = run_result
-                                route_node_state.status = RouteNodeState.Status.EXCEPTION
+                                # append node output variables to variable pool
                                 if run_result.outputs:
                                 if run_result.outputs:
                                     for variable_key, variable_value in run_result.outputs.items():
                                     for variable_key, variable_value in run_result.outputs.items():
                                         # append variables to variable pool recursively
                                         # append variables to variable pool recursively
@@ -645,21 +732,23 @@ class GraphEngine:
                                             variable_key_list=[variable_key],
                                             variable_key_list=[variable_key],
                                             variable_value=variable_value,
                                             variable_value=variable_value,
                                         )
                                         )
-                                yield NodeRunExceptionEvent(
-                                    error=run_result.error or "System Error",
-                                    id=node_instance.id,
-                                    node_id=node_instance.node_id,
-                                    node_type=node_instance.node_type,
-                                    node_data=node_instance.node_data,
-                                    route_node_state=route_node_state,
-                                    parallel_id=parallel_id,
-                                    parallel_start_node_id=parallel_start_node_id,
-                                    parent_parallel_id=parent_parallel_id,
-                                    parent_parallel_start_node_id=parent_parallel_start_node_id,
-                                )
-                            else:
-                                yield NodeRunFailedEvent(
-                                    error=route_node_state.failed_reason or "Unknown error.",
+
+                                # add parallel info to run result metadata
+                                if parallel_id and parallel_start_node_id:
+                                    if not run_result.metadata:
+                                        run_result.metadata = {}
+
+                                    run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
+                                    run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
+                                        parallel_start_node_id
+                                    )
+                                    if parent_parallel_id and parent_parallel_start_node_id:
+                                        run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
+                                        run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
+                                            parent_parallel_start_node_id
+                                        )
+
+                                yield NodeRunSucceededEvent(
                                     id=node_instance.id,
                                     id=node_instance.id,
                                     node_id=node_instance.node_id,
                                     node_id=node_instance.node_id,
                                     node_type=node_instance.node_type,
                                     node_type=node_instance.node_type,
@@ -670,108 +759,59 @@ class GraphEngine:
                                     parent_parallel_id=parent_parallel_id,
                                     parent_parallel_id=parent_parallel_id,
                                     parent_parallel_start_node_id=parent_parallel_start_node_id,
                                     parent_parallel_start_node_id=parent_parallel_start_node_id,
                                 )
                                 )
+                                shoudl_continue_retry = False
 
 
-                        elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
-                            if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
-                                node_instance.node_id
-                            ):
-                                run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
-                            if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
-                                # plus state total_tokens
-                                self.graph_runtime_state.total_tokens += int(
-                                    run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)  # type: ignore[arg-type]
-                                )
-
-                            if run_result.llm_usage:
-                                # use the latest usage
-                                self.graph_runtime_state.llm_usage += run_result.llm_usage
-
-                            # append node output variables to variable pool
-                            if run_result.outputs:
-                                for variable_key, variable_value in run_result.outputs.items():
-                                    # append variables to variable pool recursively
-                                    self._append_variables_recursively(
-                                        node_id=node_instance.node_id,
-                                        variable_key_list=[variable_key],
-                                        variable_value=variable_value,
-                                    )
-
-                            # add parallel info to run result metadata
-                            if parallel_id and parallel_start_node_id:
-                                if not run_result.metadata:
-                                    run_result.metadata = {}
-
-                                run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
-                                run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
-                                if parent_parallel_id and parent_parallel_start_node_id:
-                                    run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
-                                    run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
-                                        parent_parallel_start_node_id
-                                    )
-
-                            yield NodeRunSucceededEvent(
+                            break
+                        elif isinstance(item, RunStreamChunkEvent):
+                            yield NodeRunStreamChunkEvent(
                                 id=node_instance.id,
                                 id=node_instance.id,
                                 node_id=node_instance.node_id,
                                 node_id=node_instance.node_id,
                                 node_type=node_instance.node_type,
                                 node_type=node_instance.node_type,
                                 node_data=node_instance.node_data,
                                 node_data=node_instance.node_data,
+                                chunk_content=item.chunk_content,
+                                from_variable_selector=item.from_variable_selector,
                                 route_node_state=route_node_state,
                                 route_node_state=route_node_state,
                                 parallel_id=parallel_id,
                                 parallel_id=parallel_id,
                                 parallel_start_node_id=parallel_start_node_id,
                                 parallel_start_node_id=parallel_start_node_id,
                                 parent_parallel_id=parent_parallel_id,
                                 parent_parallel_id=parent_parallel_id,
                                 parent_parallel_start_node_id=parent_parallel_start_node_id,
                                 parent_parallel_start_node_id=parent_parallel_start_node_id,
                             )
                             )
-
-                        break
-                    elif isinstance(item, RunStreamChunkEvent):
-                        yield NodeRunStreamChunkEvent(
-                            id=node_instance.id,
-                            node_id=node_instance.node_id,
-                            node_type=node_instance.node_type,
-                            node_data=node_instance.node_data,
-                            chunk_content=item.chunk_content,
-                            from_variable_selector=item.from_variable_selector,
-                            route_node_state=route_node_state,
-                            parallel_id=parallel_id,
-                            parallel_start_node_id=parallel_start_node_id,
-                            parent_parallel_id=parent_parallel_id,
-                            parent_parallel_start_node_id=parent_parallel_start_node_id,
-                        )
-                    elif isinstance(item, RunRetrieverResourceEvent):
-                        yield NodeRunRetrieverResourceEvent(
-                            id=node_instance.id,
-                            node_id=node_instance.node_id,
-                            node_type=node_instance.node_type,
-                            node_data=node_instance.node_data,
-                            retriever_resources=item.retriever_resources,
-                            context=item.context,
-                            route_node_state=route_node_state,
-                            parallel_id=parallel_id,
-                            parallel_start_node_id=parallel_start_node_id,
-                            parent_parallel_id=parent_parallel_id,
-                            parent_parallel_start_node_id=parent_parallel_start_node_id,
-                        )
-        except GenerateTaskStoppedError:
-            # trigger node run failed event
-            route_node_state.status = RouteNodeState.Status.FAILED
-            route_node_state.failed_reason = "Workflow stopped."
-            yield NodeRunFailedEvent(
-                error="Workflow stopped.",
-                id=node_instance.id,
-                node_id=node_instance.node_id,
-                node_type=node_instance.node_type,
-                node_data=node_instance.node_data,
-                route_node_state=route_node_state,
-                parallel_id=parallel_id,
-                parallel_start_node_id=parallel_start_node_id,
-                parent_parallel_id=parent_parallel_id,
-                parent_parallel_start_node_id=parent_parallel_start_node_id,
-            )
-            return
-        except Exception as e:
-            logger.exception(f"Node {node_instance.node_data.title} run failed")
-            raise e
-        finally:
-            db.session.close()
+                        elif isinstance(item, RunRetrieverResourceEvent):
+                            yield NodeRunRetrieverResourceEvent(
+                                id=node_instance.id,
+                                node_id=node_instance.node_id,
+                                node_type=node_instance.node_type,
+                                node_data=node_instance.node_data,
+                                retriever_resources=item.retriever_resources,
+                                context=item.context,
+                                route_node_state=route_node_state,
+                                parallel_id=parallel_id,
+                                parallel_start_node_id=parallel_start_node_id,
+                                parent_parallel_id=parent_parallel_id,
+                                parent_parallel_start_node_id=parent_parallel_start_node_id,
+                            )
+            except GenerateTaskStoppedError:
+                # trigger node run failed event
+                route_node_state.status = RouteNodeState.Status.FAILED
+                route_node_state.failed_reason = "Workflow stopped."
+                yield NodeRunFailedEvent(
+                    error="Workflow stopped.",
+                    id=node_instance.id,
+                    node_id=node_instance.node_id,
+                    node_type=node_instance.node_type,
+                    node_data=node_instance.node_data,
+                    route_node_state=route_node_state,
+                    parallel_id=parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id,
+                )
+                return
+            except Exception as e:
+                logger.exception(f"Node {node_instance.node_data.title} run failed")
+                raise e
+            finally:
+                db.session.close()
 
 
     def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
     def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
         """
         """

+ 13 - 0
api/core/workflow/nodes/base/entities.py

@@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
         return self
         return self
 
 
 
 
+class RetryConfig(BaseModel):
+    """node retry config"""
+
+    max_retries: int = 0  # max retry times
+    retry_interval: int = 0  # retry interval in milliseconds
+    retry_enabled: bool = False  # whether retry is enabled
+
+    @property
+    def retry_interval_seconds(self) -> float:
+        return self.retry_interval / 1000
+
+
 class BaseNodeData(ABC, BaseModel):
 class BaseNodeData(ABC, BaseModel):
     title: str
     title: str
     desc: Optional[str] = None
     desc: Optional[str] = None
     error_strategy: Optional[ErrorStrategy] = None
     error_strategy: Optional[ErrorStrategy] = None
     default_value: Optional[list[DefaultValue]] = None
     default_value: Optional[list[DefaultValue]] = None
     version: str = "1"
     version: str = "1"
+    retry_config: RetryConfig = RetryConfig()
 
 
     @property
     @property
     def default_value_dict(self):
     def default_value_dict(self):

+ 10 - 1
api/core/workflow/nodes/base/node.py

@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
 from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
 from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
 
 
 from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.entities.node_entities import NodeRunResult
-from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
+from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
 from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
 from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
 from models.workflow import WorkflowNodeExecutionStatus
 from models.workflow import WorkflowNodeExecutionStatus
 
 
@@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
             bool: if should continue on error
             bool: if should continue on error
         """
         """
         return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
         return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
+
+    @property
+    def should_retry(self) -> bool:
+        """judge if should retry
+
+        Returns:
+            bool: if should retry
+        """
+        return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

+ 1 - 0
api/core/workflow/nodes/enums.py

@@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
 
 
 
 
 CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
 CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
+RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]

+ 8 - 1
api/core/workflow/nodes/event/__init__.py

@@ -1,4 +1,10 @@
-from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
+from .event import (
+    ModelInvokeCompletedEvent,
+    RunCompletedEvent,
+    RunRetrieverResourceEvent,
+    RunRetryEvent,
+    RunStreamChunkEvent,
+)
 from .types import NodeEvent
 from .types import NodeEvent
 
 
 __all__ = [
 __all__ = [
@@ -6,5 +12,6 @@ __all__ = [
     "NodeEvent",
     "NodeEvent",
     "RunCompletedEvent",
     "RunCompletedEvent",
     "RunRetrieverResourceEvent",
     "RunRetrieverResourceEvent",
+    "RunRetryEvent",
     "RunStreamChunkEvent",
     "RunStreamChunkEvent",
 ]
 ]

+ 25 - 0
api/core/workflow/nodes/event/event.py

@@ -1,7 +1,10 @@
+from datetime import datetime
+
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 
 
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.entities.node_entities import NodeRunResult
+from models.workflow import WorkflowNodeExecutionStatus
 
 
 
 
 class RunCompletedEvent(BaseModel):
 class RunCompletedEvent(BaseModel):
@@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel):
     text: str
     text: str
     usage: LLMUsage
     usage: LLMUsage
     finish_reason: str | None = None
     finish_reason: str | None = None
+
+
+class RunRetryEvent(BaseModel):
+    """Node Run Retry event"""
+
+    error: str = Field(..., description="error")
+    retry_index: int = Field(..., description="Retry attempt number")
+    start_at: datetime = Field(..., description="Retry start time")
+
+
+class SingleStepRetryEvent(BaseModel):
+    """Single step retry event"""
+
+    status: str = WorkflowNodeExecutionStatus.RETRY.value
+
+    inputs: dict | None = Field(..., description="input")
+    error: str = Field(..., description="error")
+    outputs: dict = Field(..., description="output")
+    retry_index: int = Field(..., description="Retry attempt number")
+    error: str = Field(..., description="error")
+    elapsed_time: float = Field(..., description="elapsed time")
+    execution_metadata: dict | None = Field(..., description="execution metadata")

+ 4 - 0
api/core/workflow/nodes/http_request/executor.py

@@ -45,6 +45,7 @@ class Executor:
     headers: dict[str, str]
     headers: dict[str, str]
     auth: HttpRequestNodeAuthorization
     auth: HttpRequestNodeAuthorization
     timeout: HttpRequestNodeTimeout
     timeout: HttpRequestNodeTimeout
+    max_retries: int
 
 
     boundary: str
     boundary: str
 
 
@@ -54,6 +55,7 @@ class Executor:
         node_data: HttpRequestNodeData,
         node_data: HttpRequestNodeData,
         timeout: HttpRequestNodeTimeout,
         timeout: HttpRequestNodeTimeout,
         variable_pool: VariablePool,
         variable_pool: VariablePool,
+        max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
     ):
     ):
         # If authorization API key is present, convert the API key using the variable pool
         # If authorization API key is present, convert the API key using the variable pool
         if node_data.authorization.type == "api-key":
         if node_data.authorization.type == "api-key":
@@ -73,6 +75,7 @@ class Executor:
         self.files = None
         self.files = None
         self.data = None
         self.data = None
         self.json = None
         self.json = None
+        self.max_retries = max_retries
 
 
         # init template
         # init template
         self.variable_pool = variable_pool
         self.variable_pool = variable_pool
@@ -241,6 +244,7 @@ class Executor:
             "params": self.params,
             "params": self.params,
             "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
             "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
             "follow_redirects": True,
             "follow_redirects": True,
+            "max_retries": self.max_retries,
         }
         }
         # request_args = {k: v for k, v in request_args.items() if v is not None}
         # request_args = {k: v for k, v in request_args.items() if v is not None}
         try:
         try:

+ 7 - 1
api/core/workflow/nodes/http_request/node.py

@@ -52,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
                     "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
                     "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
                 },
                 },
             },
             },
+            "retry_config": {
+                "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
+                "retry_interval": 0.5 * (2**2),
+                "retry_enabled": True,
+            },
         }
         }
 
 
     def _run(self) -> NodeRunResult:
     def _run(self) -> NodeRunResult:
@@ -61,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
                 node_data=self.node_data,
                 node_data=self.node_data,
                 timeout=self._get_request_timeout(self.node_data),
                 timeout=self._get_request_timeout(self.node_data),
                 variable_pool=self.graph_runtime_state.variable_pool,
                 variable_pool=self.graph_runtime_state.variable_pool,
+                max_retries=0,
             )
             )
             process_data["request"] = http_executor.to_log()
             process_data["request"] = http_executor.to_log()
 
 
             response = http_executor.invoke()
             response = http_executor.invoke()
             files = self.extract_files(url=http_executor.url, response=response)
             files = self.extract_files(url=http_executor.url, response=response)
-            if not response.response.is_success and self.should_continue_on_error:
+            if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
                 return NodeRunResult(
                 return NodeRunResult(
                     status=WorkflowNodeExecutionStatus.FAILED,
                     status=WorkflowNodeExecutionStatus.FAILED,
                     outputs={
                     outputs={

+ 14 - 0
api/fields/workflow_run_fields.py

@@ -29,6 +29,7 @@ workflow_run_for_list_fields = {
     "created_at": TimestampField,
     "created_at": TimestampField,
     "finished_at": TimestampField,
     "finished_at": TimestampField,
     "exceptions_count": fields.Integer,
     "exceptions_count": fields.Integer,
+    "retry_index": fields.Integer,
 }
 }
 
 
 advanced_chat_workflow_run_for_list_fields = {
 advanced_chat_workflow_run_for_list_fields = {
@@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = {
     "created_at": TimestampField,
     "created_at": TimestampField,
     "finished_at": TimestampField,
     "finished_at": TimestampField,
     "exceptions_count": fields.Integer,
     "exceptions_count": fields.Integer,
+    "retry_index": fields.Integer,
 }
 }
 
 
 advanced_chat_workflow_run_pagination_fields = {
 advanced_chat_workflow_run_pagination_fields = {
@@ -79,6 +81,17 @@ workflow_run_detail_fields = {
     "exceptions_count": fields.Integer,
     "exceptions_count": fields.Integer,
 }
 }
 
 
+retry_event_field = {
+    "error": fields.String,
+    "retry_index": fields.Integer,
+    "inputs": fields.Raw(attribute="inputs"),
+    "elapsed_time": fields.Float,
+    "execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
+    "status": fields.String,
+    "outputs": fields.Raw(attribute="outputs"),
+}
+
+
 workflow_run_node_execution_fields = {
 workflow_run_node_execution_fields = {
     "id": fields.String,
     "id": fields.String,
     "index": fields.Integer,
     "index": fields.Integer,
@@ -99,6 +112,7 @@ workflow_run_node_execution_fields = {
     "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
     "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
     "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
     "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
     "finished_at": TimestampField,
     "finished_at": TimestampField,
+    "retry_events": fields.List(fields.Nested(retry_event_field)),
 }
 }
 
 
 workflow_run_node_execution_list_fields = {
 workflow_run_node_execution_list_fields = {

+ 33 - 0
api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py

@@ -0,0 +1,33 @@
+"""add retry_index field to node-execution  model
+
+Revision ID: 348cb0a93d53
+Revises: cf8f4fc45278
+Create Date: 2024-12-16 01:23:13.093432
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '348cb0a93d53'
+down_revision = 'cf8f4fc45278'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
+        batch_op.drop_column('retry_index')
+
+    # ### end Alembic commands ###

+ 2 - 0
api/models/workflow.py

@@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum):
     SUCCEEDED = "succeeded"
     SUCCEEDED = "succeeded"
     FAILED = "failed"
     FAILED = "failed"
     EXCEPTION = "exception"
     EXCEPTION = "exception"
+    RETRY = "retry"
 
 
     @classmethod
     @classmethod
     def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
     def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
@@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model):
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by = db.Column(StringUUID, nullable=False)
     created_by = db.Column(StringUUID, nullable=False)
     finished_at = db.Column(db.DateTime)
     finished_at = db.Column(db.DateTime)
+    retry_index = db.Column(db.Integer, server_default=db.text("0"))
 
 
     @property
     @property
     def created_by_account(self):
     def created_by_account(self):

+ 93 - 48
api/services/workflow_service.py

@@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData
 from core.workflow.nodes.base.node import BaseNode
 from core.workflow.nodes.base.node import BaseNode
 from core.workflow.nodes.enums import ErrorStrategy
 from core.workflow.nodes.enums import ErrorStrategy
 from core.workflow.nodes.event import RunCompletedEvent
 from core.workflow.nodes.event import RunCompletedEvent
+from core.workflow.nodes.event.event import SingleStepRetryEvent
 from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
 from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
 from core.workflow.workflow_entry import WorkflowEntry
 from core.workflow.workflow_entry import WorkflowEntry
 from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
 from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@@ -220,56 +221,99 @@ class WorkflowService:
 
 
         # run draft workflow node
         # run draft workflow node
         start_at = time.perf_counter()
         start_at = time.perf_counter()
+        retries = 0
+        max_retries = 0
+        should_retry = True
+        retry_events = []
 
 
         try:
         try:
-            node_instance, generator = WorkflowEntry.single_step_run(
-                workflow=draft_workflow,
-                node_id=node_id,
-                user_inputs=user_inputs,
-                user_id=account.id,
-            )
-            node_instance = cast(BaseNode[BaseNodeData], node_instance)
-            node_run_result: NodeRunResult | None = None
-            for event in generator:
-                if isinstance(event, RunCompletedEvent):
-                    node_run_result = event.run_result
-
-                    # sign output files
-                    node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
-                    break
-
-            if not node_run_result:
-                raise ValueError("Node run failed with no run result")
-            # single step debug mode error handling return
-            if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
-                node_error_args = {
-                    "status": WorkflowNodeExecutionStatus.EXCEPTION,
-                    "error": node_run_result.error,
-                    "inputs": node_run_result.inputs,
-                    "metadata": {"error_strategy": node_instance.node_data.error_strategy},
-                }
-                if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
-                    node_run_result = NodeRunResult(
-                        **node_error_args,
-                        outputs={
-                            **node_instance.node_data.default_value_dict,
-                            "error_message": node_run_result.error,
-                            "error_type": node_run_result.error_type,
-                        },
-                    )
-                else:
-                    node_run_result = NodeRunResult(
-                        **node_error_args,
-                        outputs={
-                            "error_message": node_run_result.error,
-                            "error_type": node_run_result.error_type,
-                        },
-                    )
-            run_succeeded = node_run_result.status in (
-                WorkflowNodeExecutionStatus.SUCCEEDED,
-                WorkflowNodeExecutionStatus.EXCEPTION,
-            )
-            error = node_run_result.error if not run_succeeded else None
+            while retries <= max_retries and should_retry:
+                retry_start_at = time.perf_counter()
+                node_instance, generator = WorkflowEntry.single_step_run(
+                    workflow=draft_workflow,
+                    node_id=node_id,
+                    user_inputs=user_inputs,
+                    user_id=account.id,
+                )
+                node_instance = cast(BaseNode[BaseNodeData], node_instance)
+                max_retries = (
+                    node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0
+                )
+                retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
+                node_run_result: NodeRunResult | None = None
+                for event in generator:
+                    if isinstance(event, RunCompletedEvent):
+                        node_run_result = event.run_result
+
+                        # sign output files
+                        node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
+                        break
+
+                if not node_run_result:
+                    raise ValueError("Node run failed with no run result")
+                # single step debug mode error handling return
+                if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                    if (
+                        retries == max_retries
+                        and node_instance.node_type == NodeType.HTTP_REQUEST
+                        and node_run_result.outputs
+                        and not node_instance.should_continue_on_error
+                    ):
+                        node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
+                        should_retry = False
+                    else:
+                        if node_instance.should_retry:
+                            node_run_result.status = WorkflowNodeExecutionStatus.RETRY
+                            retries += 1
+                            node_run_result.retry_index = retries
+                            retry_events.append(
+                                SingleStepRetryEvent(
+                                    inputs=WorkflowEntry.handle_special_values(node_run_result.inputs)
+                                    if node_run_result.inputs
+                                    else None,
+                                    error=node_run_result.error,
+                                    outputs=WorkflowEntry.handle_special_values(node_run_result.outputs)
+                                    if node_run_result.outputs
+                                    else None,
+                                    retry_index=node_run_result.retry_index,
+                                    elapsed_time=time.perf_counter() - retry_start_at,
+                                    execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata)
+                                    if node_run_result.metadata
+                                    else None,
+                                )
+                            )
+                            time.sleep(retry_interval)
+                        else:
+                            should_retry = False
+                    if node_instance.should_continue_on_error:
+                        node_error_args = {
+                            "status": WorkflowNodeExecutionStatus.EXCEPTION,
+                            "error": node_run_result.error,
+                            "inputs": node_run_result.inputs,
+                            "metadata": {"error_strategy": node_instance.node_data.error_strategy},
+                        }
+                        if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
+                            node_run_result = NodeRunResult(
+                                **node_error_args,
+                                outputs={
+                                    **node_instance.node_data.default_value_dict,
+                                    "error_message": node_run_result.error,
+                                    "error_type": node_run_result.error_type,
+                                },
+                            )
+                        else:
+                            node_run_result = NodeRunResult(
+                                **node_error_args,
+                                outputs={
+                                    "error_message": node_run_result.error,
+                                    "error_type": node_run_result.error_type,
+                                },
+                            )
+                run_succeeded = node_run_result.status in (
+                    WorkflowNodeExecutionStatus.SUCCEEDED,
+                    WorkflowNodeExecutionStatus.EXCEPTION,
+                )
+                error = node_run_result.error if not run_succeeded else None
         except WorkflowNodeRunFailedError as e:
         except WorkflowNodeRunFailedError as e:
             node_instance = e.node_instance
             node_instance = e.node_instance
             run_succeeded = False
             run_succeeded = False
@@ -318,6 +362,7 @@ class WorkflowService:
 
 
         db.session.add(workflow_node_execution)
         db.session.add(workflow_node_execution)
         db.session.commit()
         db.session.commit()
+        workflow_node_execution.retry_events = retry_events
 
 
         return workflow_node_execution
         return workflow_node_execution
 
 

+ 9 - 3
api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py

@@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
 from core.workflow.enums import SystemVariableKey
 from core.workflow.enums import SystemVariableKey
 from core.workflow.graph_engine.entities.event import (
 from core.workflow.graph_engine.entities.event import (
     GraphRunPartialSucceededEvent,
     GraphRunPartialSucceededEvent,
-    GraphRunSucceededEvent,
     NodeRunExceptionEvent,
     NodeRunExceptionEvent,
     NodeRunStreamChunkEvent,
     NodeRunStreamChunkEvent,
 )
 )
@@ -14,7 +13,9 @@ from models.workflow import WorkflowType
 
 
 class ContinueOnErrorTestHelper:
 class ContinueOnErrorTestHelper:
     @staticmethod
     @staticmethod
-    def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
+    def get_code_node(
+        code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
+    ):
         """Helper method to create a code node configuration"""
         """Helper method to create a code node configuration"""
         node = {
         node = {
             "id": "node",
             "id": "node",
@@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
                 "code_language": "python3",
                 "code_language": "python3",
                 "code": "\n".join([line[4:] for line in code.split("\n")]),
                 "code": "\n".join([line[4:] for line in code.split("\n")]),
                 "type": "code",
                 "type": "code",
+                **retry_config,
             },
             },
         }
         }
         if default_value:
         if default_value:
@@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
 
 
     @staticmethod
     @staticmethod
     def get_http_node(
     def get_http_node(
-        error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
+        error_strategy: str = "fail-branch",
+        default_value: dict | None = None,
+        authorization_success: bool = False,
+        retry_config: dict = {},
     ):
     ):
         """Helper method to create a http node configuration"""
         """Helper method to create a http node configuration"""
         authorization = (
         authorization = (
@@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
                 "body": None,
                 "body": None,
                 "type": "http-request",
                 "type": "http-request",
                 "error_strategy": error_strategy,
                 "error_strategy": error_strategy,
+                **retry_config,
             },
             },
         }
         }
         if default_value:
         if default_value:

+ 73 - 0
api/tests/unit_tests/core/workflow/nodes/test_retry.py

@@ -0,0 +1,73 @@
+from core.workflow.graph_engine.entities.event import (
+    GraphRunFailedEvent,
+    GraphRunPartialSucceededEvent,
+    GraphRunSucceededEvent,
+    NodeRunRetryEvent,
+)
+from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
+
+DEFAULT_VALUE_EDGE = [
+    {
+        "id": "start-source-node-target",
+        "source": "start",
+        "target": "node",
+        "sourceHandle": "source",
+    },
+    {
+        "id": "node-source-answer-target",
+        "source": "node",
+        "target": "answer",
+        "sourceHandle": "source",
+    },
+]
+
+
+def test_retry_default_value_partial_success():
+    """retry default value node with partial success status"""
+    graph_config = {
+        "edges": DEFAULT_VALUE_EDGE,
+        "nodes": [
+            {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
+            {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
+            ContinueOnErrorTestHelper.get_http_node(
+                "default-value",
+                [{"key": "result", "type": "string", "value": "http node got error response"}],
+                retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
+            ),
+        ],
+    }
+
+    graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
+    events = list(graph_engine.run())
+    assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
+    assert events[-1].outputs == {"answer": "http node got error response"}
+    assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
+    assert len(events) == 11
+
+
+def test_retry_failed():
+    """retry failed with success status"""
+    error_code = """
+    def main() -> dict:
+        return {
+            "result": 1 / 0,
+        }
+    """
+
+    graph_config = {
+        "edges": DEFAULT_VALUE_EDGE,
+        "nodes": [
+            {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
+            {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
+            ContinueOnErrorTestHelper.get_http_node(
+                None,
+                None,
+                retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
+            ),
+        ],
+    }
+    graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
+    events = list(graph_engine.run())
+    assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
+    assert any(isinstance(e, GraphRunFailedEvent) for e in events)
+    assert len(events) == 8