Browse Source

Feat: Retry on node execution errors (#11871)

Co-authored-by: Novice Lee <novicelee@NoviPro.local>
Novice 4 months ago
parent
commit
7abc7fa573

+ 17 - 0
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
@@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     workflow_node_execution=workflow_node_execution,
                 )
 
+                if response:
+                    yield response
+            elif isinstance(
+                event,
+                QueueNodeRetryEvent,
+            ):
+                workflow_node_execution = self._handle_workflow_node_execution_retried(
+                    workflow_run=workflow_run, event=event
+                )
+
+                response = self._workflow_node_retry_to_stream_response(
+                    event=event,
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_node_execution=workflow_node_execution,
+                )
+
                 if response:
                     yield response
             elif isinstance(event, QueueParallelBranchRunStartedEvent):

+ 18 - 1
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
@@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution,
                 )
-
                 if node_failed_response:
                     yield node_failed_response
+            elif isinstance(
+                event,
+                QueueNodeRetryEvent,
+            ):
+                workflow_node_execution = self._handle_workflow_node_execution_retried(
+                    workflow_run=workflow_run, event=event
+                )
+
+                response = self._workflow_node_retry_to_stream_response(
+                    event=event,
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_node_execution=workflow_node_execution,
+                )
+
+                if response:
+                    yield response
+
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
                 if not workflow_run:
                     raise Exception("Workflow run not initialized.")

+ 32 - 0
api/core/app/apps/workflow_app_runner.py

@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
@@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
     NodeRunExceptionEvent,
     NodeRunFailedEvent,
     NodeRunRetrieverResourceEvent,
+    NodeRunRetryEvent,
     NodeRunStartedEvent,
     NodeRunStreamChunkEvent,
     NodeRunSucceededEvent,
@@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
                     error=event.error if isinstance(event, IterationRunFailedEvent) else None,
                 )
             )
+        elif isinstance(event, NodeRunRetryEvent):
+            self._publish_event(
+                QueueNodeRetryEvent(
+                    node_execution_id=event.id,
+                    node_id=event.node_id,
+                    node_type=event.node_type,
+                    node_data=event.node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.start_at,
+                    inputs=event.route_node_state.node_run_result.inputs
+                    if event.route_node_state.node_run_result
+                    else {},
+                    process_data=event.route_node_state.node_run_result.process_data
+                    if event.route_node_state.node_run_result
+                    else {},
+                    outputs=event.route_node_state.node_run_result.outputs
+                    if event.route_node_state.node_run_result
+                    else {},
+                    error=event.error,
+                    execution_metadata=event.route_node_state.node_run_result.metadata
+                    if event.route_node_state.node_run_result
+                    else {},
+                    in_iteration_id=event.in_iteration_id,
+                    retry_index=event.retry_index,
+                    start_index=event.start_index,
+                )
+            )
 
     def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
         """

+ 32 - 0
api/core/app/entities/queue_entities.py

@@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
     ERROR = "error"
     PING = "ping"
     STOP = "stop"
+    RETRY = "retry"
 
 
 class AppQueueEvent(BaseModel):
@@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
     iteration_duration_map: Optional[dict[str, float]] = None
 
 
+class QueueNodeRetryEvent(AppQueueEvent):
+    """QueueNodeRetryEvent entity"""
+
+    event: QueueEvent = QueueEvent.RETRY
+
+    node_execution_id: str
+    node_id: str
+    node_type: NodeType
+    node_data: BaseNodeData
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+    start_at: datetime
+
+    inputs: Optional[dict[str, Any]] = None
+    process_data: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
+    execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
+
+    error: str
+    retry_index: int  # retry index
+    start_index: int  # start index
+
+
 class QueueNodeInIterationFailedEvent(AppQueueEvent):
     """
     QueueNodeInIterationFailedEvent entity

+ 70 - 0
api/core/app/entities/task_entities.py

@@ -52,6 +52,7 @@ class StreamEvent(Enum):
     WORKFLOW_FINISHED = "workflow_finished"
     NODE_STARTED = "node_started"
     NODE_FINISHED = "node_finished"
+    NODE_RETRY = "node_retry"
     PARALLEL_BRANCH_STARTED = "parallel_branch_started"
     PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
     ITERATION_STARTED = "iteration_started"
@@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
         }
 
 
+class NodeRetryStreamResponse(StreamResponse):
+    """
+    NodeFinishStreamResponse entity
+    """
+
+    class Data(BaseModel):
+        """
+        Data entity
+        """
+
+        id: str
+        node_id: str
+        node_type: str
+        title: str
+        index: int
+        predecessor_node_id: Optional[str] = None
+        inputs: Optional[dict] = None
+        process_data: Optional[dict] = None
+        outputs: Optional[dict] = None
+        status: str
+        error: Optional[str] = None
+        elapsed_time: float
+        execution_metadata: Optional[dict] = None
+        created_at: int
+        finished_at: int
+        files: Optional[Sequence[Mapping[str, Any]]] = []
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
+        parent_parallel_id: Optional[str] = None
+        parent_parallel_start_node_id: Optional[str] = None
+        iteration_id: Optional[str] = None
+        retry_index: int = 0
+
+    event: StreamEvent = StreamEvent.NODE_RETRY
+    workflow_run_id: str
+    data: Data
+
+    def to_ignore_detail_dict(self):
+        return {
+            "event": self.event.value,
+            "task_id": self.task_id,
+            "workflow_run_id": self.workflow_run_id,
+            "data": {
+                "id": self.data.id,
+                "node_id": self.data.node_id,
+                "node_type": self.data.node_type,
+                "title": self.data.title,
+                "index": self.data.index,
+                "predecessor_node_id": self.data.predecessor_node_id,
+                "inputs": None,
+                "process_data": None,
+                "outputs": None,
+                "status": self.data.status,
+                "error": None,
+                "elapsed_time": self.data.elapsed_time,
+                "execution_metadata": None,
+                "created_at": self.data.created_at,
+                "finished_at": self.data.finished_at,
+                "files": [],
+                "parallel_id": self.data.parallel_id,
+                "parallel_start_node_id": self.data.parallel_start_node_id,
+                "parent_parallel_id": self.data.parent_parallel_id,
+                "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
+                "iteration_id": self.data.iteration_id,
+                "retry_index": self.data.retry_index,
+            },
+        }
+
+
 class ParallelBranchStartStreamResponse(StreamResponse):
     """
     ParallelBranchStartStreamResponse entity

+ 93 - 0
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
     QueueNodeExceptionEvent,
     QueueNodeFailedEvent,
     QueueNodeInIterationFailedEvent,
+    QueueNodeRetryEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
     QueueParallelBranchRunFailedEvent,
@@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
     IterationNodeNextStreamResponse,
     IterationNodeStartStreamResponse,
     NodeFinishStreamResponse,
+    NodeRetryStreamResponse,
     NodeStartStreamResponse,
     ParallelBranchFinishedStreamResponse,
     ParallelBranchStartStreamResponse,
@@ -423,6 +425,52 @@ class WorkflowCycleManage:
 
         return workflow_node_execution
 
+    def _handle_workflow_node_execution_retried(
+        self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
+    ) -> WorkflowNodeExecution:
+        """
+        Workflow node execution failed
+        :param event: queue node failed event
+        :return:
+        """
+        created_at = event.start_at
+        finished_at = datetime.now(UTC).replace(tzinfo=None)
+        elapsed_time = (finished_at - created_at).total_seconds()
+        inputs = WorkflowEntry.handle_special_values(event.inputs)
+        outputs = WorkflowEntry.handle_special_values(event.outputs)
+
+        workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.tenant_id = workflow_run.tenant_id
+        workflow_node_execution.app_id = workflow_run.app_id
+        workflow_node_execution.workflow_id = workflow_run.workflow_id
+        workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
+        workflow_node_execution.workflow_run_id = workflow_run.id
+        workflow_node_execution.node_execution_id = event.node_execution_id
+        workflow_node_execution.node_id = event.node_id
+        workflow_node_execution.node_type = event.node_type.value
+        workflow_node_execution.title = event.node_data.title
+        workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
+        workflow_node_execution.created_by_role = workflow_run.created_by_role
+        workflow_node_execution.created_by = workflow_run.created_by
+        workflow_node_execution.created_at = created_at
+        workflow_node_execution.finished_at = finished_at
+        workflow_node_execution.elapsed_time = elapsed_time
+        workflow_node_execution.error = event.error
+        workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
+        workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
+        workflow_node_execution.execution_metadata = json.dumps(
+            {
+                NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
+            }
+        )
+        workflow_node_execution.index = event.start_index
+
+        db.session.add(workflow_node_execution)
+        db.session.commit()
+        db.session.refresh(workflow_node_execution)
+
+        return workflow_node_execution
+
     #################################################
     #             to stream responses               #
     #################################################
@@ -587,6 +635,51 @@ class WorkflowCycleManage:
             ),
         )
 
+    def _workflow_node_retry_to_stream_response(
+        self,
+        event: QueueNodeRetryEvent,
+        task_id: str,
+        workflow_node_execution: WorkflowNodeExecution,
+    ) -> Optional[NodeFinishStreamResponse]:
+        """
+        Workflow node finish to stream response.
+        :param event: queue node succeeded or failed event
+        :param task_id: task id
+        :param workflow_node_execution: workflow node execution
+        :return:
+        """
+        if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
+            return None
+
+        return NodeRetryStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_node_execution.workflow_run_id,
+            data=NodeRetryStreamResponse.Data(
+                id=workflow_node_execution.id,
+                node_id=workflow_node_execution.node_id,
+                node_type=workflow_node_execution.node_type,
+                index=workflow_node_execution.index,
+                title=workflow_node_execution.title,
+                predecessor_node_id=workflow_node_execution.predecessor_node_id,
+                inputs=workflow_node_execution.inputs_dict,
+                process_data=workflow_node_execution.process_data_dict,
+                outputs=workflow_node_execution.outputs_dict,
+                status=workflow_node_execution.status,
+                error=workflow_node_execution.error,
+                elapsed_time=workflow_node_execution.elapsed_time,
+                execution_metadata=workflow_node_execution.execution_metadata_dict,
+                created_at=int(workflow_node_execution.created_at.timestamp()),
+                finished_at=int(workflow_node_execution.finished_at.timestamp()),
+                files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
+                parent_parallel_id=event.parent_parallel_id,
+                parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                iteration_id=event.in_iteration_id,
+                retry_index=event.retry_index,
+            ),
+        )
+
     def _workflow_parallel_branch_start_to_stream_response(
         self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
     ) -> ParallelBranchStartStreamResponse:

+ 0 - 1
api/core/helper/ssrf_proxy.py

@@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
         )
 
     retries = 0
-    stream = kwargs.pop("stream", False)
     while retries <= max_retries:
         try:
             if dify_config.SSRF_PROXY_ALL_URL:

+ 3 - 0
api/core/workflow/entities/node_entities.py

@@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
 
     error: Optional[str] = None  # error message if status is failed
     error_type: Optional[str] = None  # error type if status is failed
+
+    # single step node run retry
+    retry_index: int = 0

+ 7 - 0
api/core/workflow/graph_engine/entities/event.py

@@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
     error: str = Field(..., description="error")
 
 
+class NodeRunRetryEvent(BaseNodeEvent):
+    error: str = Field(..., description="error")
+    retry_index: int = Field(..., description="which retry attempt is about to be performed")
+    start_at: datetime = Field(..., description="retry start time")
+    start_index: int = Field(..., description="retry start index")
+
+
 ###########################################
 # Parallel Branch Events
 ###########################################

+ 175 - 135
api/core/workflow/graph_engine/graph_engine.py

@@ -5,6 +5,7 @@ import uuid
 from collections.abc import Generator, Mapping
 from concurrent.futures import ThreadPoolExecutor, wait
 from copy import copy, deepcopy
+from datetime import UTC, datetime
 from typing import Any, Optional, cast
 
 from flask import Flask, current_app
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
     NodeRunExceptionEvent,
     NodeRunFailedEvent,
     NodeRunRetrieverResourceEvent,
+    NodeRunRetryEvent,
     NodeRunStartedEvent,
     NodeRunStreamChunkEvent,
     NodeRunSucceededEvent,
@@ -581,7 +583,7 @@ class GraphEngine:
 
     def _run_node(
         self,
-        node_instance: BaseNode,
+        node_instance: BaseNode[BaseNodeData],
         route_node_state: RouteNodeState,
         parallel_id: Optional[str] = None,
         parallel_start_node_id: Optional[str] = None,
@@ -607,36 +609,121 @@ class GraphEngine:
         )
 
         db.session.close()
+        max_retries = node_instance.node_data.retry_config.max_retries
+        retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
+        retries = 0
+        shoudl_continue_retry = True
+        while shoudl_continue_retry and retries <= max_retries:
+            try:
+                # run node
+                retry_start_at = datetime.now(UTC).replace(tzinfo=None)
+                generator = node_instance.run()
+                for item in generator:
+                    if isinstance(item, GraphEngineEvent):
+                        if isinstance(item, BaseIterationEvent):
+                            # add parallel info to iteration event
+                            item.parallel_id = parallel_id
+                            item.parallel_start_node_id = parallel_start_node_id
+                            item.parent_parallel_id = parent_parallel_id
+                            item.parent_parallel_start_node_id = parent_parallel_start_node_id
+
+                        yield item
+                    else:
+                        if isinstance(item, RunCompletedEvent):
+                            run_result = item.run_result
+                            if run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                                if (
+                                    retries == max_retries
+                                    and node_instance.node_type == NodeType.HTTP_REQUEST
+                                    and run_result.outputs
+                                    and not node_instance.should_continue_on_error
+                                ):
+                                    run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
+                                if node_instance.should_retry and retries < max_retries:
+                                    retries += 1
+                                    self.graph_runtime_state.node_run_steps += 1
+                                    route_node_state.node_run_result = run_result
+                                    yield NodeRunRetryEvent(
+                                        id=node_instance.id,
+                                        node_id=node_instance.node_id,
+                                        node_type=node_instance.node_type,
+                                        node_data=node_instance.node_data,
+                                        route_node_state=route_node_state,
+                                        error=run_result.error,
+                                        retry_index=retries,
+                                        parallel_id=parallel_id,
+                                        parallel_start_node_id=parallel_start_node_id,
+                                        parent_parallel_id=parent_parallel_id,
+                                        parent_parallel_start_node_id=parent_parallel_start_node_id,
+                                        start_at=retry_start_at,
+                                        start_index=self.graph_runtime_state.node_run_steps,
+                                    )
+                                    time.sleep(retry_interval)
+                                    continue
+                            route_node_state.set_finished(run_result=run_result)
+
+                            if run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                                if node_instance.should_continue_on_error:
+                                    # if run failed, handle error
+                                    run_result = self._handle_continue_on_error(
+                                        node_instance,
+                                        item.run_result,
+                                        self.graph_runtime_state.variable_pool,
+                                        handle_exceptions=handle_exceptions,
+                                    )
+                                    route_node_state.node_run_result = run_result
+                                    route_node_state.status = RouteNodeState.Status.EXCEPTION
+                                    if run_result.outputs:
+                                        for variable_key, variable_value in run_result.outputs.items():
+                                            # append variables to variable pool recursively
+                                            self._append_variables_recursively(
+                                                node_id=node_instance.node_id,
+                                                variable_key_list=[variable_key],
+                                                variable_value=variable_value,
+                                            )
+                                    yield NodeRunExceptionEvent(
+                                        error=run_result.error or "System Error",
+                                        id=node_instance.id,
+                                        node_id=node_instance.node_id,
+                                        node_type=node_instance.node_type,
+                                        node_data=node_instance.node_data,
+                                        route_node_state=route_node_state,
+                                        parallel_id=parallel_id,
+                                        parallel_start_node_id=parallel_start_node_id,
+                                        parent_parallel_id=parent_parallel_id,
+                                        parent_parallel_start_node_id=parent_parallel_start_node_id,
+                                    )
+                                    shoudl_continue_retry = False
+                                else:
+                                    yield NodeRunFailedEvent(
+                                        error=route_node_state.failed_reason or "Unknown error.",
+                                        id=node_instance.id,
+                                        node_id=node_instance.node_id,
+                                        node_type=node_instance.node_type,
+                                        node_data=node_instance.node_data,
+                                        route_node_state=route_node_state,
+                                        parallel_id=parallel_id,
+                                        parallel_start_node_id=parallel_start_node_id,
+                                        parent_parallel_id=parent_parallel_id,
+                                        parent_parallel_start_node_id=parent_parallel_start_node_id,
+                                    )
+                                shoudl_continue_retry = False
+                            elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
+                                if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
+                                    node_instance.node_id
+                                ):
+                                    run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
+                                if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
+                                    # plus state total_tokens
+                                    self.graph_runtime_state.total_tokens += int(
+                                        run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)  # type: ignore[arg-type]
+                                    )
 
-        try:
-            # run node
-            generator = node_instance.run()
-            for item in generator:
-                if isinstance(item, GraphEngineEvent):
-                    if isinstance(item, BaseIterationEvent):
-                        # add parallel info to iteration event
-                        item.parallel_id = parallel_id
-                        item.parallel_start_node_id = parallel_start_node_id
-                        item.parent_parallel_id = parent_parallel_id
-                        item.parent_parallel_start_node_id = parent_parallel_start_node_id
+                                if run_result.llm_usage:
+                                    # use the latest usage
+                                    self.graph_runtime_state.llm_usage += run_result.llm_usage
 
-                    yield item
-                else:
-                    if isinstance(item, RunCompletedEvent):
-                        run_result = item.run_result
-                        route_node_state.set_finished(run_result=run_result)
-
-                        if run_result.status == WorkflowNodeExecutionStatus.FAILED:
-                            if node_instance.should_continue_on_error:
-                                # if run failed, handle error
-                                run_result = self._handle_continue_on_error(
-                                    node_instance,
-                                    item.run_result,
-                                    self.graph_runtime_state.variable_pool,
-                                    handle_exceptions=handle_exceptions,
-                                )
-                                route_node_state.node_run_result = run_result
-                                route_node_state.status = RouteNodeState.Status.EXCEPTION
+                                # append node output variables to variable pool
                                 if run_result.outputs:
                                     for variable_key, variable_value in run_result.outputs.items():
                                         # append variables to variable pool recursively
@@ -645,21 +732,23 @@ class GraphEngine:
                                             variable_key_list=[variable_key],
                                             variable_value=variable_value,
                                         )
-                                yield NodeRunExceptionEvent(
-                                    error=run_result.error or "System Error",
-                                    id=node_instance.id,
-                                    node_id=node_instance.node_id,
-                                    node_type=node_instance.node_type,
-                                    node_data=node_instance.node_data,
-                                    route_node_state=route_node_state,
-                                    parallel_id=parallel_id,
-                                    parallel_start_node_id=parallel_start_node_id,
-                                    parent_parallel_id=parent_parallel_id,
-                                    parent_parallel_start_node_id=parent_parallel_start_node_id,
-                                )
-                            else:
-                                yield NodeRunFailedEvent(
-                                    error=route_node_state.failed_reason or "Unknown error.",
+
+                                # add parallel info to run result metadata
+                                if parallel_id and parallel_start_node_id:
+                                    if not run_result.metadata:
+                                        run_result.metadata = {}
+
+                                    run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
+                                    run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
+                                        parallel_start_node_id
+                                    )
+                                    if parent_parallel_id and parent_parallel_start_node_id:
+                                        run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
+                                        run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
+                                            parent_parallel_start_node_id
+                                        )
+
+                                yield NodeRunSucceededEvent(
                                     id=node_instance.id,
                                     node_id=node_instance.node_id,
                                     node_type=node_instance.node_type,
@@ -670,108 +759,59 @@ class GraphEngine:
                                     parent_parallel_id=parent_parallel_id,
                                     parent_parallel_start_node_id=parent_parallel_start_node_id,
                                 )
+                                shoudl_continue_retry = False
 
-                        elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
-                            if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
-                                node_instance.node_id
-                            ):
-                                run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
-                            if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
-                                # plus state total_tokens
-                                self.graph_runtime_state.total_tokens += int(
-                                    run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)  # type: ignore[arg-type]
-                                )
-
-                            if run_result.llm_usage:
-                                # use the latest usage
-                                self.graph_runtime_state.llm_usage += run_result.llm_usage
-
-                            # append node output variables to variable pool
-                            if run_result.outputs:
-                                for variable_key, variable_value in run_result.outputs.items():
-                                    # append variables to variable pool recursively
-                                    self._append_variables_recursively(
-                                        node_id=node_instance.node_id,
-                                        variable_key_list=[variable_key],
-                                        variable_value=variable_value,
-                                    )
-
-                            # add parallel info to run result metadata
-                            if parallel_id and parallel_start_node_id:
-                                if not run_result.metadata:
-                                    run_result.metadata = {}
-
-                                run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
-                                run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
-                                if parent_parallel_id and parent_parallel_start_node_id:
-                                    run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
-                                    run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
-                                        parent_parallel_start_node_id
-                                    )
-
-                            yield NodeRunSucceededEvent(
+                            break
+                        elif isinstance(item, RunStreamChunkEvent):
+                            yield NodeRunStreamChunkEvent(
                                 id=node_instance.id,
                                 node_id=node_instance.node_id,
                                 node_type=node_instance.node_type,
                                 node_data=node_instance.node_data,
+                                chunk_content=item.chunk_content,
+                                from_variable_selector=item.from_variable_selector,
                                 route_node_state=route_node_state,
                                 parallel_id=parallel_id,
                                 parallel_start_node_id=parallel_start_node_id,
                                 parent_parallel_id=parent_parallel_id,
                                 parent_parallel_start_node_id=parent_parallel_start_node_id,
                             )
-
-                        break
-                    elif isinstance(item, RunStreamChunkEvent):
-                        yield NodeRunStreamChunkEvent(
-                            id=node_instance.id,
-                            node_id=node_instance.node_id,
-                            node_type=node_instance.node_type,
-                            node_data=node_instance.node_data,
-                            chunk_content=item.chunk_content,
-                            from_variable_selector=item.from_variable_selector,
-                            route_node_state=route_node_state,
-                            parallel_id=parallel_id,
-                            parallel_start_node_id=parallel_start_node_id,
-                            parent_parallel_id=parent_parallel_id,
-                            parent_parallel_start_node_id=parent_parallel_start_node_id,
-                        )
-                    elif isinstance(item, RunRetrieverResourceEvent):
-                        yield NodeRunRetrieverResourceEvent(
-                            id=node_instance.id,
-                            node_id=node_instance.node_id,
-                            node_type=node_instance.node_type,
-                            node_data=node_instance.node_data,
-                            retriever_resources=item.retriever_resources,
-                            context=item.context,
-                            route_node_state=route_node_state,
-                            parallel_id=parallel_id,
-                            parallel_start_node_id=parallel_start_node_id,
-                            parent_parallel_id=parent_parallel_id,
-                            parent_parallel_start_node_id=parent_parallel_start_node_id,
-                        )
-        except GenerateTaskStoppedError:
-            # trigger node run failed event
-            route_node_state.status = RouteNodeState.Status.FAILED
-            route_node_state.failed_reason = "Workflow stopped."
-            yield NodeRunFailedEvent(
-                error="Workflow stopped.",
-                id=node_instance.id,
-                node_id=node_instance.node_id,
-                node_type=node_instance.node_type,
-                node_data=node_instance.node_data,
-                route_node_state=route_node_state,
-                parallel_id=parallel_id,
-                parallel_start_node_id=parallel_start_node_id,
-                parent_parallel_id=parent_parallel_id,
-                parent_parallel_start_node_id=parent_parallel_start_node_id,
-            )
-            return
-        except Exception as e:
-            logger.exception(f"Node {node_instance.node_data.title} run failed")
-            raise e
-        finally:
-            db.session.close()
+                        elif isinstance(item, RunRetrieverResourceEvent):
+                            yield NodeRunRetrieverResourceEvent(
+                                id=node_instance.id,
+                                node_id=node_instance.node_id,
+                                node_type=node_instance.node_type,
+                                node_data=node_instance.node_data,
+                                retriever_resources=item.retriever_resources,
+                                context=item.context,
+                                route_node_state=route_node_state,
+                                parallel_id=parallel_id,
+                                parallel_start_node_id=parallel_start_node_id,
+                                parent_parallel_id=parent_parallel_id,
+                                parent_parallel_start_node_id=parent_parallel_start_node_id,
+                            )
+            except GenerateTaskStoppedError:
+                # trigger node run failed event
+                route_node_state.status = RouteNodeState.Status.FAILED
+                route_node_state.failed_reason = "Workflow stopped."
+                yield NodeRunFailedEvent(
+                    error="Workflow stopped.",
+                    id=node_instance.id,
+                    node_id=node_instance.node_id,
+                    node_type=node_instance.node_type,
+                    node_data=node_instance.node_data,
+                    route_node_state=route_node_state,
+                    parallel_id=parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id,
+                )
+                return
+            except Exception as e:
+                logger.exception(f"Node {node_instance.node_data.title} run failed")
+                raise e
+            finally:
+                db.session.close()
 
     def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
         """

+ 13 - 0
api/core/workflow/nodes/base/entities.py

@@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
         return self
 
 
+class RetryConfig(BaseModel):
+    """node retry config"""
+
+    max_retries: int = 0  # max retry times
+    retry_interval: int = 0  # retry interval in milliseconds
+    retry_enabled: bool = False  # whether retry is enabled
+
+    @property
+    def retry_interval_seconds(self) -> float:
+        return self.retry_interval / 1000
+
+
 class BaseNodeData(ABC, BaseModel):
     title: str
     desc: Optional[str] = None
     error_strategy: Optional[ErrorStrategy] = None
     default_value: Optional[list[DefaultValue]] = None
     version: str = "1"
+    retry_config: RetryConfig = RetryConfig()
 
     @property
     def default_value_dict(self):

+ 10 - 1
api/core/workflow/nodes/base/node.py

@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
 from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
 
 from core.workflow.entities.node_entities import NodeRunResult
-from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
+from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
 from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
 from models.workflow import WorkflowNodeExecutionStatus
 
@@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
             bool: if should continue on error
         """
         return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
+
+    @property
+    def should_retry(self) -> bool:
+        """judge if should retry
+
+        Returns:
+            bool: if should retry
+        """
+        return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

+ 1 - 0
api/core/workflow/nodes/enums.py

@@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
 
 
 CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
+RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]

+ 8 - 1
api/core/workflow/nodes/event/__init__.py

@@ -1,4 +1,10 @@
-from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
+from .event import (
+    ModelInvokeCompletedEvent,
+    RunCompletedEvent,
+    RunRetrieverResourceEvent,
+    RunRetryEvent,
+    RunStreamChunkEvent,
+)
 from .types import NodeEvent
 
 __all__ = [
@@ -6,5 +12,6 @@ __all__ = [
     "NodeEvent",
     "RunCompletedEvent",
     "RunRetrieverResourceEvent",
+    "RunRetryEvent",
     "RunStreamChunkEvent",
 ]

+ 25 - 0
api/core/workflow/nodes/event/event.py

@@ -1,7 +1,10 @@
+from datetime import datetime
+
 from pydantic import BaseModel, Field
 
 from core.model_runtime.entities.llm_entities import LLMUsage
 from core.workflow.entities.node_entities import NodeRunResult
+from models.workflow import WorkflowNodeExecutionStatus
 
 
 class RunCompletedEvent(BaseModel):
@@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel):
     text: str
     usage: LLMUsage
     finish_reason: str | None = None
+
+
+class RunRetryEvent(BaseModel):
+    """Node Run Retry event"""
+
+    error: str = Field(..., description="error")
+    retry_index: int = Field(..., description="Retry attempt number")
+    start_at: datetime = Field(..., description="Retry start time")
+
+
+class SingleStepRetryEvent(BaseModel):
+    """Single step retry event"""
+
+    status: str = WorkflowNodeExecutionStatus.RETRY.value
+
+    inputs: dict | None = Field(..., description="input")
+    error: str = Field(..., description="error")
+    outputs: dict = Field(..., description="output")
+    retry_index: int = Field(..., description="Retry attempt number")
+    error: str = Field(..., description="error")
+    elapsed_time: float = Field(..., description="elapsed time")
+    execution_metadata: dict | None = Field(..., description="execution metadata")

+ 4 - 0
api/core/workflow/nodes/http_request/executor.py

@@ -45,6 +45,7 @@ class Executor:
     headers: dict[str, str]
     auth: HttpRequestNodeAuthorization
     timeout: HttpRequestNodeTimeout
+    max_retries: int
 
     boundary: str
 
@@ -54,6 +55,7 @@ class Executor:
         node_data: HttpRequestNodeData,
         timeout: HttpRequestNodeTimeout,
         variable_pool: VariablePool,
+        max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
     ):
         # If authorization API key is present, convert the API key using the variable pool
         if node_data.authorization.type == "api-key":
@@ -73,6 +75,7 @@ class Executor:
         self.files = None
         self.data = None
         self.json = None
+        self.max_retries = max_retries
 
         # init template
         self.variable_pool = variable_pool
@@ -241,6 +244,7 @@ class Executor:
             "params": self.params,
             "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
             "follow_redirects": True,
+            "max_retries": self.max_retries,
         }
         # request_args = {k: v for k, v in request_args.items() if v is not None}
         try:

+ 7 - 1
api/core/workflow/nodes/http_request/node.py

@@ -52,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
                     "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
                 },
             },
+            "retry_config": {
+                "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
+                "retry_interval": 0.5 * (2**2),
+                "retry_enabled": True,
+            },
         }
 
     def _run(self) -> NodeRunResult:
@@ -61,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
                 node_data=self.node_data,
                 timeout=self._get_request_timeout(self.node_data),
                 variable_pool=self.graph_runtime_state.variable_pool,
+                max_retries=0,
             )
             process_data["request"] = http_executor.to_log()
 
             response = http_executor.invoke()
             files = self.extract_files(url=http_executor.url, response=response)
-            if not response.response.is_success and self.should_continue_on_error:
+            if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
                 return NodeRunResult(
                     status=WorkflowNodeExecutionStatus.FAILED,
                     outputs={

+ 14 - 0
api/fields/workflow_run_fields.py

@@ -29,6 +29,7 @@ workflow_run_for_list_fields = {
     "created_at": TimestampField,
     "finished_at": TimestampField,
     "exceptions_count": fields.Integer,
+    "retry_index": fields.Integer,
 }
 
 advanced_chat_workflow_run_for_list_fields = {
@@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = {
     "created_at": TimestampField,
     "finished_at": TimestampField,
     "exceptions_count": fields.Integer,
+    "retry_index": fields.Integer,
 }
 
 advanced_chat_workflow_run_pagination_fields = {
@@ -79,6 +81,17 @@ workflow_run_detail_fields = {
     "exceptions_count": fields.Integer,
 }
 
+retry_event_field = {
+    "error": fields.String,
+    "retry_index": fields.Integer,
+    "inputs": fields.Raw(attribute="inputs"),
+    "elapsed_time": fields.Float,
+    "execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
+    "status": fields.String,
+    "outputs": fields.Raw(attribute="outputs"),
+}
+
+
 workflow_run_node_execution_fields = {
     "id": fields.String,
     "index": fields.Integer,
@@ -99,6 +112,7 @@ workflow_run_node_execution_fields = {
     "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
     "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
     "finished_at": TimestampField,
+    "retry_events": fields.List(fields.Nested(retry_event_field)),
 }
 
 workflow_run_node_execution_list_fields = {

+ 33 - 0
api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py

@@ -0,0 +1,33 @@
+"""add retry_index field to node-execution  model
+
+Revision ID: 348cb0a93d53
+Revises: cf8f4fc45278
+Create Date: 2024-12-16 01:23:13.093432
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '348cb0a93d53'
+down_revision = 'cf8f4fc45278'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
+        batch_op.drop_column('retry_index')
+
+    # ### end Alembic commands ###

+ 2 - 0
api/models/workflow.py

@@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum):
     SUCCEEDED = "succeeded"
     FAILED = "failed"
     EXCEPTION = "exception"
+    RETRY = "retry"
 
     @classmethod
     def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
@@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model):
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by = db.Column(StringUUID, nullable=False)
     finished_at = db.Column(db.DateTime)
+    retry_index = db.Column(db.Integer, server_default=db.text("0"))
 
     @property
     def created_by_account(self):

+ 93 - 48
api/services/workflow_service.py

@@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData
 from core.workflow.nodes.base.node import BaseNode
 from core.workflow.nodes.enums import ErrorStrategy
 from core.workflow.nodes.event import RunCompletedEvent
+from core.workflow.nodes.event.event import SingleStepRetryEvent
 from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
 from core.workflow.workflow_entry import WorkflowEntry
 from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@@ -220,56 +221,99 @@ class WorkflowService:
 
         # run draft workflow node
         start_at = time.perf_counter()
+        retries = 0
+        max_retries = 0
+        should_retry = True
+        retry_events = []
 
         try:
-            node_instance, generator = WorkflowEntry.single_step_run(
-                workflow=draft_workflow,
-                node_id=node_id,
-                user_inputs=user_inputs,
-                user_id=account.id,
-            )
-            node_instance = cast(BaseNode[BaseNodeData], node_instance)
-            node_run_result: NodeRunResult | None = None
-            for event in generator:
-                if isinstance(event, RunCompletedEvent):
-                    node_run_result = event.run_result
-
-                    # sign output files
-                    node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
-                    break
-
-            if not node_run_result:
-                raise ValueError("Node run failed with no run result")
-            # single step debug mode error handling return
-            if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
-                node_error_args = {
-                    "status": WorkflowNodeExecutionStatus.EXCEPTION,
-                    "error": node_run_result.error,
-                    "inputs": node_run_result.inputs,
-                    "metadata": {"error_strategy": node_instance.node_data.error_strategy},
-                }
-                if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
-                    node_run_result = NodeRunResult(
-                        **node_error_args,
-                        outputs={
-                            **node_instance.node_data.default_value_dict,
-                            "error_message": node_run_result.error,
-                            "error_type": node_run_result.error_type,
-                        },
-                    )
-                else:
-                    node_run_result = NodeRunResult(
-                        **node_error_args,
-                        outputs={
-                            "error_message": node_run_result.error,
-                            "error_type": node_run_result.error_type,
-                        },
-                    )
-            run_succeeded = node_run_result.status in (
-                WorkflowNodeExecutionStatus.SUCCEEDED,
-                WorkflowNodeExecutionStatus.EXCEPTION,
-            )
-            error = node_run_result.error if not run_succeeded else None
+            while retries <= max_retries and should_retry:
+                retry_start_at = time.perf_counter()
+                node_instance, generator = WorkflowEntry.single_step_run(
+                    workflow=draft_workflow,
+                    node_id=node_id,
+                    user_inputs=user_inputs,
+                    user_id=account.id,
+                )
+                node_instance = cast(BaseNode[BaseNodeData], node_instance)
+                max_retries = (
+                    node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0
+                )
+                retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
+                node_run_result: NodeRunResult | None = None
+                for event in generator:
+                    if isinstance(event, RunCompletedEvent):
+                        node_run_result = event.run_result
+
+                        # sign output files
+                        node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
+                        break
+
+                if not node_run_result:
+                    raise ValueError("Node run failed with no run result")
+                # single step debug mode error handling return
+                if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                    if (
+                        retries == max_retries
+                        and node_instance.node_type == NodeType.HTTP_REQUEST
+                        and node_run_result.outputs
+                        and not node_instance.should_continue_on_error
+                    ):
+                        node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
+                        should_retry = False
+                    else:
+                        if node_instance.should_retry:
+                            node_run_result.status = WorkflowNodeExecutionStatus.RETRY
+                            retries += 1
+                            node_run_result.retry_index = retries
+                            retry_events.append(
+                                SingleStepRetryEvent(
+                                    inputs=WorkflowEntry.handle_special_values(node_run_result.inputs)
+                                    if node_run_result.inputs
+                                    else None,
+                                    error=node_run_result.error,
+                                    outputs=WorkflowEntry.handle_special_values(node_run_result.outputs)
+                                    if node_run_result.outputs
+                                    else None,
+                                    retry_index=node_run_result.retry_index,
+                                    elapsed_time=time.perf_counter() - retry_start_at,
+                                    execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata)
+                                    if node_run_result.metadata
+                                    else None,
+                                )
+                            )
+                            time.sleep(retry_interval)
+                        else:
+                            should_retry = False
+                    if node_instance.should_continue_on_error:
+                        node_error_args = {
+                            "status": WorkflowNodeExecutionStatus.EXCEPTION,
+                            "error": node_run_result.error,
+                            "inputs": node_run_result.inputs,
+                            "metadata": {"error_strategy": node_instance.node_data.error_strategy},
+                        }
+                        if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
+                            node_run_result = NodeRunResult(
+                                **node_error_args,
+                                outputs={
+                                    **node_instance.node_data.default_value_dict,
+                                    "error_message": node_run_result.error,
+                                    "error_type": node_run_result.error_type,
+                                },
+                            )
+                        else:
+                            node_run_result = NodeRunResult(
+                                **node_error_args,
+                                outputs={
+                                    "error_message": node_run_result.error,
+                                    "error_type": node_run_result.error_type,
+                                },
+                            )
+                run_succeeded = node_run_result.status in (
+                    WorkflowNodeExecutionStatus.SUCCEEDED,
+                    WorkflowNodeExecutionStatus.EXCEPTION,
+                )
+                error = node_run_result.error if not run_succeeded else None
         except WorkflowNodeRunFailedError as e:
             node_instance = e.node_instance
             run_succeeded = False
@@ -318,6 +362,7 @@ class WorkflowService:
 
         db.session.add(workflow_node_execution)
         db.session.commit()
+        workflow_node_execution.retry_events = retry_events
 
         return workflow_node_execution
 

+ 9 - 3
api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py

@@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
 from core.workflow.enums import SystemVariableKey
 from core.workflow.graph_engine.entities.event import (
     GraphRunPartialSucceededEvent,
-    GraphRunSucceededEvent,
     NodeRunExceptionEvent,
     NodeRunStreamChunkEvent,
 )
@@ -14,7 +13,9 @@ from models.workflow import WorkflowType
 
 class ContinueOnErrorTestHelper:
     @staticmethod
-    def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
+    def get_code_node(
+        code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
+    ):
         """Helper method to create a code node configuration"""
         node = {
             "id": "node",
@@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
                 "code_language": "python3",
                 "code": "\n".join([line[4:] for line in code.split("\n")]),
                 "type": "code",
+                **retry_config,
             },
         }
         if default_value:
@@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
 
     @staticmethod
     def get_http_node(
-        error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
+        error_strategy: str = "fail-branch",
+        default_value: dict | None = None,
+        authorization_success: bool = False,
+        retry_config: dict = {},
     ):
         """Helper method to create a http node configuration"""
         authorization = (
@@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
                 "body": None,
                 "type": "http-request",
                 "error_strategy": error_strategy,
+                **retry_config,
             },
         }
         if default_value:

+ 73 - 0
api/tests/unit_tests/core/workflow/nodes/test_retry.py

@@ -0,0 +1,73 @@
+from core.workflow.graph_engine.entities.event import (
+    GraphRunFailedEvent,
+    GraphRunPartialSucceededEvent,
+    GraphRunSucceededEvent,
+    NodeRunRetryEvent,
+)
+from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
+
+DEFAULT_VALUE_EDGE = [
+    {
+        "id": "start-source-node-target",
+        "source": "start",
+        "target": "node",
+        "sourceHandle": "source",
+    },
+    {
+        "id": "node-source-answer-target",
+        "source": "node",
+        "target": "answer",
+        "sourceHandle": "source",
+    },
+]
+
+
+def test_retry_default_value_partial_success():
+    """retry default value node with partial success status"""
+    graph_config = {
+        "edges": DEFAULT_VALUE_EDGE,
+        "nodes": [
+            {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
+            {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
+            ContinueOnErrorTestHelper.get_http_node(
+                "default-value",
+                [{"key": "result", "type": "string", "value": "http node got error response"}],
+                retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
+            ),
+        ],
+    }
+
+    graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
+    events = list(graph_engine.run())
+    assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
+    assert events[-1].outputs == {"answer": "http node got error response"}
+    assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
+    assert len(events) == 11
+
+
+def test_retry_failed():
+    """retry failed with success status"""
+    error_code = """
+    def main() -> dict:
+        return {
+            "result": 1 / 0,
+        }
+    """
+
+    graph_config = {
+        "edges": DEFAULT_VALUE_EDGE,
+        "nodes": [
+            {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
+            {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
+            ContinueOnErrorTestHelper.get_http_node(
+                None,
+                None,
+                retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
+            ),
+        ],
+    }
+    graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
+    events = list(graph_engine.run())
+    assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
+    assert any(isinstance(e, GraphRunFailedEvent) for e in events)
+    assert len(events) == 8