Просмотр исходного кода

chore(api/core): apply ruff reformatting (#7624)

Bowen Liang 7 месяцев назад
Родитель
Сommit
2cf1187b32
100 измененных файлов с 2661 добавлено и 2885 удалено
  1. 1 1
      api/core/__init__.py
  2. 94 121
      api/core/agent/cot_agent_runner.py
  3. 12 14
      api/core/agent/cot_chat_agent_runner.py
  4. 12 9
      api/core/agent/cot_completion_agent_runner.py
  5. 10 6
      api/core/agent/entities.py
  6. 115 120
      api/core/agent/fc_agent_runner.py
  7. 37 37
      api/core/agent/output_parser/cot_output_parser.py
  8. 9 9
      api/core/agent/prompt/template.py
  9. 8 18
      api/core/app/app_config/base_app_config_manager.py
  10. 9 14
      api/core/app/app_config/common/sensitive_word_avoidance/manager.py
  11. 33 30
      api/core/app/app_config/easy_ui_based_app/agent/manager.py
  12. 40 48
      api/core/app/app_config/easy_ui_based_app/dataset/manager.py
  13. 9 21
      api/core/app/app_config/easy_ui_based_app/model_config/converter.py
  14. 16 17
      api/core/app/app_config/easy_ui_based_app/model_config/manager.py
  15. 26 31
      api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py
  16. 22 30
      api/core/app/app_config/easy_ui_based_app/variables/manager.py
  17. 29 17
      api/core/app/app_config/entities.py
  18. 13 15
      api/core/app/app_config/features/file_upload/manager.py
  19. 3 5
      api/core/app/app_config/features/more_like_this/manager.py
  20. 2 4
      api/core/app/app_config/features/opening_statement/manager.py
  21. 3 5
      api/core/app/app_config/features/retrieval_resource/manager.py
  22. 3 5
      api/core/app/app_config/features/speech_to_text/manager.py
  23. 7 7
      api/core/app/app_config/features/suggested_questions_after_answer/manager.py
  24. 6 10
      api/core/app/app_config/features/text_to_speech/manager.py
  25. 9 17
      api/core/app/apps/advanced_chat/app_config_manager.py
  26. 76 90
      api/core/app/apps/advanced_chat/app_generator.py
  27. 16 21
      api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
  28. 35 49
      api/core/app/apps/advanced_chat/app_runner.py
  29. 29 25
      api/core/app/apps/advanced_chat/generate_response_converter.py
  30. 74 93
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  31. 27 32
      api/core/app/apps/agent_chat/app_config_manager.py
  32. 43 59
      api/core/app/apps/agent_chat/app_generator.py
  33. 49 44
      api/core/app/apps/agent_chat/app_runner.py
  34. 29 27
      api/core/app/apps/agent_chat/generate_response_converter.py
  35. 45 40
      api/core/app/apps/base_app_generate_response_converter.py
  36. 6 6
      api/core/app/apps/base_app_generator.py
  37. 14 16
      api/core/app/apps/base_app_queue_manager.py
  38. 125 147
      api/core/app/apps/base_app_runner.py
  39. 22 23
      api/core/app/apps/chat/app_config_manager.py
  40. 48 58
      api/core/app/apps/chat/app_generator.py
  41. 23 28
      api/core/app/apps/chat/app_runner.py
  42. 29 27
      api/core/app/apps/chat/generate_response_converter.py
  43. 15 20
      api/core/app/apps/completion/app_config_manager.py
  44. 80 95
      api/core/app/apps/completion/app_generator.py
  45. 15 21
      api/core/app/apps/completion/app_runner.py
  46. 26 24
      api/core/app/apps/completion/generate_response_converter.py
  47. 50 54
      api/core/app/apps/message_based_app_generator.py
  48. 8 13
      api/core/app/apps/message_based_app_queue_manager.py
  49. 6 12
      api/core/app/apps/workflow/app_config_manager.py
  50. 59 70
      api/core/app/apps/workflow/app_generator.py
  51. 10 14
      api/core/app/apps/workflow/app_queue_manager.py
  52. 10 13
      api/core/app/apps/workflow/app_runner.py
  53. 12 10
      api/core/app/apps/workflow/generate_response_converter.py
  54. 60 73
      api/core/app/apps/workflow/generate_task_pipeline.py
  55. 71 79
      api/core/app/apps/workflow_app_runner.py
  56. 85 115
      api/core/app/apps/workflow_logging_callback.py
  57. 23 11
      api/core/app/entities/app_invoke_entities.py
  58. 44 10
      api/core/app/entities/queue_entities.py
  59. 37 3
      api/core/app/entities/task_entities.py
  60. 27 33
      api/core/app/features/annotation_reply/annotation_reply.py
  61. 4 6
      api/core/app/features/hosting_moderation/hosting_moderation.py
  62. 14 9
      api/core/app/features/rate_limiting/rate_limit.py
  63. 21 21
      api/core/app/segments/__init__.py
  64. 10 10
      api/core/app/segments/factory.py
  65. 2 2
      api/core/app/segments/parser.py
  66. 3 3
      api/core/app/segments/segment_group.py
  67. 10 14
      api/core/app/segments/segments.py
  68. 10 10
      api/core/app/segments/types.py
  69. 2 3
      api/core/app/segments/variables.py
  70. 20 22
      api/core/app/task_pipeline/based_generate_task_pipeline.py
  71. 76 99
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  72. 27 48
      api/core/app/task_pipeline/message_cycle_manage.py
  73. 61 56
      api/core/app/task_pipeline/workflow_cycle_manage.py
  74. 17 21
      api/core/callback_handler/agent_tool_callback_handler.py
  75. 28 34
      api/core/callback_handler/index_tool_callback_handler.py
  76. 1 1
      api/core/callback_handler/workflow_tool_callback_handler.py
  77. 27 23
      api/core/embedding/cached_embedding.py
  78. 4 4
      api/core/entities/agent_entities.py
  79. 3 3
      api/core/entities/message_entities.py
  80. 7 1
      api/core/entities/model_entities.py
  81. 201 167
      api/core/entities/provider_configuration.py
  82. 14 6
      api/core/entities/provider_entities.py
  83. 7 0
      api/core/errors/error.py
  84. 9 16
      api/core/extension/api_based_extension_requestor.py
  85. 26 22
      api/core/extension/extensible.py
  86. 1 4
      api/core/extension/extension.py
  87. 36 40
      api/core/external_data_tool/api/api.py
  88. 20 21
      api/core/external_data_tool/external_data_fetch.py
  89. 1 5
      api/core/external_data_tool/factory.py
  90. 31 28
      api/core/file/file_obj.py
  91. 52 48
      api/core/file/message_file_parser.py
  92. 4 5
      api/core/file/tool_file_parser.py
  93. 6 6
      api/core/file/upload_file_parser.py
  94. 38 34
      api/core/helper/code_executor/code_executor.py
  95. 3 17
      api/core/helper/code_executor/code_node_provider.py
  96. 2 1
      api/core/helper/code_executor/javascript/javascript_code_provider.py
  97. 2 1
      api/core/helper/code_executor/javascript/javascript_transformer.py
  98. 2 4
      api/core/helper/code_executor/jinja2/jinja2_formatter.py
  99. 1 3
      api/core/helper/code_executor/jinja2/jinja2_transformer.py
  100. 2 1
      api/core/helper/code_executor/python3/python3_code_provider.py

+ 1 - 1
api/core/__init__.py

@@ -1 +1 @@
-import core.moderation.base
+import core.moderation.base

+ 94 - 121
api/core/agent/cot_agent_runner.py

@@ -25,17 +25,19 @@ from models.model import Message
 
 class CotAgentRunner(BaseAgentRunner, ABC):
     _is_first_iteration = True
-    _ignore_observation_providers = ['wenxin']
+    _ignore_observation_providers = ["wenxin"]
     _historic_prompt_messages: list[PromptMessage] = None
     _agent_scratchpad: list[AgentScratchpadUnit] = None
     _instruction: str = None
     _query: str = None
     _prompt_messages_tools: list[PromptMessage] = None
 
-    def run(self, message: Message,
-            query: str,
-            inputs: dict[str, str],
-            ) -> Union[Generator, LLMResult]:
+    def run(
+        self,
+        message: Message,
+        query: str,
+        inputs: dict[str, str],
+    ) -> Union[Generator, LLMResult]:
         """
         Run Cot agent application
         """
@@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         trace_manager = app_generate_entity.trace_manager
 
         # check model mode
-        if 'Observation' not in app_generate_entity.model_conf.stop:
+        if "Observation" not in app_generate_entity.model_conf.stop:
             if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
-                app_generate_entity.model_conf.stop.append('Observation')
+                app_generate_entity.model_conf.stop.append("Observation")
 
         app_config = self.app_config
 
         # init instruction
         inputs = inputs or {}
         instruction = app_config.prompt_template.simple_prompt_template
-        self._instruction = self._fill_in_inputs_from_external_data_tools(
-            instruction, inputs)
+        self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
 
         iteration_step = 1
         max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
 
         function_call_state = True
-        llm_usage = {
-            'usage': None
-        }
-        final_answer = ''
+        llm_usage = {"usage": None}
+        final_answer = ""
 
         def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
-            if not final_llm_usage_dict['usage']:
-                final_llm_usage_dict['usage'] = usage
+            if not final_llm_usage_dict["usage"]:
+                final_llm_usage_dict["usage"] = usage
             else:
-                llm_usage = final_llm_usage_dict['usage']
+                llm_usage = final_llm_usage_dict["usage"]
                 llm_usage.prompt_tokens += usage.prompt_tokens
                 llm_usage.completion_tokens += usage.completion_tokens
                 llm_usage.prompt_price += usage.prompt_price
@@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
             message_file_ids = []
 
             agent_thought = self.create_agent_thought(
-                message_id=message.id,
-                message='',
-                tool_name='',
-                tool_input='',
-                messages_ids=message_file_ids
+                message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
             )
 
             if iteration_step > 1:
-                self.queue_manager.publish(QueueAgentThoughtEvent(
-                    agent_thought_id=agent_thought.id
-                ), PublishFrom.APPLICATION_MANAGER)
+                self.queue_manager.publish(
+                    QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+                )
 
             # recalc llm max tokens
             prompt_messages = self._organize_prompt_messages()
@@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                 raise ValueError("failed to invoke llm")
 
             usage_dict = {}
-            react_chunks = CotAgentOutputParser.handle_react_stream_output(
-                chunks, usage_dict)
+            react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
             scratchpad = AgentScratchpadUnit(
-                agent_response='',
-                thought='',
-                action_str='',
-                observation='',
+                agent_response="",
+                thought="",
+                action_str="",
+                observation="",
                 action=None,
             )
 
             # publish agent thought if it's first iteration
             if iteration_step == 1:
-                self.queue_manager.publish(QueueAgentThoughtEvent(
-                    agent_thought_id=agent_thought.id
-                ), PublishFrom.APPLICATION_MANAGER)
+                self.queue_manager.publish(
+                    QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+                )
 
             for chunk in react_chunks:
                 if isinstance(chunk, AgentScratchpadUnit.Action):
@@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     yield LLMResultChunk(
                         model=self.model_config.model,
                         prompt_messages=prompt_messages,
-                        system_fingerprint='',
-                        delta=LLMResultChunkDelta(
-                            index=0,
-                            message=AssistantPromptMessage(
-                                content=chunk
-                            ),
-                            usage=None
-                        )
+                        system_fingerprint="",
+                        delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
                     )
 
-            scratchpad.thought = scratchpad.thought.strip(
-            ) or 'I am thinking about how to help you'
+            scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
             self._agent_scratchpad.append(scratchpad)
 
             # get llm usage
-            if 'usage' in usage_dict:
-                increase_usage(llm_usage, usage_dict['usage'])
+            if "usage" in usage_dict:
+                increase_usage(llm_usage, usage_dict["usage"])
             else:
-                usage_dict['usage'] = LLMUsage.empty_usage()
+                usage_dict["usage"] = LLMUsage.empty_usage()
 
             self.save_agent_thought(
                 agent_thought=agent_thought,
-                tool_name=scratchpad.action.action_name if scratchpad.action else '',
-                tool_input={
-                    scratchpad.action.action_name: scratchpad.action.action_input
-                } if scratchpad.action else {},
+                tool_name=scratchpad.action.action_name if scratchpad.action else "",
+                tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
                 tool_invoke_meta={},
                 thought=scratchpad.thought,
-                observation='',
+                observation="",
                 answer=scratchpad.agent_response,
                 messages_ids=[],
-                llm_usage=usage_dict['usage']
+                llm_usage=usage_dict["usage"],
             )
 
             if not scratchpad.is_final():
-                self.queue_manager.publish(QueueAgentThoughtEvent(
-                    agent_thought_id=agent_thought.id
-                ), PublishFrom.APPLICATION_MANAGER)
+                self.queue_manager.publish(
+                    QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+                )
 
             if not scratchpad.action:
                 # failed to extract action, return final answer directly
-                final_answer = ''
+                final_answer = ""
             else:
                 if scratchpad.action.action_name.lower() == "final answer":
                     # action is final answer, return final answer directly
                     try:
                         if isinstance(scratchpad.action.action_input, dict):
-                            final_answer = json.dumps(
-                                scratchpad.action.action_input)
+                            final_answer = json.dumps(scratchpad.action.action_input)
                         elif isinstance(scratchpad.action.action_input, str):
                             final_answer = scratchpad.action.action_input
                         else:
-                            final_answer = f'{scratchpad.action.action_input}'
+                            final_answer = f"{scratchpad.action.action_input}"
                     except json.JSONDecodeError:
-                        final_answer = f'{scratchpad.action.action_input}'
+                        final_answer = f"{scratchpad.action.action_input}"
                 else:
                     function_call_state = True
                     # action is tool call, invoke tool
@@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     self.save_agent_thought(
                         agent_thought=agent_thought,
                         tool_name=scratchpad.action.action_name,
-                        tool_input={
-                            scratchpad.action.action_name: scratchpad.action.action_input},
+                        tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
                         thought=scratchpad.thought,
-                        observation={
-                            scratchpad.action.action_name: tool_invoke_response},
-                        tool_invoke_meta={
-                            scratchpad.action.action_name: tool_invoke_meta.to_dict()},
+                        observation={scratchpad.action.action_name: tool_invoke_response},
+                        tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
                         answer=scratchpad.agent_response,
                         messages_ids=message_file_ids,
-                        llm_usage=usage_dict['usage']
+                        llm_usage=usage_dict["usage"],
                     )
 
-                    self.queue_manager.publish(QueueAgentThoughtEvent(
-                        agent_thought_id=agent_thought.id
-                    ), PublishFrom.APPLICATION_MANAGER)
+                    self.queue_manager.publish(
+                        QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+                    )
 
                 # update prompt tool message
                 for prompt_tool in self._prompt_messages_tools:
@@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
             model=model_instance.model,
             prompt_messages=prompt_messages,
             delta=LLMResultChunkDelta(
-                index=0,
-                message=AssistantPromptMessage(
-                    content=final_answer
-                ),
-                usage=llm_usage['usage']
+                index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
             ),
-            system_fingerprint=''
+            system_fingerprint="",
         )
 
         # save agent thought
         self.save_agent_thought(
             agent_thought=agent_thought,
-            tool_name='',
+            tool_name="",
             tool_input={},
             tool_invoke_meta={},
             thought=final_answer,
             observation={},
             answer=final_answer,
-            messages_ids=[]
+            messages_ids=[],
         )
 
         self.update_db_variables(self.variables_pool, self.db_variables_pool)
         # publish end event
-        self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
-            model=model_instance.model,
-            prompt_messages=prompt_messages,
-            message=AssistantPromptMessage(
-                content=final_answer
+        self.queue_manager.publish(
+            QueueMessageEndEvent(
+                llm_result=LLMResult(
+                    model=model_instance.model,
+                    prompt_messages=prompt_messages,
+                    message=AssistantPromptMessage(content=final_answer),
+                    usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
+                    system_fingerprint="",
+                )
             ),
-            usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
-            system_fingerprint=''
-        )), PublishFrom.APPLICATION_MANAGER)
-
-    def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
-                              tool_instances: dict[str, Tool],
-                              message_file_ids: list[str],
-                              trace_manager: Optional[TraceQueueManager] = None
-                              ) -> tuple[str, ToolInvokeMeta]:
+            PublishFrom.APPLICATION_MANAGER,
+        )
+
+    def _handle_invoke_action(
+        self,
+        action: AgentScratchpadUnit.Action,
+        tool_instances: dict[str, Tool],
+        message_file_ids: list[str],
+        trace_manager: Optional[TraceQueueManager] = None,
+    ) -> tuple[str, ToolInvokeMeta]:
         """
         handle invoke action
         :param action: action
@@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         # publish files
         for message_file_id, save_as in message_files:
             if save_as:
-                self.variables_pool.set_file(
-                    tool_name=tool_call_name, value=message_file_id, name=save_as)
+                self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
 
             # publish message file
-            self.queue_manager.publish(QueueMessageFileEvent(
-                message_file_id=message_file_id
-            ), PublishFrom.APPLICATION_MANAGER)
+            self.queue_manager.publish(
+                QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
+            )
             # add message file ids
             message_file_ids.append(message_file_id)
 
@@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         """
         convert dict to action
         """
-        return AgentScratchpadUnit.Action(
-            action_name=action['action'],
-            action_input=action['action_input']
-        )
+        return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
 
     def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
         """
@@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         """
         for key, value in inputs.items():
             try:
-                instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
+                instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
             except Exception as e:
                 continue
 
@@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
     @abstractmethod
     def _organize_prompt_messages(self) -> list[PromptMessage]:
         """
-            organize prompt messages
+        organize prompt messages
         """
 
     def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
         """
-            format assistant message
+        format assistant message
         """
-        message = ''
+        message = ""
         for scratchpad in agent_scratchpad:
             if scratchpad.is_final():
                 message += f"Final Answer: {scratchpad.agent_response}"
@@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
 
         return message
 
-    def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+    def _organize_historic_prompt_messages(
+        self, current_session_messages: list[PromptMessage] = None
+    ) -> list[PromptMessage]:
         """
-            organize historic prompt messages
+        organize historic prompt messages
         """
         result: list[PromptMessage] = []
         scratchpads: list[AgentScratchpadUnit] = []
@@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                 if not current_scratchpad:
                     current_scratchpad = AgentScratchpadUnit(
                         agent_response=message.content,
-                        thought=message.content or 'I am thinking about how to help you',
-                        action_str='',
+                        thought=message.content or "I am thinking about how to help you",
+                        action_str="",
                         action=None,
                         observation=None,
                     )
@@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     try:
                         current_scratchpad.action = AgentScratchpadUnit.Action(
                             action_name=message.tool_calls[0].function.name,
-                            action_input=json.loads(
-                                message.tool_calls[0].function.arguments)
-                        )
-                        current_scratchpad.action_str = json.dumps(
-                            current_scratchpad.action.to_dict()
+                            action_input=json.loads(message.tool_calls[0].function.arguments),
                         )
+                        current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
                     except:
                         pass
             elif isinstance(message, ToolPromptMessage):
@@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
                     current_scratchpad.observation = message.content
             elif isinstance(message, UserPromptMessage):
                 if scratchpads:
-                    result.append(AssistantPromptMessage(
-                        content=self._format_assistant_message(scratchpads)
-                    ))
+                    result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
                     scratchpads = []
                     current_scratchpad = None
 
                 result.append(message)
 
         if scratchpads:
-            result.append(AssistantPromptMessage(
-                content=self._format_assistant_message(scratchpads)
-            ))
+            result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
 
         historic_prompts = AgentHistoryPromptTransform(
             model_config=self.model_config,
             prompt_messages=current_session_messages or [],
             history_messages=result,
-            memory=self.memory
+            memory=self.memory,
         ).get_prompt()
         return historic_prompts

+ 12 - 14
api/core/agent/cot_chat_agent_runner.py

@@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
         prompt_entity = self.app_config.agent.prompt
         first_prompt = prompt_entity.first_prompt
 
-        system_prompt = first_prompt \
-            .replace("{{instruction}}", self._instruction) \
-            .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
-            .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
+        system_prompt = (
+            first_prompt.replace("{{instruction}}", self._instruction)
+            .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
+            .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
+        )
 
         return SystemPromptMessage(content=system_prompt)
 
-    def _organize_user_query(self, query,  prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+    def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
         """
         Organize user query
         """
@@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
 
     def _organize_prompt_messages(self) -> list[PromptMessage]:
         """
-        Organize 
+        Organize
         """
         # organize system prompt
         system_message = self._organize_system_prompt()
@@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
         if not agent_scratchpad:
             assistant_messages = []
         else:
-            assistant_message = AssistantPromptMessage(content='')
+            assistant_message = AssistantPromptMessage(content="")
             for unit in agent_scratchpad:
                 if unit.is_final():
                     assistant_message.content += f"Final Answer: {unit.agent_response}"
@@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
 
         if assistant_messages:
             # organize historic prompt messages
-            historic_messages = self._organize_historic_prompt_messages([
-                system_message,
-                *query_messages,
-                *assistant_messages,
-                UserPromptMessage(content='continue')
-            ])
+            historic_messages = self._organize_historic_prompt_messages(
+                [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
+            )
             messages = [
                 system_message,
                 *historic_messages,
                 *query_messages,
                 *assistant_messages,
-                UserPromptMessage(content='continue')
+                UserPromptMessage(content="continue"),
             ]
         else:
             # organize historic prompt messages

+ 12 - 9
api/core/agent/cot_completion_agent_runner.py

@@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
         prompt_entity = self.app_config.agent.prompt
         first_prompt = prompt_entity.first_prompt
 
-        system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
-            .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
-            .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
-        
+        system_prompt = (
+            first_prompt.replace("{{instruction}}", self._instruction)
+            .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
+            .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
+        )
+
         return system_prompt
 
     def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
 
         # organize current assistant messages
         agent_scratchpad = self._agent_scratchpad
-        assistant_prompt = ''
+        assistant_prompt = ""
         for unit in agent_scratchpad:
             if unit.is_final():
                 assistant_prompt += f"Final Answer: {unit.agent_response}"
@@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
         query_prompt = f"Question: {self._query}"
 
         # join all messages
-        prompt = system_prompt \
-            .replace("{{historic_messages}}", historic_prompt) \
-            .replace("{{agent_scratchpad}}", assistant_prompt) \
+        prompt = (
+            system_prompt.replace("{{historic_messages}}", historic_prompt)
+            .replace("{{agent_scratchpad}}", assistant_prompt)
             .replace("{{query}}", query_prompt)
+        )
 
-        return [UserPromptMessage(content=prompt)]
+        return [UserPromptMessage(content=prompt)]

+ 10 - 6
api/core/agent/entities.py

@@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
     """
     Agent Tool Entity.
     """
+
     provider_type: Literal["builtin", "api", "workflow"]
     provider_id: str
     tool_name: str
@@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
     """
     Agent Prompt Entity.
     """
+
     first_prompt: str
     next_iteration: str
 
@@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
         """
         Action Entity.
         """
+
         action_name: str
         action_input: Union[dict, str]
 
@@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
             Convert to dictionary.
             """
             return {
-                'action': self.action_name,
-                'action_input': self.action_input,
+                "action": self.action_name,
+                "action_input": self.action_input,
             }
 
     agent_response: Optional[str] = None
@@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
         Check if the scratchpad unit is final.
         """
         return self.action is None or (
-            'final' in self.action.action_name.lower() and 
-            'answer' in self.action.action_name.lower()
+            "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
         )
 
+
 class AgentEntity(BaseModel):
     """
     Agent Entity.
@@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
         """
         Agent Strategy.
         """
-        CHAIN_OF_THOUGHT = 'chain-of-thought'
-        FUNCTION_CALLING = 'function-calling'
+
+        CHAIN_OF_THOUGHT = "chain-of-thought"
+        FUNCTION_CALLING = "function-calling"
 
     provider: str
     model: str

+ 115 - 120
api/core/agent/fc_agent_runner.py

@@ -24,11 +24,9 @@ from models.model import Message
 
 logger = logging.getLogger(__name__)
 
-class FunctionCallAgentRunner(BaseAgentRunner):
 
-    def run(self, 
-            message: Message, query: str, **kwargs: Any
-    ) -> Generator[LLMResultChunk, None, None]:
+class FunctionCallAgentRunner(BaseAgentRunner):
+    def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
         """
         Run FunctionCall agent application
         """
@@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
         # continue to run until there is not any tool call
         function_call_state = True
-        llm_usage = {
-            'usage': None
-        }
-        final_answer = ''
+        llm_usage = {"usage": None}
+        final_answer = ""
 
         # get tracing instance
         trace_manager = app_generate_entity.trace_manager
-        
+
         def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
-            if not final_llm_usage_dict['usage']:
-                final_llm_usage_dict['usage'] = usage
+            if not final_llm_usage_dict["usage"]:
+                final_llm_usage_dict["usage"] = usage
             else:
-                llm_usage = final_llm_usage_dict['usage']
+                llm_usage = final_llm_usage_dict["usage"]
                 llm_usage.prompt_tokens += usage.prompt_tokens
                 llm_usage.completion_tokens += usage.completion_tokens
                 llm_usage.prompt_price += usage.prompt_price
@@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
             message_file_ids = []
             agent_thought = self.create_agent_thought(
-                message_id=message.id,
-                message='',
-                tool_name='',
-                tool_input='',
-                messages_ids=message_file_ids
+                message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
             )
 
             # recalc llm max tokens
@@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             tool_calls: list[tuple[str, str, dict[str, Any]]] = []
 
             # save full response
-            response = ''
+            response = ""
 
             # save tool call names and inputs
-            tool_call_names = ''
-            tool_call_inputs = ''
+            tool_call_names = ""
+            tool_call_inputs = ""
 
             current_llm_usage = None
 
@@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                 is_first_chunk = True
                 for chunk in chunks:
                     if is_first_chunk:
-                        self.queue_manager.publish(QueueAgentThoughtEvent(
-                            agent_thought_id=agent_thought.id
-                        ), PublishFrom.APPLICATION_MANAGER)
+                        self.queue_manager.publish(
+                            QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+                        )
                         is_first_chunk = False
                     # check if there is any tool call
                     if self.check_tool_calls(chunk):
                         function_call_state = True
                         tool_calls.extend(self.extract_tool_calls(chunk))
-                        tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
+                        tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
                         try:
-                            tool_call_inputs = json.dumps({
-                                tool_call[1]: tool_call[2] for tool_call in tool_calls
-                            }, ensure_ascii=False)
+                            tool_call_inputs = json.dumps(
+                                {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
+                            )
                         except json.JSONDecodeError as e:
                             # ensure ascii to avoid encoding error
-                            tool_call_inputs = json.dumps({
-                                tool_call[1]: tool_call[2] for tool_call in tool_calls
-                            })
+                            tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
 
                     if chunk.delta.message and chunk.delta.message.content:
                         if isinstance(chunk.delta.message.content, list):
@@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                 if self.check_blocking_tool_calls(result):
                     function_call_state = True
                     tool_calls.extend(self.extract_blocking_tool_calls(result))
-                    tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
+                    tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
                     try:
-                        tool_call_inputs = json.dumps({
-                            tool_call[1]: tool_call[2] for tool_call in tool_calls
-                        }, ensure_ascii=False)
+                        tool_call_inputs = json.dumps(
+                            {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
+                        )
                     except json.JSONDecodeError as e:
                         # ensure ascii to avoid encoding error
-                        tool_call_inputs = json.dumps({
-                            tool_call[1]: tool_call[2] for tool_call in tool_calls
-                        })
+                        tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
 
                 if result.usage:
                     increase_usage(llm_usage, result.usage)
@@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                         response += result.message.content
 
                 if not result.message.content:
-                    result.message.content = ''
+                    result.message.content = ""
+
+                self.queue_manager.publish(
+                    QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+                )
 
-                self.queue_manager.publish(QueueAgentThoughtEvent(
-                    agent_thought_id=agent_thought.id
-                ), PublishFrom.APPLICATION_MANAGER)
-                
                 yield LLMResultChunk(
                     model=model_instance.model,
                     prompt_messages=result.prompt_messages,
@@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                         index=0,
                         message=result.message,
                         usage=result.usage,
-                    )
+                    ),
                 )
 
-            assistant_message = AssistantPromptMessage(
-                content='',
-                tool_calls=[]
-            )
+            assistant_message = AssistantPromptMessage(content="", tool_calls=[])
             if tool_calls:
-                assistant_message.tool_calls=[
+                assistant_message.tool_calls = [
                     AssistantPromptMessage.ToolCall(
                         id=tool_call[0],
-                        type='function',
+                        type="function",
                         function=AssistantPromptMessage.ToolCall.ToolCallFunction(
-                            name=tool_call[1],
-                            arguments=json.dumps(tool_call[2], ensure_ascii=False)
-                        )
-                    ) for tool_call in tool_calls
+                            name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
+                        ),
+                    )
+                    for tool_call in tool_calls
                 ]
             else:
                 assistant_message.content = response
-            
+
             self._current_thoughts.append(assistant_message)
 
             # save thought
             self.save_agent_thought(
-                agent_thought=agent_thought, 
+                agent_thought=agent_thought,
                 tool_name=tool_call_names,
                 tool_input=tool_call_inputs,
                 thought=response,
@@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                 observation=None,
                 answer=response,
                 messages_ids=[],
-                llm_usage=current_llm_usage
+                llm_usage=current_llm_usage,
             )
-            self.queue_manager.publish(QueueAgentThoughtEvent(
-                agent_thought_id=agent_thought.id
-            ), PublishFrom.APPLICATION_MANAGER)
-            
-            final_answer += response + '\n'
+            self.queue_manager.publish(
+                QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
+            )
+
+            final_answer += response + "\n"
 
             # call tools
             tool_responses = []
@@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                         "tool_call_id": tool_call_id,
                         "tool_call_name": tool_call_name,
                         "tool_response": f"there is not a tool named {tool_call_name}",
-                        "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
+                        "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
                     }
                 else:
                     # invoke tool
@@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                             self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
 
                         # publish message file
-                        self.queue_manager.publish(QueueMessageFileEvent(
-                            message_file_id=message_file_id
-                        ), PublishFrom.APPLICATION_MANAGER)
+                        self.queue_manager.publish(
+                            QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
+                        )
                         # add message file ids
                         message_file_ids.append(message_file_id)
-                    
+
                     tool_response = {
                         "tool_call_id": tool_call_id,
                         "tool_call_name": tool_call_name,
                         "tool_response": tool_invoke_response,
-                        "meta": tool_invoke_meta.to_dict()
+                        "meta": tool_invoke_meta.to_dict(),
                     }
-                
+
                 tool_responses.append(tool_response)
-                if tool_response['tool_response'] is not None:
+                if tool_response["tool_response"] is not None:
                     self._current_thoughts.append(
                         ToolPromptMessage(
-                            content=tool_response['tool_response'],
+                            content=tool_response["tool_response"],
                             tool_call_id=tool_call_id,
                             name=tool_call_name,
                         )
-                    ) 
+                    )
 
             if len(tool_responses) > 0:
                 # save agent thought
                 self.save_agent_thought(
-                    agent_thought=agent_thought, 
+                    agent_thought=agent_thought,
                     tool_name=None,
                     tool_input=None,
-                    thought=None, 
+                    thought=None,
                     tool_invoke_meta={
-                        tool_response['tool_call_name']: tool_response['meta'] 
-                        for tool_response in tool_responses
+                        tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
                     },
                     observation={
-                        tool_response['tool_call_name']: tool_response['tool_response'] 
+                        tool_response["tool_call_name"]: tool_response["tool_response"]
                         for tool_response in tool_responses
                     },
                     answer=None,
-                    messages_ids=message_file_ids
+                    messages_ids=message_file_ids,
+                )
+                self.queue_manager.publish(
+                    QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
                 )
-                self.queue_manager.publish(QueueAgentThoughtEvent(
-                    agent_thought_id=agent_thought.id
-                ), PublishFrom.APPLICATION_MANAGER)
 
             # update prompt tool
             for prompt_tool in prompt_messages_tools:
@@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
 
         self.update_db_variables(self.variables_pool, self.db_variables_pool)
         # publish end event
-        self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
-            model=model_instance.model,
-            prompt_messages=prompt_messages,
-            message=AssistantPromptMessage(
-                content=final_answer
+        self.queue_manager.publish(
+            QueueMessageEndEvent(
+                llm_result=LLMResult(
+                    model=model_instance.model,
+                    prompt_messages=prompt_messages,
+                    message=AssistantPromptMessage(content=final_answer),
+                    usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
+                    system_fingerprint="",
+                )
             ),
-            usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
-            system_fingerprint=''
-        )), PublishFrom.APPLICATION_MANAGER)
+            PublishFrom.APPLICATION_MANAGER,
+        )
 
     def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
         """
@@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
         if llm_result_chunk.delta.message.tool_calls:
             return True
         return False
-    
+
     def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
         """
         Check if there is any blocking tool call in llm result
@@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             return True
         return False
 
-    def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
+    def extract_tool_calls(
+        self, llm_result_chunk: LLMResultChunk
+    ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
         """
         Extract tool calls from llm result chunk
 
@@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
         tool_calls = []
         for prompt_message in llm_result_chunk.delta.message.tool_calls:
             args = {}
-            if prompt_message.function.arguments != '':
+            if prompt_message.function.arguments != "":
                 args = json.loads(prompt_message.function.arguments)
 
-            tool_calls.append((
-                prompt_message.id,
-                prompt_message.function.name,
-                args,
-            ))
+            tool_calls.append(
+                (
+                    prompt_message.id,
+                    prompt_message.function.name,
+                    args,
+                )
+            )
 
         return tool_calls
-    
+
     def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
         """
         Extract blocking tool calls from llm result
@@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
         tool_calls = []
         for prompt_message in llm_result.message.tool_calls:
             args = {}
-            if prompt_message.function.arguments != '':
+            if prompt_message.function.arguments != "":
                 args = json.loads(prompt_message.function.arguments)
 
-            tool_calls.append((
-                prompt_message.id,
-                prompt_message.function.name,
-                args,
-            ))
+            tool_calls.append(
+                (
+                    prompt_message.id,
+                    prompt_message.function.name,
+                    args,
+                )
+            )
 
         return tool_calls
 
-    def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+    def _init_system_message(
+        self, prompt_template: str, prompt_messages: list[PromptMessage] = None
+    ) -> list[PromptMessage]:
         """
         Initialize system message
         """
@@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             return [
                 SystemPromptMessage(content=prompt_template),
             ]
-        
+
         if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
             prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
 
         return prompt_messages
 
-    def _organize_user_query(self, query,  prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+    def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
         """
         Organize user query
         """
@@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             prompt_messages.append(UserPromptMessage(content=query))
 
         return prompt_messages
-    
+
     def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
         """
         As for now, gpt supports both fc and vision at the first iteration.
@@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
         for prompt_message in prompt_messages:
             if isinstance(prompt_message, UserPromptMessage):
                 if isinstance(prompt_message.content, list):
-                    prompt_message.content = '\n'.join([
-                        content.data if content.type == PromptMessageContentType.TEXT else 
-                        '[image]' if content.type == PromptMessageContentType.IMAGE else
-                        '[file]' 
-                        for content in prompt_message.content 
-                    ])
+                    prompt_message.content = "\n".join(
+                        [
+                            content.data
+                            if content.type == PromptMessageContentType.TEXT
+                            else "[image]"
+                            if content.type == PromptMessageContentType.IMAGE
+                            else "[file]"
+                            for content in prompt_message.content
+                        ]
+                    )
 
         return prompt_messages
 
     def _organize_prompt_messages(self):
-        prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
+        prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
         self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
         query_prompt_messages = self._organize_user_query(self.query, [])
 
@@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             model_config=self.model_config,
             prompt_messages=[*query_prompt_messages, *self._current_thoughts],
             history_messages=self.history_prompt_messages,
-            memory=self.memory
+            memory=self.memory,
         ).get_prompt()
 
-        prompt_messages = [
-            *self.history_prompt_messages,
-            *query_prompt_messages,
-            *self._current_thoughts
-        ]
+        prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
         if len(self._current_thoughts) != 0:
             # clear messages after the first iteration
             prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

+ 37 - 37
api/core/agent/output_parser/cot_output_parser.py

@@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
 
 class CotAgentOutputParser:
     @classmethod
-    def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
-        Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
+    def handle_react_stream_output(
+        cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
+    ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
         def parse_action(json_str):
             try:
                 action = json.loads(json_str)
@@ -22,7 +23,7 @@ class CotAgentOutputParser:
                     action = action[0]
 
                 for key, value in action.items():
-                    if 'input' in key.lower():
+                    if "input" in key.lower():
                         action_input = value
                     else:
                         action_name = value
@@ -33,37 +34,37 @@ class CotAgentOutputParser:
                         action_input=action_input,
                     )
                 else:
-                    return json_str or ''
+                    return json_str or ""
             except:
-                return json_str or ''
-            
+                return json_str or ""
+
         def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
-            code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
+            code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
             if not code_blocks:
                 return
             for block in code_blocks:
-                json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
+                json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
                 yield parse_action(json_text)
-            
-        code_block_cache = ''
+
+        code_block_cache = ""
         code_block_delimiter_count = 0
         in_code_block = False
-        json_cache = ''
+        json_cache = ""
         json_quote_count = 0
         in_json = False
         got_json = False
 
-        action_cache = ''
-        action_str = 'action:'
+        action_cache = ""
+        action_str = "action:"
         action_idx = 0
 
-        thought_cache = ''
-        thought_str = 'thought:'
+        thought_cache = ""
+        thought_str = "thought:"
         thought_idx = 0
 
         for response in llm_response:
             if response.delta.usage:
-                usage_dict['usage'] = response.delta.usage
+                usage_dict["usage"] = response.delta.usage
             response = response.delta.message.content
             if not isinstance(response, str):
                 continue
@@ -72,24 +73,24 @@ class CotAgentOutputParser:
             index = 0
             while index < len(response):
                 steps = 1
-                delta = response[index:index+steps]
-                last_character = response[index-1] if index > 0 else ''
+                delta = response[index : index + steps]
+                last_character = response[index - 1] if index > 0 else ""
 
-                if delta == '`':
+                if delta == "`":
                     code_block_cache += delta
                     code_block_delimiter_count += 1
                 else:
                     if not in_code_block:
                         if code_block_delimiter_count > 0:
                             yield code_block_cache
-                        code_block_cache = ''
+                        code_block_cache = ""
                     else:
                         code_block_cache += delta
                     code_block_delimiter_count = 0
 
                 if not in_code_block and not in_json:
                     if delta.lower() == action_str[action_idx] and action_idx == 0:
-                        if last_character not in ['\n', ' ', '']:
+                        if last_character not in ["\n", " ", ""]:
                             index += steps
                             yield delta
                             continue
@@ -97,7 +98,7 @@ class CotAgentOutputParser:
                         action_cache += delta
                         action_idx += 1
                         if action_idx == len(action_str):
-                            action_cache = ''
+                            action_cache = ""
                             action_idx = 0
                         index += steps
                         continue
@@ -105,18 +106,18 @@ class CotAgentOutputParser:
                         action_cache += delta
                         action_idx += 1
                         if action_idx == len(action_str):
-                            action_cache = ''
+                            action_cache = ""
                             action_idx = 0
                         index += steps
                         continue
                     else:
                         if action_cache:
                             yield action_cache
-                            action_cache = ''
+                            action_cache = ""
                             action_idx = 0
-                    
+
                     if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
-                        if last_character not in ['\n', ' ', '']:
+                        if last_character not in ["\n", " ", ""]:
                             index += steps
                             yield delta
                             continue
@@ -124,7 +125,7 @@ class CotAgentOutputParser:
                         thought_cache += delta
                         thought_idx += 1
                         if thought_idx == len(thought_str):
-                            thought_cache = ''
+                            thought_cache = ""
                             thought_idx = 0
                         index += steps
                         continue
@@ -132,31 +133,31 @@ class CotAgentOutputParser:
                         thought_cache += delta
                         thought_idx += 1
                         if thought_idx == len(thought_str):
-                            thought_cache = ''
+                            thought_cache = ""
                             thought_idx = 0
                         index += steps
                         continue
                     else:
                         if thought_cache:
                             yield thought_cache
-                            thought_cache = ''
+                            thought_cache = ""
                             thought_idx = 0
 
                 if code_block_delimiter_count == 3:
                     if in_code_block:
                         yield from extra_json_from_code_block(code_block_cache)
-                        code_block_cache = ''
-                        
+                        code_block_cache = ""
+
                     in_code_block = not in_code_block
                     code_block_delimiter_count = 0
 
                 if not in_code_block:
                     # handle single json
-                    if delta == '{':
+                    if delta == "{":
                         json_quote_count += 1
                         in_json = True
                         json_cache += delta
-                    elif delta == '}':
+                    elif delta == "}":
                         json_cache += delta
                         if json_quote_count > 0:
                             json_quote_count -= 1
@@ -172,12 +173,12 @@ class CotAgentOutputParser:
                     if got_json:
                         got_json = False
                         yield parse_action(json_cache)
-                        json_cache = ''
+                        json_cache = ""
                         json_quote_count = 0
                         in_json = False
-                    
+
                 if not in_code_block and not in_json:
-                    yield delta.replace('`', '')
+                    yield delta.replace("`", "")
 
                 index += steps
 
@@ -186,4 +187,3 @@ class CotAgentOutputParser:
 
         if json_cache:
             yield parse_action(json_cache)
-

+ 9 - 9
api/core/agent/prompt/template.py

@@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
 ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
 
 REACT_PROMPT_TEMPLATES = {
-    'english': {
-        'chat': {
-            'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
-            'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
+    "english": {
+        "chat": {
+            "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
+            "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
+        },
+        "completion": {
+            "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
+            "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
         },
-        'completion': {
-            'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
-            'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
-        }
     }
-}
+}

+ 8 - 18
api/core/app/app_config/base_app_config_manager.py

@@ -26,34 +26,24 @@ class BaseAppConfigManager:
         config_dict = dict(config_dict.items())
 
         additional_features = AppAdditionalFeatures()
-        additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
-            config=config_dict
-        )
+        additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
 
         additional_features.file_upload = FileUploadConfigManager.convert(
-            config=config_dict,
-            is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
+            config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
         )
 
-        additional_features.opening_statement, additional_features.suggested_questions = \
-            OpeningStatementConfigManager.convert(
-                config=config_dict
-            )
+        additional_features.opening_statement, additional_features.suggested_questions = (
+            OpeningStatementConfigManager.convert(config=config_dict)
+        )
 
         additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
             config=config_dict
         )
 
-        additional_features.more_like_this = MoreLikeThisConfigManager.convert(
-            config=config_dict
-        )
+        additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
 
-        additional_features.speech_to_text = SpeechToTextConfigManager.convert(
-            config=config_dict
-        )
+        additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
 
-        additional_features.text_to_speech = TextToSpeechConfigManager.convert(
-            config=config_dict
-        )
+        additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
 
         return additional_features

+ 9 - 14
api/core/app/app_config/common/sensitive_word_avoidance/manager.py

@@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
 class SensitiveWordAvoidanceConfigManager:
     @classmethod
     def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
-        sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
+        sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
         if not sensitive_word_avoidance_dict:
             return None
 
-        if sensitive_word_avoidance_dict.get('enabled'):
+        if sensitive_word_avoidance_dict.get("enabled"):
             return SensitiveWordAvoidanceEntity(
-                type=sensitive_word_avoidance_dict.get('type'),
-                config=sensitive_word_avoidance_dict.get('config'),
+                type=sensitive_word_avoidance_dict.get("type"),
+                config=sensitive_word_avoidance_dict.get("config"),
             )
         else:
             return None
 
     @classmethod
-    def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-            -> tuple[dict, list[str]]:
+    def validate_and_set_defaults(
+        cls, tenant_id, config: dict, only_structure_validate: bool = False
+    ) -> tuple[dict, list[str]]:
         if not config.get("sensitive_word_avoidance"):
-            config["sensitive_word_avoidance"] = {
-                "enabled": False
-            }
+            config["sensitive_word_avoidance"] = {"enabled": False}
 
         if not isinstance(config["sensitive_word_avoidance"], dict):
             raise ValueError("sensitive_word_avoidance must be of dict type")
@@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
                 typ = config["sensitive_word_avoidance"]["type"]
                 sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
 
-                ModerationFactory.validate_config(
-                    name=typ,
-                    tenant_id=tenant_id,
-                    config=sensitive_word_avoidance_config
-                )
+                ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
 
         return config, ["sensitive_word_avoidance"]

+ 33 - 30
api/core/app/app_config/easy_ui_based_app/agent/manager.py

@@ -12,67 +12,70 @@ class AgentConfigManager:
 
         :param config: model config args
         """
-        if 'agent_mode' in config and config['agent_mode'] \
-                and 'enabled' in config['agent_mode']:
+        if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
+            agent_dict = config.get("agent_mode", {})
+            agent_strategy = agent_dict.get("strategy", "cot")
 
-            agent_dict = config.get('agent_mode', {})
-            agent_strategy = agent_dict.get('strategy', 'cot')
-
-            if agent_strategy == 'function_call':
+            if agent_strategy == "function_call":
                 strategy = AgentEntity.Strategy.FUNCTION_CALLING
-            elif agent_strategy == 'cot' or agent_strategy == 'react':
+            elif agent_strategy == "cot" or agent_strategy == "react":
                 strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
             else:
                 # old configs, try to detect default strategy
-                if config['model']['provider'] == 'openai':
+                if config["model"]["provider"] == "openai":
                     strategy = AgentEntity.Strategy.FUNCTION_CALLING
                 else:
                     strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
 
             agent_tools = []
-            for tool in agent_dict.get('tools', []):
+            for tool in agent_dict.get("tools", []):
                 keys = tool.keys()
                 if len(keys) >= 4:
                     if "enabled" not in tool or not tool["enabled"]:
                         continue
 
                     agent_tool_properties = {
-                        'provider_type': tool['provider_type'],
-                        'provider_id': tool['provider_id'],
-                        'tool_name': tool['tool_name'],
-                        'tool_parameters': tool.get('tool_parameters', {})
+                        "provider_type": tool["provider_type"],
+                        "provider_id": tool["provider_id"],
+                        "tool_name": tool["tool_name"],
+                        "tool_parameters": tool.get("tool_parameters", {}),
                     }
 
                     agent_tools.append(AgentToolEntity(**agent_tool_properties))
 
-            if 'strategy' in config['agent_mode'] and \
-                    config['agent_mode']['strategy'] not in ['react_router', 'router']:
-                agent_prompt = agent_dict.get('prompt', None) or {}
+            if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
+                "react_router",
+                "router",
+            ]:
+                agent_prompt = agent_dict.get("prompt", None) or {}
                 # check model mode
-                model_mode = config.get('model', {}).get('mode', 'completion')
-                if model_mode == 'completion':
+                model_mode = config.get("model", {}).get("mode", "completion")
+                if model_mode == "completion":
                     agent_prompt_entity = AgentPromptEntity(
-                        first_prompt=agent_prompt.get('first_prompt',
-                                                      REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
-                        next_iteration=agent_prompt.get('next_iteration',
-                                                        REACT_PROMPT_TEMPLATES['english']['completion'][
-                                                            'agent_scratchpad']),
+                        first_prompt=agent_prompt.get(
+                            "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
+                        ),
+                        next_iteration=agent_prompt.get(
+                            "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
+                        ),
                     )
                 else:
                     agent_prompt_entity = AgentPromptEntity(
-                        first_prompt=agent_prompt.get('first_prompt',
-                                                      REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
-                        next_iteration=agent_prompt.get('next_iteration',
-                                                        REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
+                        first_prompt=agent_prompt.get(
+                            "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
+                        ),
+                        next_iteration=agent_prompt.get(
+                            "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
+                        ),
                     )
 
                 return AgentEntity(
-                    provider=config['model']['provider'],
-                    model=config['model']['name'],
+                    provider=config["model"]["provider"],
+                    model=config["model"]["name"],
                     strategy=strategy,
                     prompt=agent_prompt_entity,
                     tools=agent_tools,
-                    max_iteration=agent_dict.get('max_iteration', 5)
+                    max_iteration=agent_dict.get("max_iteration", 5),
                 )
 
         return None

+ 40 - 48
api/core/app/app_config/easy_ui_based_app/dataset/manager.py

@@ -15,39 +15,38 @@ class DatasetConfigManager:
         :param config: model config args
         """
         dataset_ids = []
-        if 'datasets' in config.get('dataset_configs', {}):
-            datasets = config.get('dataset_configs', {}).get('datasets', {
-                'strategy': 'router',
-                'datasets': []
-            })
+        if "datasets" in config.get("dataset_configs", {}):
+            datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
 
-            for dataset in datasets.get('datasets', []):
+            for dataset in datasets.get("datasets", []):
                 keys = list(dataset.keys())
-                if len(keys) == 0 or keys[0] != 'dataset':
+                if len(keys) == 0 or keys[0] != "dataset":
                     continue
 
-                dataset = dataset['dataset']
+                dataset = dataset["dataset"]
 
-                if 'enabled' not in dataset or not dataset['enabled']:
+                if "enabled" not in dataset or not dataset["enabled"]:
                     continue
 
-                dataset_id = dataset.get('id', None)
+                dataset_id = dataset.get("id", None)
                 if dataset_id:
                     dataset_ids.append(dataset_id)
 
-        if 'agent_mode' in config and config['agent_mode'] \
-                and 'enabled' in config['agent_mode'] \
-                and config['agent_mode']['enabled']:
+        if (
+            "agent_mode" in config
+            and config["agent_mode"]
+            and "enabled" in config["agent_mode"]
+            and config["agent_mode"]["enabled"]
+        ):
+            agent_dict = config.get("agent_mode", {})
 
-            agent_dict = config.get('agent_mode', {})
-
-            for tool in agent_dict.get('tools', []):
+            for tool in agent_dict.get("tools", []):
                 keys = tool.keys()
                 if len(keys) == 1:
                     # old standard
                     key = list(tool.keys())[0]
 
-                    if key != 'dataset':
+                    if key != "dataset":
                         continue
 
                     tool_item = tool[key]
@@ -55,30 +54,28 @@ class DatasetConfigManager:
                     if "enabled" not in tool_item or not tool_item["enabled"]:
                         continue
 
-                    dataset_id = tool_item['id']
+                    dataset_id = tool_item["id"]
                     dataset_ids.append(dataset_id)
 
         if len(dataset_ids) == 0:
             return None
 
         # dataset configs
-        if 'dataset_configs' in config and config.get('dataset_configs'):
-            dataset_configs = config.get('dataset_configs')
+        if "dataset_configs" in config and config.get("dataset_configs"):
+            dataset_configs = config.get("dataset_configs")
         else:
-            dataset_configs = {
-                'retrieval_model': 'multiple'
-            }
-        query_variable = config.get('dataset_query_variable')
+            dataset_configs = {"retrieval_model": "multiple"}
+        query_variable = config.get("dataset_query_variable")
 
-        if dataset_configs['retrieval_model'] == 'single':
+        if dataset_configs["retrieval_model"] == "single":
             return DatasetEntity(
                 dataset_ids=dataset_ids,
                 retrieve_config=DatasetRetrieveConfigEntity(
                     query_variable=query_variable,
                     retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
-                        dataset_configs['retrieval_model']
-                    )
-                )
+                        dataset_configs["retrieval_model"]
+                    ),
+                ),
             )
         else:
             return DatasetEntity(
@@ -86,15 +83,15 @@ class DatasetConfigManager:
                 retrieve_config=DatasetRetrieveConfigEntity(
                     query_variable=query_variable,
                     retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
-                        dataset_configs['retrieval_model']
+                        dataset_configs["retrieval_model"]
                     ),
-                    top_k=dataset_configs.get('top_k', 4),
-                    score_threshold=dataset_configs.get('score_threshold'),
-                    reranking_model=dataset_configs.get('reranking_model'),
-                    weights=dataset_configs.get('weights'),
-                    reranking_enabled=dataset_configs.get('reranking_enabled', True),
-                    rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
-                )
+                    top_k=dataset_configs.get("top_k", 4),
+                    score_threshold=dataset_configs.get("score_threshold"),
+                    reranking_model=dataset_configs.get("reranking_model"),
+                    weights=dataset_configs.get("weights"),
+                    reranking_enabled=dataset_configs.get("reranking_enabled", True),
+                    rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
+                ),
             )
 
     @classmethod
@@ -111,13 +108,10 @@ class DatasetConfigManager:
 
         # dataset_configs
         if not config.get("dataset_configs"):
-            config["dataset_configs"] = {'retrieval_model': 'single'}
+            config["dataset_configs"] = {"retrieval_model": "single"}
 
         if not config["dataset_configs"].get("datasets"):
-            config["dataset_configs"]["datasets"] = {
-                "strategy": "router",
-                "datasets": []
-            }
+            config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
 
         if not isinstance(config["dataset_configs"], dict):
             raise ValueError("dataset_configs must be of object type")
@@ -125,8 +119,9 @@ class DatasetConfigManager:
         if not isinstance(config["dataset_configs"], dict):
             raise ValueError("dataset_configs must be of object type")
 
-        need_manual_query_datasets = (config.get("dataset_configs")
-                                      and config["dataset_configs"].get("datasets", {}).get("datasets"))
+        need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
+            "datasets", {}
+        ).get("datasets")
 
         if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
             # Only check when mode is completion
@@ -148,10 +143,7 @@ class DatasetConfigManager:
         """
         # Extract dataset config for legacy compatibility
         if not config.get("agent_mode"):
-            config["agent_mode"] = {
-                "enabled": False,
-                "tools": []
-            }
+            config["agent_mode"] = {"enabled": False, "tools": []}
 
         if not isinstance(config["agent_mode"], dict):
             raise ValueError("agent_mode must be of object type")
@@ -188,7 +180,7 @@ class DatasetConfigManager:
                     if not isinstance(tool_item["enabled"], bool):
                         raise ValueError("enabled in agent_mode.tools must be of boolean type")
 
-                    if 'id' not in tool_item:
+                    if "id" not in tool_item:
                         raise ValueError("id is required in dataset")
 
                     try:

+ 9 - 21
api/core/app/app_config/easy_ui_based_app/model_config/converter.py

@@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
 
 class ModelConfigConverter:
     @classmethod
-    def convert(cls, app_config: EasyUIBasedAppConfig,
-                skip_check: bool = False) \
-            -> ModelConfigWithCredentialsEntity:
+    def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
         """
         Convert app model config dict to entity.
         :param app_config: app config
@@ -25,9 +23,7 @@ class ModelConfigConverter:
 
         provider_manager = ProviderManager()
         provider_model_bundle = provider_manager.get_provider_model_bundle(
-            tenant_id=app_config.tenant_id,
-            provider=model_config.provider,
-            model_type=ModelType.LLM
+            tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
         )
 
         provider_name = provider_model_bundle.configuration.provider.provider
@@ -38,8 +34,7 @@ class ModelConfigConverter:
 
         # check model credentials
         model_credentials = provider_model_bundle.configuration.get_current_credentials(
-            model_type=ModelType.LLM,
-            model=model_config.model
+            model_type=ModelType.LLM, model=model_config.model
         )
 
         if model_credentials is None:
@@ -51,8 +46,7 @@ class ModelConfigConverter:
         if not skip_check:
             # check model
             provider_model = provider_model_bundle.configuration.get_provider_model(
-                model=model_config.model,
-                model_type=ModelType.LLM
+                model=model_config.model, model_type=ModelType.LLM
             )
 
             if provider_model is None:
@@ -69,24 +63,18 @@ class ModelConfigConverter:
         # model config
         completion_params = model_config.parameters
         stop = []
-        if 'stop' in completion_params:
-            stop = completion_params['stop']
-            del completion_params['stop']
+        if "stop" in completion_params:
+            stop = completion_params["stop"]
+            del completion_params["stop"]
 
         # get model mode
         model_mode = model_config.mode
         if not model_mode:
-            mode_enum = model_type_instance.get_model_mode(
-                model=model_config.model,
-                credentials=model_credentials
-            )
+            mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
 
             model_mode = mode_enum.value
 
-        model_schema = model_type_instance.get_model_schema(
-            model_config.model,
-            model_credentials
-        )
+        model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
 
         if not skip_check and not model_schema:
             raise ValueError(f"Model {model_name} not exist.")

+ 16 - 17
api/core/app/app_config/easy_ui_based_app/model_config/manager.py

@@ -13,23 +13,23 @@ class ModelConfigManager:
         :param config: model config args
         """
         # model config
-        model_config = config.get('model')
+        model_config = config.get("model")
 
         if not model_config:
             raise ValueError("model is required")
 
-        completion_params = model_config.get('completion_params')
+        completion_params = model_config.get("completion_params")
         stop = []
-        if 'stop' in completion_params:
-            stop = completion_params['stop']
-            del completion_params['stop']
+        if "stop" in completion_params:
+            stop = completion_params["stop"]
+            del completion_params["stop"]
 
         # get model mode
-        model_mode = model_config.get('mode')
+        model_mode = model_config.get("mode")
 
         return ModelConfigEntity(
-            provider=config['model']['provider'],
-            model=config['model']['name'],
+            provider=config["model"]["provider"],
+            model=config["model"]["name"],
             mode=model_mode,
             parameters=completion_params,
             stop=stop,
@@ -43,7 +43,7 @@ class ModelConfigManager:
         :param tenant_id: tenant id
         :param config: app model config args
         """
-        if 'model' not in config:
+        if "model" not in config:
             raise ValueError("model is required")
 
         if not isinstance(config["model"], dict):
@@ -52,17 +52,16 @@ class ModelConfigManager:
         # model.provider
         provider_entities = model_provider_factory.get_providers()
         model_provider_names = [provider.provider for provider in provider_entities]
-        if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
+        if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
             raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
 
         # model.name
-        if 'name' not in config["model"]:
+        if "name" not in config["model"]:
             raise ValueError("model.name is required")
 
         provider_manager = ProviderManager()
         models = provider_manager.get_configurations(tenant_id).get_models(
-            provider=config["model"]["provider"],
-            model_type=ModelType.LLM
+            provider=config["model"]["provider"], model_type=ModelType.LLM
         )
 
         if not models:
@@ -80,12 +79,12 @@ class ModelConfigManager:
 
         # model.mode
         if model_mode:
-            config['model']["mode"] = model_mode
+            config["model"]["mode"] = model_mode
         else:
-            config['model']["mode"] = "completion"
+            config["model"]["mode"] = "completion"
 
         # model.completion_params
-        if 'completion_params' not in config["model"]:
+        if "completion_params" not in config["model"]:
             raise ValueError("model.completion_params is required")
 
         config["model"]["completion_params"] = cls.validate_model_completion_params(
@@ -101,7 +100,7 @@ class ModelConfigManager:
             raise ValueError("model.completion_params must be of object type")
 
         # stop
-        if 'stop' not in cp:
+        if "stop" not in cp:
             cp["stop"] = []
         elif not isinstance(cp["stop"], list):
             raise ValueError("stop in model.completion_params must be of list type")

+ 26 - 31
api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py

@@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
         if not config.get("prompt_type"):
             raise ValueError("prompt_type is required")
 
-        prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
+        prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
         if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
             simple_prompt_template = config.get("pre_prompt", "")
-            return PromptTemplateEntity(
-                prompt_type=prompt_type,
-                simple_prompt_template=simple_prompt_template
-            )
+            return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
         else:
             advanced_chat_prompt_template = None
             chat_prompt_config = config.get("chat_prompt_config", {})
             if chat_prompt_config:
                 chat_prompt_messages = []
                 for message in chat_prompt_config.get("prompt", []):
-                    chat_prompt_messages.append({
-                        "text": message["text"],
-                        "role": PromptMessageRole.value_of(message["role"])
-                    })
+                    chat_prompt_messages.append(
+                        {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
+                    )
 
-                advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
-                    messages=chat_prompt_messages
-                )
+                advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
 
             advanced_completion_prompt_template = None
             completion_prompt_config = config.get("completion_prompt_config", {})
             if completion_prompt_config:
                 completion_prompt_template_params = {
-                    'prompt': completion_prompt_config['prompt']['text'],
+                    "prompt": completion_prompt_config["prompt"]["text"],
                 }
 
-                if 'conversation_histories_role' in completion_prompt_config:
-                    completion_prompt_template_params['role_prefix'] = {
-                        'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
-                        'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
+                if "conversation_histories_role" in completion_prompt_config:
+                    completion_prompt_template_params["role_prefix"] = {
+                        "user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
+                        "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
                     }
 
                 advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
             return PromptTemplateEntity(
                 prompt_type=prompt_type,
                 advanced_chat_prompt_template=advanced_chat_prompt_template,
-                advanced_completion_prompt_template=advanced_completion_prompt_template
+                advanced_completion_prompt_template=advanced_completion_prompt_template,
             )
 
     @classmethod
@@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
             config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
 
         prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
-        if config['prompt_type'] not in prompt_type_vals:
+        if config["prompt_type"] not in prompt_type_vals:
             raise ValueError(f"prompt_type must be in {prompt_type_vals}")
 
         # chat_prompt_config
@@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
         if not isinstance(config["completion_prompt_config"], dict):
             raise ValueError("completion_prompt_config must be of object type")
 
-        if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
-            if not config['chat_prompt_config'] and not config['completion_prompt_config']:
-                raise ValueError("chat_prompt_config or completion_prompt_config is required "
-                                 "when prompt_type is advanced")
+        if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
+            if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
+                raise ValueError(
+                    "chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
+                )
 
             model_mode_vals = [mode.value for mode in ModelMode]
-            if config['model']["mode"] not in model_mode_vals:
+            if config["model"]["mode"] not in model_mode_vals:
                 raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
 
-            if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
-                user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
-                assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
+            if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
+                user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
+                assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
 
                 if not user_prefix:
-                    config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
+                    config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
 
                 if not assistant_prefix:
-                    config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
+                    config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
 
-            if config['model']["mode"] == ModelMode.CHAT.value:
-                prompt_list = config['chat_prompt_config']['prompt']
+            if config["model"]["mode"] == ModelMode.CHAT.value:
+                prompt_list = config["chat_prompt_config"]["prompt"]
 
                 if len(prompt_list) > 10:
                     raise ValueError("prompt messages must be less than 10")

+ 22 - 30
api/core/app/app_config/easy_ui_based_app/variables/manager.py

@@ -16,32 +16,30 @@ class BasicVariablesConfigManager:
         variable_entities = []
 
         # old external_data_tools
-        external_data_tools = config.get('external_data_tools', [])
+        external_data_tools = config.get("external_data_tools", [])
         for external_data_tool in external_data_tools:
-            if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
+            if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
                 continue
 
             external_data_variables.append(
                 ExternalDataVariableEntity(
-                    variable=external_data_tool['variable'],
-                    type=external_data_tool['type'],
-                    config=external_data_tool['config']
+                    variable=external_data_tool["variable"],
+                    type=external_data_tool["type"],
+                    config=external_data_tool["config"],
                 )
             )
 
         # variables and external_data_tools
-        for variables in config.get('user_input_form', []):
+        for variables in config.get("user_input_form", []):
             variable_type = list(variables.keys())[0]
             if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
                 variable = variables[variable_type]
-                if 'config' not in variable:
+                if "config" not in variable:
                     continue
 
                 external_data_variables.append(
                     ExternalDataVariableEntity(
-                        variable=variable['variable'],
-                        type=variable['type'],
-                        config=variable['config']
+                        variable=variable["variable"], type=variable["type"], config=variable["config"]
                     )
                 )
             elif variable_type in [
@@ -54,13 +52,13 @@ class BasicVariablesConfigManager:
                 variable_entities.append(
                     VariableEntity(
                         type=variable_type,
-                        variable=variable.get('variable'),
-                        description=variable.get('description'),
-                        label=variable.get('label'),
-                        required=variable.get('required', False),
-                        max_length=variable.get('max_length'),
-                        options=variable.get('options'),
-                        default=variable.get('default'),
+                        variable=variable.get("variable"),
+                        description=variable.get("description"),
+                        label=variable.get("label"),
+                        required=variable.get("required", False),
+                        max_length=variable.get("max_length"),
+                        options=variable.get("options"),
+                        default=variable.get("default"),
                     )
                 )
 
@@ -103,13 +101,13 @@ class BasicVariablesConfigManager:
                 raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph'  or 'select'")
 
             form_item = item[key]
-            if 'label' not in form_item:
+            if "label" not in form_item:
                 raise ValueError("label is required in user_input_form")
 
             if not isinstance(form_item["label"], str):
                 raise ValueError("label in user_input_form must be of string type")
 
-            if 'variable' not in form_item:
+            if "variable" not in form_item:
                 raise ValueError("variable is required in user_input_form")
 
             if not isinstance(form_item["variable"], str):
@@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
 
             pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
             if pattern.match(form_item["variable"]) is None:
-                raise ValueError("variable in user_input_form must be a string, "
-                                 "and cannot start with a number")
+                raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
 
             variables.append(form_item["variable"])
 
-            if 'required' not in form_item or not form_item["required"]:
+            if "required" not in form_item or not form_item["required"]:
                 form_item["required"] = False
 
             if not isinstance(form_item["required"], bool):
                 raise ValueError("required in user_input_form must be of boolean type")
 
             if key == "select":
-                if 'options' not in form_item or not form_item["options"]:
+                if "options" not in form_item or not form_item["options"]:
                     form_item["options"] = []
 
                 if not isinstance(form_item["options"], list):
                     raise ValueError("options in user_input_form must be a list of strings")
 
-                if "default" in form_item and form_item['default'] \
-                        and form_item["default"] not in form_item["options"]:
+                if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
                     raise ValueError("default value in user_input_form must be in the options list")
 
         return config, ["user_input_form"]
@@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
             typ = tool["type"]
             config = tool["config"]
 
-            ExternalDataToolFactory.validate_config(
-                name=typ,
-                tenant_id=tenant_id,
-                config=config
-            )
+            ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
 
         return config, ["external_data_tools"]

+ 29 - 17
api/core/app/app_config/entities.py

@@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
     """
     Model Config Entity.
     """
+
     provider: str
     model: str
     mode: Optional[str] = None
@@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
     """
     Advanced Chat Message Entity.
     """
+
     text: str
     role: PromptMessageRole
 
@@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
     """
     Advanced Chat Prompt Template Entity.
     """
+
     messages: list[AdvancedChatMessageEntity]
 
 
@@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
         """
         Role Prefix Entity.
         """
+
         user: str
         assistant: str
 
@@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
         Prompt Type.
         'simple', 'advanced'
         """
-        SIMPLE = 'simple'
-        ADVANCED = 'advanced'
+
+        SIMPLE = "simple"
+        ADVANCED = "advanced"
 
         @classmethod
-        def value_of(cls, value: str) -> 'PromptType':
+        def value_of(cls, value: str) -> "PromptType":
             """
             Get value of given mode.
 
@@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
             for mode in cls:
                 if mode.value == value:
                     return mode
-            raise ValueError(f'invalid prompt type value {value}')
+            raise ValueError(f"invalid prompt type value {value}")
 
     prompt_type: PromptType
     simple_prompt_template: Optional[str] = None
@@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
     """
     External Data Variable Entity.
     """
+
     variable: str
     type: str
     config: dict[str, Any] = {}
@@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
         Dataset Retrieve Strategy.
         'single' or 'multiple'
         """
-        SINGLE = 'single'
-        MULTIPLE = 'multiple'
+
+        SINGLE = "single"
+        MULTIPLE = "multiple"
 
         @classmethod
-        def value_of(cls, value: str) -> 'RetrieveStrategy':
+        def value_of(cls, value: str) -> "RetrieveStrategy":
             """
             Get value of given mode.
 
@@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
             for mode in cls:
                 if mode.value == value:
                     return mode
-            raise ValueError(f'invalid retrieve strategy value {value}')
+            raise ValueError(f"invalid retrieve strategy value {value}")
 
     query_variable: Optional[str] = None  # Only when app mode is completion
 
     retrieve_strategy: RetrieveStrategy
     top_k: Optional[int] = None
-    score_threshold: Optional[float] = .0
-    rerank_mode: Optional[str] = 'reranking_model'
+    score_threshold: Optional[float] = 0.0
+    rerank_mode: Optional[str] = "reranking_model"
     reranking_model: Optional[dict] = None
     weights: Optional[dict] = None
     reranking_enabled: Optional[bool] = True
 
 
-
-
 class DatasetEntity(BaseModel):
     """
     Dataset Config Entity.
     """
+
     dataset_ids: list[str]
     retrieve_config: DatasetRetrieveConfigEntity
 
@@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
     """
     Sensitive Word Avoidance Entity.
     """
+
     type: str
     config: dict[str, Any] = {}
 
@@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
     """
     Sensitive Word Avoidance Entity.
     """
+
     enabled: bool
     voice: Optional[str] = None
     language: Optional[str] = None
@@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
     """
     Tracing Config Entity.
     """
+
     enabled: bool
     tracing_provider: str
 
 
-
-
 class AppAdditionalFeatures(BaseModel):
     file_upload: Optional[FileExtraConfig] = None
     opening_statement: Optional[str] = None
@@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
     text_to_speech: Optional[TextToSpeechEntity] = None
     trace_config: Optional[TracingConfigEntity] = None
 
+
 class AppConfig(BaseModel):
     """
     Application Config Entity.
     """
+
     tenant_id: str
     app_id: str
     app_mode: AppMode
@@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
     """
     App Model Config From.
     """
-    ARGS = 'args'
-    APP_LATEST_CONFIG = 'app-latest-config'
-    CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
+
+    ARGS = "args"
+    APP_LATEST_CONFIG = "app-latest-config"
+    CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
 
 
 class EasyUIBasedAppConfig(AppConfig):
     """
     Easy UI Based App Config Entity.
     """
+
     app_model_config_from: EasyUIBasedAppModelConfigFrom
     app_model_config_id: str
     app_model_config_dict: dict
@@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
     """
     Workflow UI Based App Config Entity.
     """
+
     workflow_id: str

+ 13 - 15
api/core/app/app_config/features/file_upload/manager.py

@@ -13,21 +13,19 @@ class FileUploadConfigManager:
         :param config: model config args
         :param is_vision: if True, the feature is vision feature
         """
-        file_upload_dict = config.get('file_upload')
+        file_upload_dict = config.get("file_upload")
         if file_upload_dict:
-            if file_upload_dict.get('image'):
-                if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
+            if file_upload_dict.get("image"):
+                if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
                     image_config = {
-                        'number_limits': file_upload_dict['image']['number_limits'],
-                        'transfer_methods': file_upload_dict['image']['transfer_methods']
+                        "number_limits": file_upload_dict["image"]["number_limits"],
+                        "transfer_methods": file_upload_dict["image"]["transfer_methods"],
                     }
 
                     if is_vision:
-                        image_config['detail'] = file_upload_dict['image']['detail']
+                        image_config["detail"] = file_upload_dict["image"]["detail"]
 
-                    return FileExtraConfig(
-                        image_config=image_config
-                    )
+                    return FileExtraConfig(image_config=image_config)
 
         return None
 
@@ -49,21 +47,21 @@ class FileUploadConfigManager:
         if not config["file_upload"].get("image"):
             config["file_upload"]["image"] = {"enabled": False}
 
-        if config['file_upload']['image']['enabled']:
-            number_limits = config['file_upload']['image']['number_limits']
+        if config["file_upload"]["image"]["enabled"]:
+            number_limits = config["file_upload"]["image"]["number_limits"]
             if number_limits < 1 or number_limits > 6:
                 raise ValueError("number_limits must be in [1, 6]")
 
             if is_vision:
-                detail = config['file_upload']['image']['detail']
-                if detail not in ['high', 'low']:
+                detail = config["file_upload"]["image"]["detail"]
+                if detail not in ["high", "low"]:
                     raise ValueError("detail must be in ['high', 'low']")
 
-            transfer_methods = config['file_upload']['image']['transfer_methods']
+            transfer_methods = config["file_upload"]["image"]["transfer_methods"]
             if not isinstance(transfer_methods, list):
                 raise ValueError("transfer_methods must be of list type")
             for method in transfer_methods:
-                if method not in ['remote_url', 'local_file']:
+                if method not in ["remote_url", "local_file"]:
                     raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
 
         return config, ["file_upload"]

+ 3 - 5
api/core/app/app_config/features/more_like_this/manager.py

@@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
         :param config: model config args
         """
         more_like_this = False
-        more_like_this_dict = config.get('more_like_this')
+        more_like_this_dict = config.get("more_like_this")
         if more_like_this_dict:
-            if more_like_this_dict.get('enabled'):
+            if more_like_this_dict.get("enabled"):
                 more_like_this = True
 
         return more_like_this
@@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
         :param config: app model config args
         """
         if not config.get("more_like_this"):
-            config["more_like_this"] = {
-                "enabled": False
-            }
+            config["more_like_this"] = {"enabled": False}
 
         if not isinstance(config["more_like_this"], dict):
             raise ValueError("more_like_this must be of dict type")

+ 2 - 4
api/core/app/app_config/features/opening_statement/manager.py

@@ -1,5 +1,3 @@
-
-
 class OpeningStatementConfigManager:
     @classmethod
     def convert(cls, config: dict) -> tuple[str, list]:
@@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
         :param config: model config args
         """
         # opening statement
-        opening_statement = config.get('opening_statement')
+        opening_statement = config.get("opening_statement")
 
         # suggested questions
-        suggested_questions_list = config.get('suggested_questions')
+        suggested_questions_list = config.get("suggested_questions")
 
         return opening_statement, suggested_questions_list
 

+ 3 - 5
api/core/app/app_config/features/retrieval_resource/manager.py

@@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
     @classmethod
     def convert(cls, config: dict) -> bool:
         show_retrieve_source = False
-        retriever_resource_dict = config.get('retriever_resource')
+        retriever_resource_dict = config.get("retriever_resource")
         if retriever_resource_dict:
-            if retriever_resource_dict.get('enabled'):
+            if retriever_resource_dict.get("enabled"):
                 show_retrieve_source = True
 
         return show_retrieve_source
@@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
         :param config: app model config args
         """
         if not config.get("retriever_resource"):
-            config["retriever_resource"] = {
-                "enabled": False
-            }
+            config["retriever_resource"] = {"enabled": False}
 
         if not isinstance(config["retriever_resource"], dict):
             raise ValueError("retriever_resource must be of dict type")

+ 3 - 5
api/core/app/app_config/features/speech_to_text/manager.py

@@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
         :param config: model config args
         """
         speech_to_text = False
-        speech_to_text_dict = config.get('speech_to_text')
+        speech_to_text_dict = config.get("speech_to_text")
         if speech_to_text_dict:
-            if speech_to_text_dict.get('enabled'):
+            if speech_to_text_dict.get("enabled"):
                 speech_to_text = True
 
         return speech_to_text
@@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
         :param config: app model config args
         """
         if not config.get("speech_to_text"):
-            config["speech_to_text"] = {
-                "enabled": False
-            }
+            config["speech_to_text"] = {"enabled": False}
 
         if not isinstance(config["speech_to_text"], dict):
             raise ValueError("speech_to_text must be of dict type")

+ 7 - 7
api/core/app/app_config/features/suggested_questions_after_answer/manager.py

@@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
         :param config: model config args
         """
         suggested_questions_after_answer = False
-        suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
+        suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
         if suggested_questions_after_answer_dict:
-            if suggested_questions_after_answer_dict.get('enabled'):
+            if suggested_questions_after_answer_dict.get("enabled"):
                 suggested_questions_after_answer = True
 
         return suggested_questions_after_answer
@@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
         :param config: app model config args
         """
         if not config.get("suggested_questions_after_answer"):
-            config["suggested_questions_after_answer"] = {
-                "enabled": False
-            }
+            config["suggested_questions_after_answer"] = {"enabled": False}
 
         if not isinstance(config["suggested_questions_after_answer"], dict):
             raise ValueError("suggested_questions_after_answer must be of dict type")
 
-        if "enabled" not in config["suggested_questions_after_answer"] or not \
-        config["suggested_questions_after_answer"]["enabled"]:
+        if (
+            "enabled" not in config["suggested_questions_after_answer"]
+            or not config["suggested_questions_after_answer"]["enabled"]
+        ):
             config["suggested_questions_after_answer"]["enabled"] = False
 
         if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

+ 6 - 10
api/core/app/app_config/features/text_to_speech/manager.py

@@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
         :param config: model config args
         """
         text_to_speech = None
-        text_to_speech_dict = config.get('text_to_speech')
+        text_to_speech_dict = config.get("text_to_speech")
         if text_to_speech_dict:
-            if text_to_speech_dict.get('enabled'):
+            if text_to_speech_dict.get("enabled"):
                 text_to_speech = TextToSpeechEntity(
-                    enabled=text_to_speech_dict.get('enabled'),
-                    voice=text_to_speech_dict.get('voice'),
-                    language=text_to_speech_dict.get('language'),
+                    enabled=text_to_speech_dict.get("enabled"),
+                    voice=text_to_speech_dict.get("voice"),
+                    language=text_to_speech_dict.get("language"),
                 )
 
         return text_to_speech
@@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
         :param config: app model config args
         """
         if not config.get("text_to_speech"):
-            config["text_to_speech"] = {
-                "enabled": False,
-                "voice": "",
-                "language": ""
-            }
+            config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
 
         if not isinstance(config["text_to_speech"], dict):
             raise ValueError("text_to_speech must be of dict type")

+ 9 - 17
api/core/app/apps/advanced_chat/app_config_manager.py

@@ -1,4 +1,3 @@
-
 from core.app.app_config.base_app_config_manager import BaseAppConfigManager
 from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
 from core.app.app_config.entities import WorkflowUIBasedAppConfig
@@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
     """
     Advanced Chatbot App Config Entity.
     """
+
     pass
 
 
 class AdvancedChatAppConfigManager(BaseAppConfigManager):
     @classmethod
-    def get_app_config(cls, app_model: App,
-                       workflow: Workflow) -> AdvancedChatAppConfig:
+    def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
         features_dict = workflow.features_dict
 
         app_mode = AppMode.value_of(app_model.mode)
@@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
             app_id=app_model.id,
             app_mode=app_mode,
             workflow_id=workflow.id,
-            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
-                config=features_dict
-            ),
-            variables=WorkflowVariablesConfigManager.convert(
-                workflow=workflow
-            ),
-            additional_features=cls.convert_features(features_dict, app_mode)
+            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
+            variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
+            additional_features=cls.convert_features(features_dict, app_mode),
         )
 
         return app_config
@@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
 
         # file upload validation
         config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
-            config=config,
-            is_vision=False
+            config=config, is_vision=False
         )
         related_config_keys.extend(current_related_config_keys)
 
@@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
 
         # suggested_questions_after_answer
         config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
-            config)
+            config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         # speech_to_text
@@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
 
         # moderation validation
         config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
-            tenant_id=tenant_id,
-            config=config,
-            only_structure_validate=only_structure_validate
+            tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
         )
         related_config_keys.extend(current_related_config_keys)
 
@@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
         filtered_config = {key: config.get(key) for key in related_config_keys}
 
         return filtered_config
-

+ 76 - 90
api/core/app/apps/advanced_chat/app_generator.py

@@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
 class AdvancedChatAppGenerator(MessageBasedAppGenerator):
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         args: dict,
@@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
 
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         args: dict,
@@ -53,14 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
     ) -> dict: ...
 
     def generate(
-            self,
-            app_model: App,
-            workflow: Workflow,
-            user: Union[Account, EndUser],
-            args: dict,
-            invoke_from: InvokeFrom,
-            stream: bool = True,
-    )  -> dict[str, Any] | Generator[str, Any, None]:
+        self,
+        app_model: App,
+        workflow: Workflow,
+        user: Union[Account, EndUser],
+        args: dict,
+        invoke_from: InvokeFrom,
+        stream: bool = True,
+    ) -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -71,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         :param invoke_from: invoke from source
         :param stream: is stream
         """
-        if not args.get('query'):
-            raise ValueError('query is required')
+        if not args.get("query"):
+            raise ValueError("query is required")
 
-        query = args['query']
+        query = args["query"]
         if not isinstance(query, str):
-            raise ValueError('query must be a string')
+            raise ValueError("query must be a string")
 
-        query = query.replace('\x00', '')
-        inputs = args['inputs']
+        query = query.replace("\x00", "")
+        inputs = args["inputs"]
 
-        extras = {
-            "auto_generate_conversation_name": args.get('auto_generate_name', False)
-        }
+        extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
 
         # get conversation
         conversation = None
-        conversation_id = args.get('conversation_id')
+        conversation_id = args.get("conversation_id")
         if conversation_id:
-            conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
+            conversation = self._get_conversation_by_user(
+                app_model=app_model, conversation_id=conversation_id, user=user
+            )
 
         # parse files
-        files = args['files'] if args.get('files') else []
+        files = args["files"] if args.get("files") else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(
-                files,
-                file_extra_config,
-                user
-            )
+            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
         else:
             file_objs = []
 
         # convert to app config
-        app_config = AdvancedChatAppConfigManager.get_app_config(
-            app_model=app_model,
-            workflow=workflow
-        )
+        app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
 
         # get tracing instance
         user_id = user.id if isinstance(user, Account) else user.session_id
@@ -130,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             stream=stream,
             invoke_from=invoke_from,
             extras=extras,
-            trace_manager=trace_manager
+            trace_manager=trace_manager,
         )
         contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
 
@@ -140,16 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=invoke_from,
             application_generate_entity=application_generate_entity,
             conversation=conversation,
-            stream=stream
+            stream=stream,
         )
 
-    def single_iteration_generate(self, app_model: App,
-                                  workflow: Workflow,
-                                  node_id: str,
-                                  user: Account,
-                                  args: dict,
-                                  stream: bool = True) \
-            -> dict[str, Any] | Generator[str, Any, None]:
+    def single_iteration_generate(
+        self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
+    ) -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -161,16 +152,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         :param stream: is stream
         """
         if not node_id:
-            raise ValueError('node_id is required')
+            raise ValueError("node_id is required")
 
-        if args.get('inputs') is None:
-            raise ValueError('inputs is required')
+        if args.get("inputs") is None:
+            raise ValueError("inputs is required")
 
         # convert to app config
-        app_config = AdvancedChatAppConfigManager.get_app_config(
-            app_model=app_model,
-            workflow=workflow
-        )
+        app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
 
         # init application generate entity
         application_generate_entity = AdvancedChatAppGenerateEntity(
@@ -178,18 +166,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             app_config=app_config,
             conversation_id=None,
             inputs={},
-            query='',
+            query="",
             files=[],
             user_id=user.id,
             stream=stream,
             invoke_from=InvokeFrom.DEBUGGER,
-            extras={
-                "auto_generate_conversation_name": False
-            },
+            extras={"auto_generate_conversation_name": False},
             single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
-                node_id=node_id,
-                inputs=args['inputs']
-            )
+                node_id=node_id, inputs=args["inputs"]
+            ),
         )
         contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
 
@@ -199,17 +184,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=InvokeFrom.DEBUGGER,
             application_generate_entity=application_generate_entity,
             conversation=None,
-            stream=stream
+            stream=stream,
         )
 
-    def _generate(self, *,
-                  workflow: Workflow,
-                  user: Union[Account, EndUser],
-                  invoke_from: InvokeFrom,
-                  application_generate_entity: AdvancedChatAppGenerateEntity,
-                  conversation: Optional[Conversation] = None,
-                  stream: bool = True) \
-            -> dict[str, Any] | Generator[str, Any, None]:
+    def _generate(
+        self,
+        *,
+        workflow: Workflow,
+        user: Union[Account, EndUser],
+        invoke_from: InvokeFrom,
+        application_generate_entity: AdvancedChatAppGenerateEntity,
+        conversation: Optional[Conversation] = None,
+        stream: bool = True,
+    ) -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -225,10 +212,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             is_first_conversation = True
 
         # init generate records
-        (
-            conversation,
-            message
-        ) = self._init_generate_records(application_generate_entity, conversation)
+        (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
 
         if is_first_conversation:
             # update conversation features
@@ -243,18 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=application_generate_entity.invoke_from,
             conversation_id=conversation.id,
             app_mode=conversation.mode,
-            message_id=message.id
+            message_id=message.id,
         )
 
         # new thread
-        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(), # type: ignore
-            'application_generate_entity': application_generate_entity,
-            'queue_manager': queue_manager,
-            'conversation_id': conversation.id,
-            'message_id': message.id,
-            'context': contextvars.copy_context(),
-        })
+        worker_thread = threading.Thread(
+            target=self._generate_worker,
+            kwargs={
+                "flask_app": current_app._get_current_object(),  # type: ignore
+                "application_generate_entity": application_generate_entity,
+                "queue_manager": queue_manager,
+                "conversation_id": conversation.id,
+                "message_id": message.id,
+                "context": contextvars.copy_context(),
+            },
+        )
 
         worker_thread.start()
 
@@ -269,17 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             stream=stream,
         )
 
-        return AdvancedChatAppGenerateResponseConverter.convert(
-            response=response,
-            invoke_from=invoke_from
-        )
+        return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
 
-    def _generate_worker(self, flask_app: Flask,
-                         application_generate_entity: AdvancedChatAppGenerateEntity,
-                         queue_manager: AppQueueManager,
-                         conversation_id: str,
-                         message_id: str,
-                         context: contextvars.Context) -> None:
+    def _generate_worker(
+        self,
+        flask_app: Flask,
+        application_generate_entity: AdvancedChatAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        conversation_id: str,
+        message_id: str,
+        context: contextvars.Context,
+    ) -> None:
         """
         Generate worker in a new thread.
         :param flask_app: Flask app
@@ -302,7 +289,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
                     application_generate_entity=application_generate_entity,
                     queue_manager=queue_manager,
                     conversation=conversation,
-                    message=message
+                    message=message,
                 )
 
                 runner.run()
@@ -310,14 +297,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
                 pass
             except InvokeAuthorizationError:
                 queue_manager.publish_error(
-                    InvokeAuthorizationError('Incorrect API key provided'),
-                    PublishFrom.APPLICATION_MANAGER
+                    InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
                 )
             except ValidationError as e:
                 logger.exception("Validation Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                if os.environ.get("DEBUG", "false").lower() == 'true':
+                if os.environ.get("DEBUG", "false").lower() == "true":
                     logger.exception("Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:

+ 16 - 21
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py

@@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
     if not text_content or text_content.isspace():
         return
     return model_instance.invoke_tts(
-        content_text=text_content.strip(),
-        user="responding_tts",
-        tenant_id=tenant_id,
-        voice=voice
+        content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
     )
 
 
@@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
         except Exception as e:
             logging.getLogger(__name__).warning(e)
             break
-    audio_queue.put(AudioTrunk("finish", b''))
+    audio_queue.put(AudioTrunk("finish", b""))
 
 
 class AppGeneratorTTSPublisher:
-
     def __init__(self, tenant_id: str, voice: str):
         self.logger = logging.getLogger(__name__)
         self.tenant_id = tenant_id
-        self.msg_text = ''
+        self.msg_text = ""
         self._audio_queue = queue.Queue()
         self._msg_queue = queue.Queue()
-        self.match = re.compile(r'[。.!?]')
+        self.match = re.compile(r"[。.!?]")
         self.model_manager = ModelManager()
         self.model_instance = self.model_manager.get_default_model_instance(
-            tenant_id=self.tenant_id,
-            model_type=ModelType.TTS
+            tenant_id=self.tenant_id, model_type=ModelType.TTS
         )
         self.voices = self.model_instance.get_tts_voices()
-        values = [voice.get('value') for voice in self.voices]
+        values = [voice.get("value") for voice in self.voices]
         self.voice = voice
         if not voice or voice not in values:
-            self.voice = self.voices[0].get('value')
+            self.voice = self.voices[0].get("value")
         self.MAX_SENTENCE = 2
         self._last_audio_event = None
         self._runtime_thread = threading.Thread(target=self._runtime).start()
@@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
                 message = self._msg_queue.get()
                 if message is None:
                     if self.msg_text and len(self.msg_text.strip()) > 0:
-                        futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
-                                                              self.model_instance, self.tenant_id, self.voice)
+                        futures_result = self.executor.submit(
+                            _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
+                        )
                         future_queue.put(futures_result)
                     break
                 elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
                 elif isinstance(message.event, QueueTextChunkEvent):
                     self.msg_text += message.event.text
                 elif isinstance(message.event, QueueNodeSucceededEvent):
-                    self.msg_text += message.event.outputs.get('output', '')
+                    self.msg_text += message.event.outputs.get("output", "")
                 self.last_message = message
                 sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
                 if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
                     self.MAX_SENTENCE += 1
-                    text_content = ''.join(sentence_arr)
-                    futures_result = self.executor.submit(_invoiceTTS, text_content,
-                                                          self.model_instance,
-                                                          self.tenant_id,
-                                                          self.voice)
+                    text_content = "".join(sentence_arr)
+                    futures_result = self.executor.submit(
+                        _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
+                    )
                     future_queue.put(futures_result)
                     if text_tmp:
                         self.msg_text = text_tmp
                     else:
-                        self.msg_text = ''
+                        self.msg_text = ""
 
             except Exception as e:
                 self.logger.warning(e)

+ 35 - 49
api/core/app/apps/advanced_chat/app_runner.py

@@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
     """
 
     def __init__(
-            self,
-            application_generate_entity: AdvancedChatAppGenerateEntity,
-            queue_manager: AppQueueManager,
-            conversation: Conversation,
-            message: Message
+        self,
+        application_generate_entity: AdvancedChatAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        conversation: Conversation,
+        message: Message,
     ) -> None:
         """
         :param application_generate_entity: application generate entity
@@ -66,11 +66,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
 
         app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
         if not app_record:
-            raise ValueError('App not found')
+            raise ValueError("App not found")
 
         workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
         if not workflow:
-            raise ValueError('Workflow not initialized')
+            raise ValueError("Workflow not initialized")
 
         user_id = None
         if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
@@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             user_id = self.application_generate_entity.user_id
 
         workflow_callbacks: list[WorkflowCallback] = []
-        if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
+        if bool(os.environ.get("DEBUG", "False").lower() == "true"):
             workflow_callbacks.append(WorkflowLoggingCallback())
 
         if self.application_generate_entity.single_iteration_run:
@@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
                 workflow=workflow,
                 node_id=self.application_generate_entity.single_iteration_run.node_id,
-                user_inputs=self.application_generate_entity.single_iteration_run.inputs
+                user_inputs=self.application_generate_entity.single_iteration_run.inputs,
             )
         else:
             inputs = self.application_generate_entity.inputs
@@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
 
             # moderation
             if self.handle_input_moderation(
-                    app_record=app_record,
-                    app_generate_entity=self.application_generate_entity,
-                    inputs=inputs,
-                    query=query,
-                    message_id=self.message.id
+                app_record=app_record,
+                app_generate_entity=self.application_generate_entity,
+                inputs=inputs,
+                query=query,
+                message_id=self.message.id,
             ):
                 return
 
             # annotation reply
             if self.handle_annotation_reply(
-                    app_record=app_record,
-                    message=self.message,
-                    query=query,
-                    app_generate_entity=self.application_generate_entity
+                app_record=app_record,
+                message=self.message,
+                query=query,
+                app_generate_entity=self.application_generate_entity,
             ):
                 return
 
             # Init conversation variables
             stmt = select(ConversationVariable).where(
-                ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
+                ConversationVariable.app_id == self.conversation.app_id,
+                ConversationVariable.conversation_id == self.conversation.id,
             )
             with Session(db.engine) as session:
                 conversation_variables = session.scalars(stmt).all()
@@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             self._handle_event(workflow_entry, event)
 
     def handle_input_moderation(
-            self,
-            app_record: App,
-            app_generate_entity: AdvancedChatAppGenerateEntity,
-            inputs: Mapping[str, Any],
-            query: str,
-            message_id: str
+        self,
+        app_record: App,
+        app_generate_entity: AdvancedChatAppGenerateEntity,
+        inputs: Mapping[str, Any],
+        query: str,
+        message_id: str,
     ) -> bool:
         """
         Handle input moderation
@@ -217,18 +218,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
                 message_id=message_id,
             )
         except ModerationException as e:
-            self._complete_with_stream_output(
-                text=str(e),
-                stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
-            )
+            self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
             return True
 
         return False
 
-    def handle_annotation_reply(self, app_record: App,
-                                message: Message,
-                                query: str,
-                                app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
+    def handle_annotation_reply(
+        self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
+    ) -> bool:
         """
         Handle annotation reply
         :param app_record: app record
@@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
         )
 
         if annotation_reply:
-            self._publish_event(
-                QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
-            )
+            self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
 
             self._complete_with_stream_output(
-                text=annotation_reply.content,
-                stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
+                text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
             )
             return True
 
         return False
 
-    def _complete_with_stream_output(self,
-                                     text: str,
-                                     stopped_by: QueueStopEvent.StopBy) -> None:
+    def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
         """
         Direct output
         :param text: text
         :return:
         """
-        self._publish_event(
-            QueueTextChunkEvent(
-                text=text
-            )
-        )
+        self._publish_event(QueueTextChunkEvent(text=text))
 
-        self._publish_event(
-            QueueStopEvent(stopped_by=stopped_by)
-        )
+        self._publish_event(QueueStopEvent(stopped_by=stopped_by))

+ 29 - 25
api/core/app/apps/advanced_chat/generate_response_converter.py

@@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
         """
         blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
         response = {
-            'event': 'message',
-            'task_id': blocking_response.task_id,
-            'id': blocking_response.data.id,
-            'message_id': blocking_response.data.message_id,
-            'conversation_id': blocking_response.data.conversation_id,
-            'mode': blocking_response.data.mode,
-            'answer': blocking_response.data.answer,
-            'metadata': blocking_response.data.metadata,
-            'created_at': blocking_response.data.created_at
+            "event": "message",
+            "task_id": blocking_response.task_id,
+            "id": blocking_response.data.id,
+            "message_id": blocking_response.data.message_id,
+            "conversation_id": blocking_response.data.conversation_id,
+            "mode": blocking_response.data.mode,
+            "answer": blocking_response.data.answer,
+            "metadata": blocking_response.data.metadata,
+            "created_at": blocking_response.data.created_at,
         }
 
         return response
@@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
         """
         response = cls.convert_blocking_full_response(blocking_response)
 
-        metadata = response.get('metadata', {})
-        response['metadata'] = cls._get_simple_metadata(metadata)
+        metadata = response.get("metadata", {})
+        response["metadata"] = cls._get_simple_metadata(metadata)
 
         return response
 
     @classmethod
-    def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
+    def convert_stream_full_response(
+        cls, stream_response: Generator[AppStreamResponse, None, None]
+    ) -> Generator[str, Any, None]:
         """
         Convert stream full response.
         :param stream_response: stream response
@@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'conversation_id': chunk.conversation_id,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "conversation_id": chunk.conversation_id,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             yield json.dumps(response_chunk)
 
     @classmethod
-    def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
+    def convert_stream_simple_response(
+        cls, stream_response: Generator[AppStreamResponse, None, None]
+    ) -> Generator[str, Any, None]:
         """
         Convert stream simple response.
         :param stream_response: stream response
@@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'conversation_id': chunk.conversation_id,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "conversation_id": chunk.conversation_id,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, MessageEndStreamResponse):
                 sub_stream_response_dict = sub_stream_response.to_dict()
-                metadata = sub_stream_response_dict.get('metadata', {})
-                sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
+                metadata = sub_stream_response_dict.get("metadata", {})
+                sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
                 response_chunk.update(sub_stream_response_dict)
             if isinstance(sub_stream_response, ErrorStreamResponse):
                 data = cls._error_to_stream_response(sub_stream_response.err)

+ 74 - 93
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     """
     AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
+
     _task_state: WorkflowTaskState
     _application_generate_entity: AdvancedChatAppGenerateEntity
     _workflow: Workflow
@@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     _workflow_system_variables: dict[SystemVariableKey, Any]
 
     def __init__(
-            self,
-            application_generate_entity: AdvancedChatAppGenerateEntity,
-            workflow: Workflow,
-            queue_manager: AppQueueManager,
-            conversation: Conversation,
-            message: Message,
-            user: Union[Account, EndUser],
-            stream: bool,
+        self,
+        application_generate_entity: AdvancedChatAppGenerateEntity,
+        workflow: Workflow,
+        queue_manager: AppQueueManager,
+        conversation: Conversation,
+        message: Message,
+        user: Union[Account, EndUser],
+        stream: bool,
     ) -> None:
         """
         Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
         # start generate conversation name thread
         self._conversation_name_generate_thread = self._generate_conversation_name(
-            self._conversation,
-            self._application_generate_entity.query
+            self._conversation, self._application_generate_entity.query
         )
 
-        generator = self._wrapper_process_stream_response(
-            trace_manager=self._application_generate_entity.trace_manager
-        )
+        generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
 
         if self._stream:
             return self._to_stream_response(generator)
@@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             elif isinstance(stream_response, MessageEndStreamResponse):
                 extras = {}
                 if stream_response.metadata:
-                    extras['metadata'] = stream_response.metadata
+                    extras["metadata"] = stream_response.metadata
 
                 return ChatbotAppBlockingResponse(
                     task_id=stream_response.task_id,
@@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         message_id=self._message.id,
                         answer=self._task_state.answer,
                         created_at=int(self._message.created_at.timestamp()),
-                        **extras
-                    )
+                        **extras,
+                    ),
                 )
             else:
                 continue
 
-        raise Exception('Queue listening stopped unexpectedly.')
+        raise Exception("Queue listening stopped unexpectedly.")
 
-    def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
+    def _to_stream_response(
+        self, generator: Generator[StreamResponse, None, None]
+    ) -> Generator[ChatbotAppStreamResponse, Any, None]:
         """
         To stream response.
         :return:
@@ -176,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 conversation_id=self._conversation.id,
                 message_id=self._message.id,
                 created_at=int(self._message.created_at.timestamp()),
-                stream_response=stream_response
+                stream_response=stream_response,
             )
 
     def _listenAudioMsg(self, publisher, task_id: str):
@@ -187,17 +187,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
         return None
 
-    def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
-            Generator[StreamResponse, None, None]:
-
+    def _wrapper_process_stream_response(
+        self, trace_manager: Optional[TraceQueueManager] = None
+    ) -> Generator[StreamResponse, None, None]:
         tts_publisher = None
         task_id = self._application_generate_entity.task_id
         tenant_id = self._application_generate_entity.app_config.tenant_id
         features_dict = self._workflow.features_dict
 
-        if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
-                'text_to_speech'].get('autoPlay') == 'enabled':
-            tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
+        if (
+            features_dict.get("text_to_speech")
+            and features_dict["text_to_speech"].get("enabled")
+            and features_dict["text_to_speech"].get("autoPlay") == "enabled"
+        ):
+            tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
 
         for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
             while True:
@@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             except Exception as e:
                 logger.error(e)
                 break
-        yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
+        yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
 
     def _process_stream_response(
-            self,
-            tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
-            trace_manager: Optional[TraceQueueManager] = None
+        self,
+        tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+        trace_manager: Optional[TraceQueueManager] = None,
     ) -> Generator[StreamResponse, None, None]:
         """
         Process stream response.
@@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 db.session.close()
 
                 yield self._workflow_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                 )
             elif isinstance(event, QueueNodeStartedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
-                workflow_node_execution = self._handle_node_execution_start(
-                    workflow_run=workflow_run,
-                    event=event
-                )
+                workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
 
                 response = self._workflow_node_start_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution
+                    workflow_node_execution=workflow_node_execution,
                 )
 
                 if response:
@@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution
+                    workflow_node_execution=workflow_node_execution,
                 )
 
                 if response:
@@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution
+                    workflow_node_execution=workflow_node_execution,
                 )
 
                 if response:
                     yield response
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_parallel_branch_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_parallel_branch_finished_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueIterationStartEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_iteration_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueIterationNextEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_iteration_next_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueIterationCompletedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_iteration_completed_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueWorkflowSucceededEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 if not graph_runtime_state:
-                    raise Exception('Graph runtime state not initialized.')
+                    raise Exception("Graph runtime state not initialized.")
 
                 workflow_run = self._handle_workflow_run_success(
                     workflow_run=workflow_run,
@@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 )
 
                 yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                 )
 
-                self._queue_manager.publish(
-                    QueueAdvancedChatMessageEndEvent(),
-                    PublishFrom.TASK_PIPELINE
-                )
+                self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
             elif isinstance(event, QueueWorkflowFailedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 if not graph_runtime_state:
-                    raise Exception('Graph runtime state not initialized.')
+                    raise Exception("Graph runtime state not initialized.")
 
                 workflow_run = self._handle_workflow_run_failed(
                     workflow_run=workflow_run,
@@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 )
 
                 yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                 )
 
-                err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
+                err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
                 yield self._error_to_stream_response(self._handle_error(err_event, self._message))
                 break
             elif isinstance(event, QueueStopEvent):
@@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     )
 
                     yield self._workflow_finish_to_stream_response(
-                        task_id=self._application_generate_entity.task_id,
-                        workflow_run=workflow_run
+                        task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
 
                 # Save message
@@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
                 self._refetch_message()
 
-                self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
-                    if self._task_state.metadata else None
+                self._message.message_metadata = (
+                    json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+                )
 
                 db.session.commit()
                 db.session.refresh(self._message)
@@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
                 self._refetch_message()
 
-                self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
-                    if self._task_state.metadata else None
+                self._message.message_metadata = (
+                    json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+                )
 
                 db.session.commit()
                 db.session.refresh(self._message)
@@ -472,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 yield self._message_replace_to_stream_response(answer=event.text)
             elif isinstance(event, QueueAdvancedChatMessageEndEvent):
                 if not graph_runtime_state:
-                    raise Exception('Graph runtime state not initialized.')
+                    raise Exception("Graph runtime state not initialized.")
 
                 output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
                 if output_moderation_answer:
@@ -502,8 +487,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
         self._message.answer = self._task_state.answer
         self._message.provider_response_latency = time.perf_counter() - self._start_at
-        self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
-            if self._task_state.metadata else None
+        self._message.message_metadata = (
+            json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+        )
 
         if graph_runtime_state and graph_runtime_state.llm_usage:
             usage = graph_runtime_state.llm_usage
@@ -523,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             application_generate_entity=self._application_generate_entity,
             conversation=self._conversation,
             is_first_message=self._application_generate_entity.conversation_id is None,
-            extras=self._application_generate_entity.extras
+            extras=self._application_generate_entity.extras,
         )
 
     def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -533,15 +519,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         """
         extras = {}
         if self._task_state.metadata:
-            extras['metadata'] = self._task_state.metadata.copy()
+            extras["metadata"] = self._task_state.metadata.copy()
 
-            if 'annotation_reply' in extras['metadata']:
-                del extras['metadata']['annotation_reply']
+            if "annotation_reply" in extras["metadata"]:
+                del extras["metadata"]["annotation_reply"]
 
         return MessageEndStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            id=self._message.id,
-            **extras
+            task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
         )
 
     def _handle_output_moderation_chunk(self, text: str) -> bool:
@@ -555,14 +539,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 # stop subscribe new token when output moderation should direct output
                 self._task_state.answer = self._output_moderation_handler.get_final_output()
                 self._queue_manager.publish(
-                    QueueTextChunkEvent(
-                        text=self._task_state.answer
-                    ), PublishFrom.TASK_PIPELINE
+                    QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
                 )
 
                 self._queue_manager.publish(
-                    QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
-                    PublishFrom.TASK_PIPELINE
+                    QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
                 )
                 return True
             else:

+ 27 - 32
api/core/app/apps/agent_chat/app_config_manager.py

@@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
     """
     Agent Chatbot App Config Entity.
     """
+
     agent: Optional[AgentEntity] = None
 
 
 class AgentChatAppConfigManager(BaseAppConfigManager):
     @classmethod
-    def get_app_config(cls, app_model: App,
-                       app_model_config: AppModelConfig,
-                       conversation: Optional[Conversation] = None,
-                       override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
+    def get_app_config(
+        cls,
+        app_model: App,
+        app_model_config: AppModelConfig,
+        conversation: Optional[Conversation] = None,
+        override_config_dict: Optional[dict] = None,
+    ) -> AgentChatAppConfig:
         """
         Convert app model config to agent chat app config
         :param app_model: app model
@@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
             app_model_config_from=config_from,
             app_model_config_id=app_model_config.id,
             app_model_config_dict=config_dict,
-            model=ModelConfigManager.convert(
-                config=config_dict
-            ),
-            prompt_template=PromptTemplateConfigManager.convert(
-                config=config_dict
-            ),
-            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
-                config=config_dict
-            ),
-            dataset=DatasetConfigManager.convert(
-                config=config_dict
-            ),
-            agent=AgentConfigManager.convert(
-                config=config_dict
-            ),
-            additional_features=cls.convert_features(config_dict, app_mode)
+            model=ModelConfigManager.convert(config=config_dict),
+            prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
+            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
+            dataset=DatasetConfigManager.convert(config=config_dict),
+            agent=AgentConfigManager.convert(config=config_dict),
+            additional_features=cls.convert_features(config_dict, app_mode),
         )
 
         app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
 
         # suggested_questions_after_answer
         config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
-            config)
+            config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         # speech_to_text
@@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
 
         # dataset configs
         # dataset_query_variable
-        config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
-                                                                                             config)
+        config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
+            tenant_id, app_mode, config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         # moderation validation
-        config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
-                                                                                                            config)
+        config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
+            tenant_id, config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         related_config_keys = list(set(related_config_keys))
@@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
         :param config: app model config args
         """
         if not config.get("agent_mode"):
-            config["agent_mode"] = {
-                "enabled": False,
-                "tools": []
-            }
+            config["agent_mode"] = {"enabled": False, "tools": []}
 
         if not isinstance(config["agent_mode"], dict):
             raise ValueError("agent_mode must be of object type")
@@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
         if not config["agent_mode"].get("strategy"):
             config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
 
-        if config["agent_mode"]["strategy"] not in [member.value for member in
-                                                    list(PlanningStrategy.__members__.values())]:
+        if config["agent_mode"]["strategy"] not in [
+            member.value for member in list(PlanningStrategy.__members__.values())
+        ]:
             raise ValueError("strategy in agent_mode must be in the specified strategy list")
 
         if not config["agent_mode"].get("tools"):
@@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
                     raise ValueError("enabled in agent_mode.tools must be of boolean type")
 
                 if key == "dataset":
-                    if 'id' not in tool_item:
+                    if "id" not in tool_item:
                         raise ValueError("id is required in dataset")
 
                     try:

+ 43 - 59
api/core/app/apps/agent_chat/app_generator.py

@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
 class AgentChatAppGenerator(MessageBasedAppGenerator):
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
@@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
 
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
         stream: Literal[False] = False,
     ) -> dict: ...
 
-    def generate(self, app_model: App,
-                 user: Union[Account, EndUser],
-                 args: Any,
-                 invoke_from: InvokeFrom,
-                 stream: bool = True) \
-            -> Union[dict, Generator[dict, None, None]]:
+    def generate(
+        self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
+    ) -> Union[dict, Generator[dict, None, None]]:
         """
         Generate App response.
 
@@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
         :param stream: is stream
         """
         if not stream:
-            raise ValueError('Agent Chat App does not support blocking mode')
+            raise ValueError("Agent Chat App does not support blocking mode")
 
-        if not args.get('query'):
-            raise ValueError('query is required')
+        if not args.get("query"):
+            raise ValueError("query is required")
 
-        query = args['query']
+        query = args["query"]
         if not isinstance(query, str):
-            raise ValueError('query must be a string')
+            raise ValueError("query must be a string")
 
-        query = query.replace('\x00', '')
-        inputs = args['inputs']
+        query = query.replace("\x00", "")
+        inputs = args["inputs"]
 
-        extras = {
-            "auto_generate_conversation_name": args.get('auto_generate_name', True)
-        }
+        extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
 
         # get conversation
         conversation = None
-        if args.get('conversation_id'):
-            conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
+        if args.get("conversation_id"):
+            conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
 
         # get app model config
-        app_model_config = self._get_app_model_config(
-            app_model=app_model,
-            conversation=conversation
-        )
+        app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
 
         # validate override model config
         override_model_config_dict = None
-        if args.get('model_config'):
+        if args.get("model_config"):
             if invoke_from != InvokeFrom.DEBUGGER:
-                raise ValueError('Only in App debug mode can override model config')
+                raise ValueError("Only in App debug mode can override model config")
 
             # validate config
             override_model_config_dict = AgentChatAppConfigManager.config_validate(
-                tenant_id=app_model.tenant_id,
-                config=args.get('model_config')
+                tenant_id=app_model.tenant_id, config=args.get("model_config")
             )
 
             # always enable retriever resource in debugger mode
-            override_model_config_dict["retriever_resource"] = {
-                "enabled": True
-            }
+            override_model_config_dict["retriever_resource"] = {"enabled": True}
 
         # parse files
-        files = args['files'] if args.get('files') else []
+        files = args["files"] if args.get("files") else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(
-                files,
-                file_extra_config,
-                user
-            )
+            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
         else:
             file_objs = []
 
@@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             app_model=app_model,
             app_model_config=app_model_config,
             conversation=conversation,
-            override_config_dict=override_model_config_dict
+            override_config_dict=override_model_config_dict,
         )
 
         # get tracing instance
@@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=invoke_from,
             extras=extras,
             call_depth=0,
-            trace_manager=trace_manager
+            trace_manager=trace_manager,
         )
 
         # init generate records
-        (
-            conversation,
-            message
-        ) = self._init_generate_records(application_generate_entity, conversation)
+        (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
 
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
@@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=application_generate_entity.invoke_from,
             conversation_id=conversation.id,
             app_mode=conversation.mode,
-            message_id=message.id
+            message_id=message.id,
         )
 
         # new thread
-        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(),
-            'application_generate_entity': application_generate_entity,
-            'queue_manager': queue_manager,
-            'conversation_id': conversation.id,
-            'message_id': message.id,
-        })
+        worker_thread = threading.Thread(
+            target=self._generate_worker,
+            kwargs={
+                "flask_app": current_app._get_current_object(),
+                "application_generate_entity": application_generate_entity,
+                "queue_manager": queue_manager,
+                "conversation_id": conversation.id,
+                "message_id": message.id,
+            },
+        )
 
         worker_thread.start()
 
@@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             stream=stream,
         )
 
-        return AgentChatAppGenerateResponseConverter.convert(
-            response=response,
-            invoke_from=invoke_from
-        )
+        return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
 
     def _generate_worker(
-        self, flask_app: Flask,
+        self,
+        flask_app: Flask,
         application_generate_entity: AgentChatAppGenerateEntity,
         queue_manager: AppQueueManager,
         conversation_id: str,
@@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
                 pass
             except InvokeAuthorizationError:
                 queue_manager.publish_error(
-                    InvokeAuthorizationError('Incorrect API key provided'),
-                    PublishFrom.APPLICATION_MANAGER
+                    InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
                 )
             except ValidationError as e:
                 logger.exception("Validation Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
+                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
                     logger.exception("Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:

+ 49 - 44
api/core/app/apps/agent_chat/app_runner.py

@@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
     """
 
     def run(
-        self, application_generate_entity: AgentChatAppGenerateEntity,
+        self,
+        application_generate_entity: AgentChatAppGenerateEntity,
         queue_manager: AppQueueManager,
         conversation: Conversation,
         message: Message,
@@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
             prompt_template_entity=app_config.prompt_template,
             inputs=inputs,
             files=files,
-            query=query
+            query=query,
         )
 
         memory = None
@@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
             # get memory of conversation (read-only)
             model_instance = ModelInstance(
                 provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
-                model=application_generate_entity.model_conf.model
+                model=application_generate_entity.model_conf.model,
             )
 
-            memory = TokenBufferMemory(
-                conversation=conversation,
-                model_instance=model_instance
-            )
+            memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
 
         # organize all inputs and template to prompt messages
         # Include: prompt template, inputs, query(optional), files(optional)
@@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
             inputs=inputs,
             files=files,
             query=query,
-            memory=memory
+            memory=memory,
         )
 
         # moderation
@@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner):
                 app_generate_entity=application_generate_entity,
                 inputs=inputs,
                 query=query,
-                message_id=message.id
+                message_id=message.id,
             )
         except ModerationException as e:
             self.direct_output(
@@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner):
                 app_generate_entity=application_generate_entity,
                 prompt_messages=prompt_messages,
                 text=str(e),
-                stream=application_generate_entity.stream
+                stream=application_generate_entity.stream,
             )
             return
 
@@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
                 message=message,
                 query=query,
                 user_id=application_generate_entity.user_id,
-                invoke_from=application_generate_entity.invoke_from
+                invoke_from=application_generate_entity.invoke_from,
             )
 
             if annotation_reply:
                 queue_manager.publish(
                     QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
-                    PublishFrom.APPLICATION_MANAGER
+                    PublishFrom.APPLICATION_MANAGER,
                 )
 
                 self.direct_output(
@@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
                     app_generate_entity=application_generate_entity,
                     prompt_messages=prompt_messages,
                     text=annotation_reply.content,
-                    stream=application_generate_entity.stream
+                    stream=application_generate_entity.stream,
                 )
                 return
 
@@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
                 app_id=app_record.id,
                 external_data_tools=external_data_tools,
                 inputs=inputs,
-                query=query
+                query=query,
             )
 
         # reorganize all inputs and template to prompt messages
@@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
             inputs=inputs,
             files=files,
             query=query,
-            memory=memory
+            memory=memory,
         )
 
         # check hosting moderation
         hosting_moderation_result = self.check_hosting_moderation(
             application_generate_entity=application_generate_entity,
             queue_manager=queue_manager,
-            prompt_messages=prompt_messages
+            prompt_messages=prompt_messages,
         )
 
         if hosting_moderation_result:
@@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
         agent_entity = app_config.agent
 
         # load tool variables
-        tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
-                                                   user_id=application_generate_entity.user_id,
-                                                   tenant_id=app_config.tenant_id)
+        tool_conversation_variables = self._load_tool_variables(
+            conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
+        )
 
         # convert db variables to tool variables
         tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
         # init model instance
         model_instance = ModelInstance(
             provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
-            model=application_generate_entity.model_conf.model
+            model=application_generate_entity.model_conf.model,
         )
         prompt_message, _ = self.organize_prompt_messages(
             app_record=app_record,
@@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
             prompt_messages=prompt_message,
             variables_pool=tool_variables,
             db_variables=tool_conversation_variables,
-            model_instance=model_instance
+            model_instance=model_instance,
         )
 
         invoke_result = runner.run(
@@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
             invoke_result=invoke_result,
             queue_manager=queue_manager,
             stream=application_generate_entity.stream,
-            agent=True
+            agent=True,
         )
 
     def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
         """
         load tool variables from database
         """
-        tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
-            ToolConversationVariables.conversation_id == conversation_id,
-            ToolConversationVariables.tenant_id == tenant_id
-        ).first()
+        tool_variables: ToolConversationVariables = (
+            db.session.query(ToolConversationVariables)
+            .filter(
+                ToolConversationVariables.conversation_id == conversation_id,
+                ToolConversationVariables.tenant_id == tenant_id,
+            )
+            .first()
+        )
 
         if tool_variables:
             # save tool variables to session, so that we can update it later
@@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
                 conversation_id=conversation_id,
                 user_id=user_id,
                 tenant_id=tenant_id,
-                variables_str='[]',
+                variables_str="[]",
             )
             db.session.add(tool_variables)
             db.session.commit()
 
         return tool_variables
-    
-    def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
+
+    def _convert_db_variables_to_tool_variables(
+        self, db_variables: ToolConversationVariables
+    ) -> ToolRuntimeVariablePool:
         """
         convert db variables to tool variables
         """
-        return ToolRuntimeVariablePool(**{
-            'conversation_id': db_variables.conversation_id,
-            'user_id': db_variables.user_id,
-            'tenant_id': db_variables.tenant_id,
-            'pool': db_variables.variables
-        })
-
-    def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
-                                         message: Message) -> LLMUsage:
+        return ToolRuntimeVariablePool(
+            **{
+                "conversation_id": db_variables.conversation_id,
+                "user_id": db_variables.user_id,
+                "tenant_id": db_variables.tenant_id,
+                "pool": db_variables.variables,
+            }
+        )
+
+    def _get_usage_of_all_agent_thoughts(
+        self, model_config: ModelConfigWithCredentialsEntity, message: Message
+    ) -> LLMUsage:
         """
         Get usage of all agent thoughts
         :param model_config: model config
         :param message: message
         :return:
         """
-        agent_thoughts = (db.session.query(MessageAgentThought)
-                          .filter(MessageAgentThought.message_id == message.id).all())
+        agent_thoughts = (
+            db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
+        )
 
         all_message_tokens = 0
         all_answer_tokens = 0
@@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
         model_type_instance = cast(LargeLanguageModel, model_type_instance)
 
         return model_type_instance._calc_response_usage(
-            model_config.model,
-            model_config.credentials,
-            all_message_tokens,
-            all_answer_tokens
+            model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
         )

+ 29 - 27
api/core/app/apps/agent_chat/generate_response_converter.py

@@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
         :return:
         """
         response = {
-            'event': 'message',
-            'task_id': blocking_response.task_id,
-            'id': blocking_response.data.id,
-            'message_id': blocking_response.data.message_id,
-            'conversation_id': blocking_response.data.conversation_id,
-            'mode': blocking_response.data.mode,
-            'answer': blocking_response.data.answer,
-            'metadata': blocking_response.data.metadata,
-            'created_at': blocking_response.data.created_at
+            "event": "message",
+            "task_id": blocking_response.task_id,
+            "id": blocking_response.data.id,
+            "message_id": blocking_response.data.message_id,
+            "conversation_id": blocking_response.data.conversation_id,
+            "mode": blocking_response.data.mode,
+            "answer": blocking_response.data.answer,
+            "metadata": blocking_response.data.metadata,
+            "created_at": blocking_response.data.created_at,
         }
 
         return response
@@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
         """
         response = cls.convert_blocking_full_response(blocking_response)
 
-        metadata = response.get('metadata', {})
-        response['metadata'] = cls._get_simple_metadata(metadata)
+        metadata = response.get("metadata", {})
+        response["metadata"] = cls._get_simple_metadata(metadata)
 
         return response
 
     @classmethod
-    def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_full_response(
+        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream full response.
         :param stream_response: stream response
@@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'conversation_id': chunk.conversation_id,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "conversation_id": chunk.conversation_id,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             yield json.dumps(response_chunk)
 
     @classmethod
-    def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_simple_response(
+        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream simple response.
         :param stream_response: stream response
@@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'conversation_id': chunk.conversation_id,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "conversation_id": chunk.conversation_id,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, MessageEndStreamResponse):
                 sub_stream_response_dict = sub_stream_response.to_dict()
-                metadata = sub_stream_response_dict.get('metadata', {})
-                sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
+                metadata = sub_stream_response_dict.get("metadata", {})
+                sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
                 response_chunk.update(sub_stream_response_dict)
             if isinstance(sub_stream_response, ErrorStreamResponse):
                 data = cls._error_to_stream_response(sub_stream_response.err)

+ 45 - 40
api/core/app/apps/base_app_generate_response_converter.py

@@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC):
     _blocking_response_type: type[AppBlockingResponse]
 
     @classmethod
-    def convert(cls, response: Union[
-        AppBlockingResponse,
-        Generator[AppStreamResponse, Any, None]
-    ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
+    def convert(
+        cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
+    ) -> dict[str, Any] | Generator[str, Any, None]:
         if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_full_response(response)
             else:
+
                 def _generate_full_response() -> Generator[str, Any, None]:
                     for chunk in cls.convert_stream_full_response(response):
-                        if chunk == 'ping':
-                            yield f'event: {chunk}\n\n'
+                        if chunk == "ping":
+                            yield f"event: {chunk}\n\n"
                         else:
-                            yield f'data: {chunk}\n\n'
+                            yield f"data: {chunk}\n\n"
 
                 return _generate_full_response()
         else:
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_simple_response(response)
             else:
+
                 def _generate_simple_response() -> Generator[str, Any, None]:
                     for chunk in cls.convert_stream_simple_response(response):
-                        if chunk == 'ping':
-                            yield f'event: {chunk}\n\n'
+                        if chunk == "ping":
+                            yield f"event: {chunk}\n\n"
                         else:
-                            yield f'data: {chunk}\n\n'
+                            yield f"data: {chunk}\n\n"
 
                 return _generate_simple_response()
 
@@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC):
 
     @classmethod
     @abstractmethod
-    def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_full_response(
+        cls, stream_response: Generator[AppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         raise NotImplementedError
 
     @classmethod
     @abstractmethod
-    def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_simple_response(
+        cls, stream_response: Generator[AppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         raise NotImplementedError
 
     @classmethod
@@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC):
         :return:
         """
         # show_retrieve_source
-        if 'retriever_resources' in metadata:
-            metadata['retriever_resources'] = []
-            for resource in metadata['retriever_resources']:
-                metadata['retriever_resources'].append({
-                    'segment_id': resource['segment_id'],
-                    'position': resource['position'],
-                    'document_name': resource['document_name'],
-                    'score': resource['score'],
-                    'content': resource['content'],
-                })
+        if "retriever_resources" in metadata:
+            metadata["retriever_resources"] = []
+            for resource in metadata["retriever_resources"]:
+                metadata["retriever_resources"].append(
+                    {
+                        "segment_id": resource["segment_id"],
+                        "position": resource["position"],
+                        "document_name": resource["document_name"],
+                        "score": resource["score"],
+                        "content": resource["content"],
+                    }
+                )
 
         # show annotation reply
-        if 'annotation_reply' in metadata:
-            del metadata['annotation_reply']
+        if "annotation_reply" in metadata:
+            del metadata["annotation_reply"]
 
         # show usage
-        if 'usage' in metadata:
-            del metadata['usage']
+        if "usage" in metadata:
+            del metadata["usage"]
 
         return metadata
 
@@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC):
         :return:
         """
         error_responses = {
-            ValueError: {'code': 'invalid_param', 'status': 400},
-            ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
+            ValueError: {"code": "invalid_param", "status": 400},
+            ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
             QuotaExceededError: {
-                'code': 'provider_quota_exceeded',
-                'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
-                           "Please go to Settings -> Model Provider to complete your own provider credentials.",
-                'status': 400
+                "code": "provider_quota_exceeded",
+                "message": "Your quota for Dify Hosted Model Provider has been exhausted. "
+                "Please go to Settings -> Model Provider to complete your own provider credentials.",
+                "status": 400,
             },
-            ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
-            InvokeError: {'code': 'completion_request_error', 'status': 400}
+            ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
+            InvokeError: {"code": "completion_request_error", "status": 400},
         }
 
         # Determine the response based on the type of exception
@@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC):
                 data = v
 
         if data:
-            data.setdefault('message', getattr(e, 'description', str(e)))
+            data.setdefault("message", getattr(e, "description", str(e)))
         else:
             logging.error(e)
             data = {
-                'code': 'internal_server_error',
-                'message': 'Internal Server Error, please contact support.',
-                'status': 500
+                "code": "internal_server_error",
+                "message": "Internal Server Error, please contact support.",
+                "status": 500,
             }
 
         return data

+ 6 - 6
api/core/app/apps/base_app_generator.py

@@ -16,10 +16,10 @@ class BaseAppGenerator:
     def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
         user_input_value = inputs.get(var.variable)
         if var.required and not user_input_value:
-            raise ValueError(f'{var.variable} is required in input form')
+            raise ValueError(f"{var.variable} is required in input form")
         if not var.required and not user_input_value:
             # TODO: should we return None here if the default value is None?
-            return var.default or ''
+            return var.default or ""
         if (
             var.type
             in (
@@ -34,7 +34,7 @@ class BaseAppGenerator:
         if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
             # may raise ValueError if user_input_value is not a valid number
             try:
-                if '.' in user_input_value:
+                if "." in user_input_value:
                     return float(user_input_value)
                 else:
                     return int(user_input_value)
@@ -43,14 +43,14 @@ class BaseAppGenerator:
         if var.type == VariableEntityType.SELECT:
             options = var.options or []
             if user_input_value not in options:
-                raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
+                raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
         elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
             if var.max_length and user_input_value and len(user_input_value) > var.max_length:
-                raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
+                raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
 
         return user_input_value
 
     def _sanitize_value(self, value: Any) -> Any:
         if isinstance(value, str):
-            return value.replace('\x00', '')
+            return value.replace("\x00", "")
         return value

+ 14 - 16
api/core/app/apps/base_app_queue_manager.py

@@ -24,9 +24,7 @@ class PublishFrom(Enum):
 
 
 class AppQueueManager:
-    def __init__(self, task_id: str,
-                 user_id: str,
-                 invoke_from: InvokeFrom) -> None:
+    def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
         if not user_id:
             raise ValueError("user is required")
 
@@ -34,9 +32,10 @@ class AppQueueManager:
         self._user_id = user_id
         self._invoke_from = invoke_from
 
-        user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
-        redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
-                           f"{user_prefix}-{self._user_id}")
+        user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
+        redis_client.setex(
+            AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
+        )
 
         q = queue.Queue()
 
@@ -66,8 +65,7 @@ class AppQueueManager:
                     # publish two messages to make sure the client can receive the stop signal
                     # and stop listening after the stop signal processed
                     self.publish(
-                        QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
-                        PublishFrom.TASK_PIPELINE
+                        QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
                     )
 
                 if elapsed_time // 10 > last_ping_time:
@@ -88,9 +86,7 @@ class AppQueueManager:
         :param pub_from: publish from
         :return:
         """
-        self.publish(QueueErrorEvent(
-            error=e
-        ), pub_from)
+        self.publish(QueueErrorEvent(error=e), pub_from)
 
     def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
         """
@@ -122,8 +118,8 @@ class AppQueueManager:
         if result is None:
             return
 
-        user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
-        if result.decode('utf-8') != f"{user_prefix}-{user_id}":
+        user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
+        if result.decode("utf-8") != f"{user_prefix}-{user_id}":
             return
 
         stopped_cache_key = cls._generate_stopped_cache_key(task_id)
@@ -168,9 +164,11 @@ class AppQueueManager:
             for item in data:
                 self._check_for_sqlalchemy_models(item)
         else:
-            if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
-                raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
-                                "that cause thread safety issues is not allowed.")
+            if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
+                raise TypeError(
+                    "Critical Error: Passing SQLAlchemy Model instances "
+                    "that cause thread safety issues is not allowed."
+                )
 
 
 class GenerateTaskStoppedException(Exception):

+ 125 - 147
api/core/app/apps/base_app_runner.py

@@ -31,12 +31,15 @@ if TYPE_CHECKING:
 
 
 class AppRunner:
-    def get_pre_calculate_rest_tokens(self, app_record: App,
-                                      model_config: ModelConfigWithCredentialsEntity,
-                                      prompt_template_entity: PromptTemplateEntity,
-                                      inputs: dict[str, str],
-                                      files: list["FileVar"],
-                                      query: Optional[str] = None) -> int:
+    def get_pre_calculate_rest_tokens(
+        self,
+        app_record: App,
+        model_config: ModelConfigWithCredentialsEntity,
+        prompt_template_entity: PromptTemplateEntity,
+        inputs: dict[str, str],
+        files: list["FileVar"],
+        query: Optional[str] = None,
+    ) -> int:
         """
         Get pre calculate rest tokens
         :param app_record: app record
@@ -49,18 +52,20 @@ class AppRunner:
         """
         # Invoke model
         model_instance = ModelInstance(
-            provider_model_bundle=model_config.provider_model_bundle,
-            model=model_config.model
+            provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
         )
 
         model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
 
         max_tokens = 0
         for parameter_rule in model_config.model_schema.parameter_rules:
-            if (parameter_rule.name == 'max_tokens'
-                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
-                max_tokens = (model_config.parameters.get(parameter_rule.name)
-                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+            if parameter_rule.name == "max_tokens" or (
+                parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
+            ):
+                max_tokens = (
+                    model_config.parameters.get(parameter_rule.name)
+                    or model_config.parameters.get(parameter_rule.use_template)
+                ) or 0
 
         if model_context_tokens is None:
             return -1
@@ -75,36 +80,39 @@ class AppRunner:
             prompt_template_entity=prompt_template_entity,
             inputs=inputs,
             files=files,
-            query=query
+            query=query,
         )
 
-        prompt_tokens = model_instance.get_llm_num_tokens(
-            prompt_messages
-        )
+        prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
 
         rest_tokens = model_context_tokens - max_tokens - prompt_tokens
         if rest_tokens < 0:
-            raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
-                                        "or shrink the max token, or switch to a llm with a larger token limit size.")
+            raise InvokeBadRequestError(
+                "Query or prefix prompt is too long, you can reduce the prefix prompt, "
+                "or shrink the max token, or switch to a llm with a larger token limit size."
+            )
 
         return rest_tokens
 
-    def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
-                              prompt_messages: list[PromptMessage]):
+    def recalc_llm_max_tokens(
+        self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
+    ):
         # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
         model_instance = ModelInstance(
-            provider_model_bundle=model_config.provider_model_bundle,
-            model=model_config.model
+            provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
         )
 
         model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
 
         max_tokens = 0
         for parameter_rule in model_config.model_schema.parameter_rules:
-            if (parameter_rule.name == 'max_tokens'
-                    or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
-                max_tokens = (model_config.parameters.get(parameter_rule.name)
-                              or model_config.parameters.get(parameter_rule.use_template)) or 0
+            if parameter_rule.name == "max_tokens" or (
+                parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
+            ):
+                max_tokens = (
+                    model_config.parameters.get(parameter_rule.name)
+                    or model_config.parameters.get(parameter_rule.use_template)
+                ) or 0
 
         if model_context_tokens is None:
             return -1
@@ -112,27 +120,28 @@ class AppRunner:
         if max_tokens is None:
             max_tokens = 0
 
-        prompt_tokens = model_instance.get_llm_num_tokens(
-            prompt_messages
-        )
+        prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
 
         if prompt_tokens + max_tokens > model_context_tokens:
             max_tokens = max(model_context_tokens - prompt_tokens, 16)
 
             for parameter_rule in model_config.model_schema.parameter_rules:
-                if (parameter_rule.name == 'max_tokens'
-                        or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
+                if parameter_rule.name == "max_tokens" or (
+                    parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
+                ):
                     model_config.parameters[parameter_rule.name] = max_tokens
 
-    def organize_prompt_messages(self, app_record: App,
-                                 model_config: ModelConfigWithCredentialsEntity,
-                                 prompt_template_entity: PromptTemplateEntity,
-                                 inputs: dict[str, str],
-                                 files: list["FileVar"],
-                                 query: Optional[str] = None,
-                                 context: Optional[str] = None,
-                                 memory: Optional[TokenBufferMemory] = None) \
-            -> tuple[list[PromptMessage], Optional[list[str]]]:
+    def organize_prompt_messages(
+        self,
+        app_record: App,
+        model_config: ModelConfigWithCredentialsEntity,
+        prompt_template_entity: PromptTemplateEntity,
+        inputs: dict[str, str],
+        files: list["FileVar"],
+        query: Optional[str] = None,
+        context: Optional[str] = None,
+        memory: Optional[TokenBufferMemory] = None,
+    ) -> tuple[list[PromptMessage], Optional[list[str]]]:
         """
         Organize prompt messages
         :param context:
@@ -152,60 +161,54 @@ class AppRunner:
                 app_mode=AppMode.value_of(app_record.mode),
                 prompt_template_entity=prompt_template_entity,
                 inputs=inputs,
-                query=query if query else '',
+                query=query if query else "",
                 files=files,
                 context=context,
                 memory=memory,
-                model_config=model_config
+                model_config=model_config,
             )
         else:
-            memory_config = MemoryConfig(
-                window=MemoryConfig.WindowConfig(
-                    enabled=False
-                )
-            )
+            memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
 
             model_mode = ModelMode.value_of(model_config.mode)
             if model_mode == ModelMode.COMPLETION:
                 advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
-                prompt_template = CompletionModelPromptTemplate(
-                    text=advanced_completion_prompt_template.prompt
-                )
+                prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
 
                 if advanced_completion_prompt_template.role_prefix:
                     memory_config.role_prefix = MemoryConfig.RolePrefix(
                         user=advanced_completion_prompt_template.role_prefix.user,
-                        assistant=advanced_completion_prompt_template.role_prefix.assistant
+                        assistant=advanced_completion_prompt_template.role_prefix.assistant,
                     )
             else:
                 prompt_template = []
                 for message in prompt_template_entity.advanced_chat_prompt_template.messages:
-                    prompt_template.append(ChatModelMessage(
-                        text=message.text,
-                        role=message.role
-                    ))
+                    prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
 
             prompt_transform = AdvancedPromptTransform()
             prompt_messages = prompt_transform.get_prompt(
                 prompt_template=prompt_template,
                 inputs=inputs,
-                query=query if query else '',
+                query=query if query else "",
                 files=files,
                 context=context,
                 memory_config=memory_config,
                 memory=memory,
-                model_config=model_config
+                model_config=model_config,
             )
             stop = model_config.stop
 
         return prompt_messages, stop
 
-    def direct_output(self, queue_manager: AppQueueManager,
-                      app_generate_entity: EasyUIBasedAppGenerateEntity,
-                      prompt_messages: list,
-                      text: str,
-                      stream: bool,
-                      usage: Optional[LLMUsage] = None) -> None:
+    def direct_output(
+        self,
+        queue_manager: AppQueueManager,
+        app_generate_entity: EasyUIBasedAppGenerateEntity,
+        prompt_messages: list,
+        text: str,
+        stream: bool,
+        usage: Optional[LLMUsage] = None,
+    ) -> None:
         """
         Direct output
         :param queue_manager: application queue manager
@@ -222,17 +225,10 @@ class AppRunner:
                 chunk = LLMResultChunk(
                     model=app_generate_entity.model_conf.model,
                     prompt_messages=prompt_messages,
-                    delta=LLMResultChunkDelta(
-                        index=index,
-                        message=AssistantPromptMessage(content=token)
-                    )
+                    delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
                 )
 
-                queue_manager.publish(
-                    QueueLLMChunkEvent(
-                        chunk=chunk
-                    ), PublishFrom.APPLICATION_MANAGER
-                )
+                queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
                 index += 1
                 time.sleep(0.01)
 
@@ -242,15 +238,19 @@ class AppRunner:
                     model=app_generate_entity.model_conf.model,
                     prompt_messages=prompt_messages,
                     message=AssistantPromptMessage(content=text),
-                    usage=usage if usage else LLMUsage.empty_usage()
+                    usage=usage if usage else LLMUsage.empty_usage(),
                 ),
-            ), PublishFrom.APPLICATION_MANAGER
+            ),
+            PublishFrom.APPLICATION_MANAGER,
         )
 
-    def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
-                              queue_manager: AppQueueManager,
-                              stream: bool,
-                              agent: bool = False) -> None:
+    def _handle_invoke_result(
+        self,
+        invoke_result: Union[LLMResult, Generator],
+        queue_manager: AppQueueManager,
+        stream: bool,
+        agent: bool = False,
+    ) -> None:
         """
         Handle invoke result
         :param invoke_result: invoke result
@@ -260,21 +260,13 @@ class AppRunner:
         :return:
         """
         if not stream:
-            self._handle_invoke_result_direct(
-                invoke_result=invoke_result,
-                queue_manager=queue_manager,
-                agent=agent
-            )
+            self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
         else:
-            self._handle_invoke_result_stream(
-                invoke_result=invoke_result,
-                queue_manager=queue_manager,
-                agent=agent
-            )
+            self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
 
-    def _handle_invoke_result_direct(self, invoke_result: LLMResult,
-                                     queue_manager: AppQueueManager,
-                                     agent: bool) -> None:
+    def _handle_invoke_result_direct(
+        self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
+    ) -> None:
         """
         Handle invoke result direct
         :param invoke_result: invoke result
@@ -285,12 +277,13 @@ class AppRunner:
         queue_manager.publish(
             QueueMessageEndEvent(
                 llm_result=invoke_result,
-            ), PublishFrom.APPLICATION_MANAGER
+            ),
+            PublishFrom.APPLICATION_MANAGER,
         )
 
-    def _handle_invoke_result_stream(self, invoke_result: Generator,
-                                     queue_manager: AppQueueManager,
-                                     agent: bool) -> None:
+    def _handle_invoke_result_stream(
+        self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
+    ) -> None:
         """
         Handle invoke result
         :param invoke_result: invoke result
@@ -300,21 +293,13 @@ class AppRunner:
         """
         model = None
         prompt_messages = []
-        text = ''
+        text = ""
         usage = None
         for result in invoke_result:
             if not agent:
-                queue_manager.publish(
-                    QueueLLMChunkEvent(
-                        chunk=result
-                    ), PublishFrom.APPLICATION_MANAGER
-                )
+                queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
             else:
-                queue_manager.publish(
-                    QueueAgentMessageEvent(
-                        chunk=result
-                    ), PublishFrom.APPLICATION_MANAGER
-                )
+                queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
 
             text += result.delta.message.content
 
@@ -331,25 +316,24 @@ class AppRunner:
             usage = LLMUsage.empty_usage()
 
         llm_result = LLMResult(
-            model=model,
-            prompt_messages=prompt_messages,
-            message=AssistantPromptMessage(content=text),
-            usage=usage
+            model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
         )
 
         queue_manager.publish(
             QueueMessageEndEvent(
                 llm_result=llm_result,
-            ), PublishFrom.APPLICATION_MANAGER
+            ),
+            PublishFrom.APPLICATION_MANAGER,
         )
 
     def moderation_for_inputs(
-            self, app_id: str,
-            tenant_id: str,
-            app_generate_entity: AppGenerateEntity,
-            inputs: Mapping[str, Any],
-            query: str,
-            message_id: str,
+        self,
+        app_id: str,
+        tenant_id: str,
+        app_generate_entity: AppGenerateEntity,
+        inputs: Mapping[str, Any],
+        query: str,
+        message_id: str,
     ) -> tuple[bool, dict, str]:
         """
         Process sensitive_word_avoidance.
@@ -367,14 +351,17 @@ class AppRunner:
             tenant_id=tenant_id,
             app_config=app_generate_entity.app_config,
             inputs=inputs,
-            query=query if query else '',
+            query=query if query else "",
             message_id=message_id,
-            trace_manager=app_generate_entity.trace_manager
+            trace_manager=app_generate_entity.trace_manager,
         )
 
-    def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
-                                 queue_manager: AppQueueManager,
-                                 prompt_messages: list[PromptMessage]) -> bool:
+    def check_hosting_moderation(
+        self,
+        application_generate_entity: EasyUIBasedAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        prompt_messages: list[PromptMessage],
+    ) -> bool:
         """
         Check hosting moderation
         :param application_generate_entity: application generate entity
@@ -384,8 +371,7 @@ class AppRunner:
         """
         hosting_moderation_feature = HostingModerationFeature()
         moderation_result = hosting_moderation_feature.check(
-            application_generate_entity=application_generate_entity,
-            prompt_messages=prompt_messages
+            application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
         )
 
         if moderation_result:
@@ -393,18 +379,20 @@ class AppRunner:
                 queue_manager=queue_manager,
                 app_generate_entity=application_generate_entity,
                 prompt_messages=prompt_messages,
-                text="I apologize for any confusion, " \
-                     "but I'm an AI assistant to be helpful, harmless, and honest.",
-                stream=application_generate_entity.stream
+                text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
+                stream=application_generate_entity.stream,
             )
 
         return moderation_result
 
-    def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
-                                                app_id: str,
-                                                external_data_tools: list[ExternalDataVariableEntity],
-                                                inputs: dict,
-                                                query: str) -> dict:
+    def fill_in_inputs_from_external_data_tools(
+        self,
+        tenant_id: str,
+        app_id: str,
+        external_data_tools: list[ExternalDataVariableEntity],
+        inputs: dict,
+        query: str,
+    ) -> dict:
         """
         Fill in variable inputs from external data tools if exists.
 
@@ -417,18 +405,12 @@ class AppRunner:
         """
         external_data_fetch_feature = ExternalDataFetch()
         return external_data_fetch_feature.fetch(
-            tenant_id=tenant_id,
-            app_id=app_id,
-            external_data_tools=external_data_tools,
-            inputs=inputs,
-            query=query
+            tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
         )
 
-    def query_app_annotations_to_reply(self, app_record: App,
-                                       message: Message,
-                                       query: str,
-                                       user_id: str,
-                                       invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
+    def query_app_annotations_to_reply(
+        self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
+    ) -> Optional[MessageAnnotation]:
         """
         Query app annotations to reply
         :param app_record: app record
@@ -440,9 +422,5 @@ class AppRunner:
         """
         annotation_reply_feature = AnnotationReplyFeature()
         return annotation_reply_feature.query(
-            app_record=app_record,
-            message=message,
-            query=query,
-            user_id=user_id,
-            invoke_from=invoke_from
+            app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
         )

+ 22 - 23
api/core/app/apps/chat/app_config_manager.py

@@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
     """
     Chatbot App Config Entity.
     """
+
     pass
 
 
 class ChatAppConfigManager(BaseAppConfigManager):
     @classmethod
-    def get_app_config(cls, app_model: App,
-                       app_model_config: AppModelConfig,
-                       conversation: Optional[Conversation] = None,
-                       override_config_dict: Optional[dict] = None) -> ChatAppConfig:
+    def get_app_config(
+        cls,
+        app_model: App,
+        app_model_config: AppModelConfig,
+        conversation: Optional[Conversation] = None,
+        override_config_dict: Optional[dict] = None,
+    ) -> ChatAppConfig:
         """
         Convert app model config to chat app config
         :param app_model: app model
@@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
             config_dict = app_model_config_dict.copy()
         else:
             if not override_config_dict:
-                raise Exception('override_config_dict is required when config_from is ARGS')
+                raise Exception("override_config_dict is required when config_from is ARGS")
 
             config_dict = override_config_dict
 
@@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
             app_model_config_from=config_from,
             app_model_config_id=app_model_config.id,
             app_model_config_dict=config_dict,
-            model=ModelConfigManager.convert(
-                config=config_dict
-            ),
-            prompt_template=PromptTemplateConfigManager.convert(
-                config=config_dict
-            ),
-            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
-                config=config_dict
-            ),
-            dataset=DatasetConfigManager.convert(
-                config=config_dict
-            ),
-            additional_features=cls.convert_features(config_dict, app_mode)
+            model=ModelConfigManager.convert(config=config_dict),
+            prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
+            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
+            dataset=DatasetConfigManager.convert(config=config_dict),
+            additional_features=cls.convert_features(config_dict, app_mode),
         )
 
         app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
         related_config_keys.extend(current_related_config_keys)
 
         # dataset_query_variable
-        config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
-                                                                                             config)
+        config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
+            tenant_id, app_mode, config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         # opening_statement
@@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
 
         # suggested_questions_after_answer
         config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
-            config)
+            config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         # speech_to_text
@@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
         related_config_keys.extend(current_related_config_keys)
 
         # moderation validation
-        config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
-                                                                                                            config)
+        config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
+            tenant_id, config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         related_config_keys = list(set(related_config_keys))

+ 48 - 58
api/core/app/apps/chat/app_generator.py

@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
 class ChatAppGenerator(MessageBasedAppGenerator):
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: Any,
         invoke_from: InvokeFrom,
@@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
 
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: Any,
         invoke_from: InvokeFrom,
@@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
     ) -> dict: ...
 
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: Any,
         invoke_from: InvokeFrom,
@@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
         :param invoke_from: invoke from source
         :param stream: is stream
         """
-        if not args.get('query'):
-            raise ValueError('query is required')
+        if not args.get("query"):
+            raise ValueError("query is required")
 
-        query = args['query']
+        query = args["query"]
         if not isinstance(query, str):
-            raise ValueError('query must be a string')
+            raise ValueError("query must be a string")
 
-        query = query.replace('\x00', '')
-        inputs = args['inputs']
+        query = query.replace("\x00", "")
+        inputs = args["inputs"]
 
-        extras = {
-            "auto_generate_conversation_name": args.get('auto_generate_name', True)
-        }
+        extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
 
         # get conversation
         conversation = None
-        if args.get('conversation_id'):
-            conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
+        if args.get("conversation_id"):
+            conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
 
         # get app model config
-        app_model_config = self._get_app_model_config(
-            app_model=app_model,
-            conversation=conversation
-        )
+        app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
 
         # validate override model config
         override_model_config_dict = None
-        if args.get('model_config'):
+        if args.get("model_config"):
             if invoke_from != InvokeFrom.DEBUGGER:
-                raise ValueError('Only in App debug mode can override model config')
+                raise ValueError("Only in App debug mode can override model config")
 
             # validate config
             override_model_config_dict = ChatAppConfigManager.config_validate(
-                tenant_id=app_model.tenant_id,
-                config=args.get('model_config')
+                tenant_id=app_model.tenant_id, config=args.get("model_config")
             )
 
             # always enable retriever resource in debugger mode
-            override_model_config_dict["retriever_resource"] = {
-                "enabled": True
-            }
+            override_model_config_dict["retriever_resource"] = {"enabled": True}
 
         # parse files
-        files = args['files'] if args.get('files') else []
+        files = args["files"] if args.get("files") else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(
-                files,
-                file_extra_config,
-                user
-            )
+            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
         else:
             file_objs = []
 
@@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             app_model=app_model,
             app_model_config=app_model_config,
             conversation=conversation,
-            override_config_dict=override_model_config_dict
+            override_config_dict=override_model_config_dict,
         )
 
         # get tracing instance
@@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             stream=stream,
             invoke_from=invoke_from,
             extras=extras,
-            trace_manager=trace_manager
+            trace_manager=trace_manager,
         )
 
         # init generate records
-        (
-            conversation,
-            message
-        ) = self._init_generate_records(application_generate_entity, conversation)
+        (conversation, message) = self._init_generate_records(application_generate_entity, conversation)
 
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
@@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=application_generate_entity.invoke_from,
             conversation_id=conversation.id,
             app_mode=conversation.mode,
-            message_id=message.id
+            message_id=message.id,
         )
 
         # new thread
-        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(),
-            'application_generate_entity': application_generate_entity,
-            'queue_manager': queue_manager,
-            'conversation_id': conversation.id,
-            'message_id': message.id,
-        })
+        worker_thread = threading.Thread(
+            target=self._generate_worker,
+            kwargs={
+                "flask_app": current_app._get_current_object(),
+                "application_generate_entity": application_generate_entity,
+                "queue_manager": queue_manager,
+                "conversation_id": conversation.id,
+                "message_id": message.id,
+            },
+        )
 
         worker_thread.start()
 
@@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             stream=stream,
         )
 
-        return ChatAppGenerateResponseConverter.convert(
-            response=response,
-            invoke_from=invoke_from
-        )
+        return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
 
-    def _generate_worker(self, flask_app: Flask,
-                         application_generate_entity: ChatAppGenerateEntity,
-                         queue_manager: AppQueueManager,
-                         conversation_id: str,
-                         message_id: str) -> None:
+    def _generate_worker(
+        self,
+        flask_app: Flask,
+        application_generate_entity: ChatAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        conversation_id: str,
+        message_id: str,
+    ) -> None:
         """
         Generate worker in a new thread.
         :param flask_app: Flask app
@@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
                     application_generate_entity=application_generate_entity,
                     queue_manager=queue_manager,
                     conversation=conversation,
-                    message=message
+                    message=message,
                 )
             except GenerateTaskStoppedException:
                 pass
             except InvokeAuthorizationError:
                 queue_manager.publish_error(
-                    InvokeAuthorizationError('Incorrect API key provided'),
-                    PublishFrom.APPLICATION_MANAGER
+                    InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
                 )
             except ValidationError as e:
                 logger.exception("Validation Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
+                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
                     logger.exception("Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:

+ 23 - 28
api/core/app/apps/chat/app_runner.py

@@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
     Chat Application Runner
     """
 
-    def run(self, application_generate_entity: ChatAppGenerateEntity,
-            queue_manager: AppQueueManager,
-            conversation: Conversation,
-            message: Message) -> None:
+    def run(
+        self,
+        application_generate_entity: ChatAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        conversation: Conversation,
+        message: Message,
+    ) -> None:
         """
         Run application
         :param application_generate_entity: application generate entity
@@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
             prompt_template_entity=app_config.prompt_template,
             inputs=inputs,
             files=files,
-            query=query
+            query=query,
         )
 
         memory = None
@@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
             # get memory of conversation (read-only)
             model_instance = ModelInstance(
                 provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
-                model=application_generate_entity.model_conf.model
+                model=application_generate_entity.model_conf.model,
             )
 
-            memory = TokenBufferMemory(
-                conversation=conversation,
-                model_instance=model_instance
-            )
+            memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
 
         # organize all inputs and template to prompt messages
         # Include: prompt template, inputs, query(optional), files(optional)
@@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
             inputs=inputs,
             files=files,
             query=query,
-            memory=memory
+            memory=memory,
         )
 
         # moderation
@@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner):
                 app_generate_entity=application_generate_entity,
                 inputs=inputs,
                 query=query,
-                message_id=message.id
+                message_id=message.id,
             )
         except ModerationException as e:
             self.direct_output(
@@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner):
                 app_generate_entity=application_generate_entity,
                 prompt_messages=prompt_messages,
                 text=str(e),
-                stream=application_generate_entity.stream
+                stream=application_generate_entity.stream,
             )
             return
 
@@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
                 message=message,
                 query=query,
                 user_id=application_generate_entity.user_id,
-                invoke_from=application_generate_entity.invoke_from
+                invoke_from=application_generate_entity.invoke_from,
             )
 
             if annotation_reply:
                 queue_manager.publish(
                     QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
-                    PublishFrom.APPLICATION_MANAGER
+                    PublishFrom.APPLICATION_MANAGER,
                 )
 
                 self.direct_output(
@@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
                     app_generate_entity=application_generate_entity,
                     prompt_messages=prompt_messages,
                     text=annotation_reply.content,
-                    stream=application_generate_entity.stream
+                    stream=application_generate_entity.stream,
                 )
                 return
 
@@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
                 app_id=app_record.id,
                 external_data_tools=external_data_tools,
                 inputs=inputs,
-                query=query
+                query=query,
             )
 
         # get context from datasets
@@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
                 app_record.id,
                 message.id,
                 application_generate_entity.user_id,
-                application_generate_entity.invoke_from
+                application_generate_entity.invoke_from,
             )
 
             dataset_retrieval = DatasetRetrieval(application_generate_entity)
@@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
             files=files,
             query=query,
             context=context,
-            memory=memory
+            memory=memory,
         )
 
         # check hosting moderation
         hosting_moderation_result = self.check_hosting_moderation(
             application_generate_entity=application_generate_entity,
             queue_manager=queue_manager,
-            prompt_messages=prompt_messages
+            prompt_messages=prompt_messages,
         )
 
         if hosting_moderation_result:
             return
 
         # Re-calculate the max tokens if sum(prompt_token +  max_tokens) over model token limit
-        self.recalc_llm_max_tokens(
-            model_config=application_generate_entity.model_conf,
-            prompt_messages=prompt_messages
-        )
+        self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
 
         # Invoke model
         model_instance = ModelInstance(
             provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
-            model=application_generate_entity.model_conf.model
+            model=application_generate_entity.model_conf.model,
         )
 
         db.session.close()
@@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
 
         # handle invoke result
         self._handle_invoke_result(
-            invoke_result=invoke_result,
-            queue_manager=queue_manager,
-            stream=application_generate_entity.stream
+            invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
         )

+ 29 - 27
api/core/app/apps/chat/generate_response_converter.py

@@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
         :return:
         """
         response = {
-            'event': 'message',
-            'task_id': blocking_response.task_id,
-            'id': blocking_response.data.id,
-            'message_id': blocking_response.data.message_id,
-            'conversation_id': blocking_response.data.conversation_id,
-            'mode': blocking_response.data.mode,
-            'answer': blocking_response.data.answer,
-            'metadata': blocking_response.data.metadata,
-            'created_at': blocking_response.data.created_at
+            "event": "message",
+            "task_id": blocking_response.task_id,
+            "id": blocking_response.data.id,
+            "message_id": blocking_response.data.message_id,
+            "conversation_id": blocking_response.data.conversation_id,
+            "mode": blocking_response.data.mode,
+            "answer": blocking_response.data.answer,
+            "metadata": blocking_response.data.metadata,
+            "created_at": blocking_response.data.created_at,
         }
 
         return response
@@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
         """
         response = cls.convert_blocking_full_response(blocking_response)
 
-        metadata = response.get('metadata', {})
-        response['metadata'] = cls._get_simple_metadata(metadata)
+        metadata = response.get("metadata", {})
+        response["metadata"] = cls._get_simple_metadata(metadata)
 
         return response
 
     @classmethod
-    def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_full_response(
+        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream full response.
         :param stream_response: stream response
@@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'conversation_id': chunk.conversation_id,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "conversation_id": chunk.conversation_id,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             yield json.dumps(response_chunk)
 
     @classmethod
-    def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_simple_response(
+        cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream simple response.
         :param stream_response: stream response
@@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'conversation_id': chunk.conversation_id,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "conversation_id": chunk.conversation_id,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, MessageEndStreamResponse):
                 sub_stream_response_dict = sub_stream_response.to_dict()
-                metadata = sub_stream_response_dict.get('metadata', {})
-                sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
+                metadata = sub_stream_response_dict.get("metadata", {})
+                sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
                 response_chunk.update(sub_stream_response_dict)
             if isinstance(sub_stream_response, ErrorStreamResponse):
                 data = cls._error_to_stream_response(sub_stream_response.err)

+ 15 - 20
api/core/app/apps/completion/app_config_manager.py

@@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
     """
     Completion App Config Entity.
     """
+
     pass
 
 
 class CompletionAppConfigManager(BaseAppConfigManager):
     @classmethod
-    def get_app_config(cls, app_model: App,
-                       app_model_config: AppModelConfig,
-                       override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
+    def get_app_config(
+        cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
+    ) -> CompletionAppConfig:
         """
         Convert app model config to completion app config
         :param app_model: app model
@@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
             app_model_config_from=config_from,
             app_model_config_id=app_model_config.id,
             app_model_config_dict=config_dict,
-            model=ModelConfigManager.convert(
-                config=config_dict
-            ),
-            prompt_template=PromptTemplateConfigManager.convert(
-                config=config_dict
-            ),
-            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
-                config=config_dict
-            ),
-            dataset=DatasetConfigManager.convert(
-                config=config_dict
-            ),
-            additional_features=cls.convert_features(config_dict, app_mode)
+            model=ModelConfigManager.convert(config=config_dict),
+            prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
+            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
+            dataset=DatasetConfigManager.convert(config=config_dict),
+            additional_features=cls.convert_features(config_dict, app_mode),
         )
 
         app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
         related_config_keys.extend(current_related_config_keys)
 
         # dataset_query_variable
-        config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
-                                                                                             config)
+        config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
+            tenant_id, app_mode, config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         # text_to_speech
@@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
         related_config_keys.extend(current_related_config_keys)
 
         # moderation validation
-        config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
-                                                                                                            config)
+        config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
+            tenant_id, config
+        )
         related_config_keys.extend(current_related_config_keys)
 
         related_config_keys = list(set(related_config_keys))

+ 80 - 95
api/core/app/apps/completion/app_generator.py

@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
 class CompletionAppGenerator(MessageBasedAppGenerator):
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
@@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
 
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
         stream: Literal[False] = False,
     ) -> dict: ...
 
-    def generate(self, app_model: App,
-                 user: Union[Account, EndUser],
-                 args: Any,
-                 invoke_from: InvokeFrom,
-                 stream: bool = True) \
-            -> Union[dict, Generator[str, None, None]]:
+    def generate(
+        self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
+    ) -> Union[dict, Generator[str, None, None]]:
         """
         Generate App response.
 
@@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         :param invoke_from: invoke from source
         :param stream: is stream
         """
-        query = args['query']
+        query = args["query"]
         if not isinstance(query, str):
-            raise ValueError('query must be a string')
+            raise ValueError("query must be a string")
 
-        query = query.replace('\x00', '')
-        inputs = args['inputs']
+        query = query.replace("\x00", "")
+        inputs = args["inputs"]
 
         extras = {}
 
@@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         conversation = None
 
         # get app model config
-        app_model_config = self._get_app_model_config(
-            app_model=app_model,
-            conversation=conversation
-        )
+        app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
 
         # validate override model config
         override_model_config_dict = None
-        if args.get('model_config'):
+        if args.get("model_config"):
             if invoke_from != InvokeFrom.DEBUGGER:
-                raise ValueError('Only in App debug mode can override model config')
+                raise ValueError("Only in App debug mode can override model config")
 
             # validate config
             override_model_config_dict = CompletionAppConfigManager.config_validate(
-                tenant_id=app_model.tenant_id,
-                config=args.get('model_config')
+                tenant_id=app_model.tenant_id, config=args.get("model_config")
             )
 
         # parse files
-        files = args['files'] if args.get('files') else []
+        files = args["files"] if args.get("files") else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(
-                files,
-                file_extra_config,
-                user
-            )
+            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
         else:
             file_objs = []
 
         # convert to app config
         app_config = CompletionAppConfigManager.get_app_config(
-            app_model=app_model,
-            app_model_config=app_model_config,
-            override_config_dict=override_model_config_dict
+            app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
         )
 
         # get tracing instance
@@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             stream=stream,
             invoke_from=invoke_from,
             extras=extras,
-            trace_manager=trace_manager
+            trace_manager=trace_manager,
         )
 
         # init generate records
-        (
-            conversation,
-            message
-        ) = self._init_generate_records(application_generate_entity)
+        (conversation, message) = self._init_generate_records(application_generate_entity)
 
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
@@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             invoke_from=application_generate_entity.invoke_from,
             conversation_id=conversation.id,
             app_mode=conversation.mode,
-            message_id=message.id
+            message_id=message.id,
         )
 
         # new thread
-        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(),
-            'application_generate_entity': application_generate_entity,
-            'queue_manager': queue_manager,
-            'message_id': message.id,
-        })
+        worker_thread = threading.Thread(
+            target=self._generate_worker,
+            kwargs={
+                "flask_app": current_app._get_current_object(),
+                "application_generate_entity": application_generate_entity,
+                "queue_manager": queue_manager,
+                "message_id": message.id,
+            },
+        )
 
         worker_thread.start()
 
@@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             stream=stream,
         )
 
-        return CompletionAppGenerateResponseConverter.convert(
-            response=response,
-            invoke_from=invoke_from
-        )
+        return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
 
-    def _generate_worker(self, flask_app: Flask,
-                         application_generate_entity: CompletionAppGenerateEntity,
-                         queue_manager: AppQueueManager,
-                         message_id: str) -> None:
+    def _generate_worker(
+        self,
+        flask_app: Flask,
+        application_generate_entity: CompletionAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        message_id: str,
+    ) -> None:
         """
         Generate worker in a new thread.
         :param flask_app: Flask app
@@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
                 runner.run(
                     application_generate_entity=application_generate_entity,
                     queue_manager=queue_manager,
-                    message=message
+                    message=message,
                 )
             except GenerateTaskStoppedException:
                 pass
             except InvokeAuthorizationError:
                 queue_manager.publish_error(
-                    InvokeAuthorizationError('Incorrect API key provided'),
-                    PublishFrom.APPLICATION_MANAGER
+                    InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
                 )
             except ValidationError as e:
                 logger.exception("Validation Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
+                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
                     logger.exception("Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:
@@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             finally:
                 db.session.close()
 
-    def generate_more_like_this(self, app_model: App,
-                                message_id: str,
-                                user: Union[Account, EndUser],
-                                invoke_from: InvokeFrom,
-                                stream: bool = True) \
-            -> Union[dict, Generator[str, None, None]]:
+    def generate_more_like_this(
+        self,
+        app_model: App,
+        message_id: str,
+        user: Union[Account, EndUser],
+        invoke_from: InvokeFrom,
+        stream: bool = True,
+    ) -> Union[dict, Generator[str, None, None]]:
         """
         Generate App response.
 
@@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         :param invoke_from: invoke from source
         :param stream: is stream
         """
-        message = db.session.query(Message).filter(
-            Message.id == message_id,
-            Message.app_id == app_model.id,
-            Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
-            Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
-            Message.from_account_id == (user.id if isinstance(user, Account) else None),
-        ).first()
+        message = (
+            db.session.query(Message)
+            .filter(
+                Message.id == message_id,
+                Message.app_id == app_model.id,
+                Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
+                Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
+                Message.from_account_id == (user.id if isinstance(user, Account) else None),
+            )
+            .first()
+        )
 
         if not message:
             raise MessageNotExistsError()
@@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
 
         app_model_config = message.app_model_config
         override_model_config_dict = app_model_config.to_dict()
-        model_dict = override_model_config_dict['model']
-        completion_params = model_dict.get('completion_params')
-        completion_params['temperature'] = 0.9
-        model_dict['completion_params'] = completion_params
-        override_model_config_dict['model'] = model_dict
+        model_dict = override_model_config_dict["model"]
+        completion_params = model_dict.get("completion_params")
+        completion_params["temperature"] = 0.9
+        model_dict["completion_params"] = completion_params
+        override_model_config_dict["model"] = model_dict
 
         # parse files
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(
-                message.files,
-                file_extra_config,
-                user
-            )
+            file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
         else:
             file_objs = []
 
         # convert to app config
         app_config = CompletionAppConfigManager.get_app_config(
-            app_model=app_model,
-            app_model_config=app_model_config,
-            override_config_dict=override_model_config_dict
+            app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
         )
 
         # init application generate entity
@@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             user_id=user.id,
             stream=stream,
             invoke_from=invoke_from,
-            extras={}
+            extras={},
         )
 
         # init generate records
-        (
-            conversation,
-            message
-        ) = self._init_generate_records(application_generate_entity)
+        (conversation, message) = self._init_generate_records(application_generate_entity)
 
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
@@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             invoke_from=application_generate_entity.invoke_from,
             conversation_id=conversation.id,
             app_mode=conversation.mode,
-            message_id=message.id
+            message_id=message.id,
         )
 
         # new thread
-        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(),
-            'application_generate_entity': application_generate_entity,
-            'queue_manager': queue_manager,
-            'message_id': message.id,
-        })
+        worker_thread = threading.Thread(
+            target=self._generate_worker,
+            kwargs={
+                "flask_app": current_app._get_current_object(),
+                "application_generate_entity": application_generate_entity,
+                "queue_manager": queue_manager,
+                "message_id": message.id,
+            },
+        )
 
         worker_thread.start()
 
@@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             stream=stream,
         )
 
-        return CompletionAppGenerateResponseConverter.convert(
-            response=response,
-            invoke_from=invoke_from
-        )
+        return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

+ 15 - 21
api/core/app/apps/completion/app_runner.py

@@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
     Completion Application Runner
     """
 
-    def run(self, application_generate_entity: CompletionAppGenerateEntity,
-            queue_manager: AppQueueManager,
-            message: Message) -> None:
+    def run(
+        self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
+    ) -> None:
         """
         Run application
         :param application_generate_entity: application generate entity
@@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
             prompt_template_entity=app_config.prompt_template,
             inputs=inputs,
             files=files,
-            query=query
+            query=query,
         )
 
         # organize all inputs and template to prompt messages
@@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
             prompt_template_entity=app_config.prompt_template,
             inputs=inputs,
             files=files,
-            query=query
+            query=query,
         )
 
         # moderation
@@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner):
                 app_generate_entity=application_generate_entity,
                 inputs=inputs,
                 query=query,
-                message_id=message.id
+                message_id=message.id,
             )
         except ModerationException as e:
             self.direct_output(
@@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner):
                 app_generate_entity=application_generate_entity,
                 prompt_messages=prompt_messages,
                 text=str(e),
-                stream=application_generate_entity.stream
+                stream=application_generate_entity.stream,
             )
             return
 
@@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
                 app_id=app_record.id,
                 external_data_tools=external_data_tools,
                 inputs=inputs,
-                query=query
+                query=query,
             )
 
         # get context from datasets
@@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
                 app_record.id,
                 message.id,
                 application_generate_entity.user_id,
-                application_generate_entity.invoke_from
+                application_generate_entity.invoke_from,
             )
 
             dataset_config = app_config.dataset
@@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
                 invoke_from=application_generate_entity.invoke_from,
                 show_retrieve_source=app_config.additional_features.show_retrieve_source,
                 hit_callback=hit_callback,
-                message_id=message.id
+                message_id=message.id,
             )
 
         # reorganize all inputs and template to prompt messages
@@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
             inputs=inputs,
             files=files,
             query=query,
-            context=context
+            context=context,
         )
 
         # check hosting moderation
         hosting_moderation_result = self.check_hosting_moderation(
             application_generate_entity=application_generate_entity,
             queue_manager=queue_manager,
-            prompt_messages=prompt_messages
+            prompt_messages=prompt_messages,
         )
 
         if hosting_moderation_result:
             return
 
         # Re-calculate the max tokens if sum(prompt_token +  max_tokens) over model token limit
-        self.recalc_llm_max_tokens(
-            model_config=application_generate_entity.model_conf,
-            prompt_messages=prompt_messages
-        )
+        self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
 
         # Invoke model
         model_instance = ModelInstance(
             provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
-            model=application_generate_entity.model_conf.model
+            model=application_generate_entity.model_conf.model,
         )
 
         db.session.close()
@@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
 
         # handle invoke result
         self._handle_invoke_result(
-            invoke_result=invoke_result,
-            queue_manager=queue_manager,
-            stream=application_generate_entity.stream
+            invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
         )
-    

+ 26 - 24
api/core/app/apps/completion/generate_response_converter.py

@@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
         :return:
         """
         response = {
-            'event': 'message',
-            'task_id': blocking_response.task_id,
-            'id': blocking_response.data.id,
-            'message_id': blocking_response.data.message_id,
-            'mode': blocking_response.data.mode,
-            'answer': blocking_response.data.answer,
-            'metadata': blocking_response.data.metadata,
-            'created_at': blocking_response.data.created_at
+            "event": "message",
+            "task_id": blocking_response.task_id,
+            "id": blocking_response.data.id,
+            "message_id": blocking_response.data.message_id,
+            "mode": blocking_response.data.mode,
+            "answer": blocking_response.data.answer,
+            "metadata": blocking_response.data.metadata,
+            "created_at": blocking_response.data.created_at,
         }
 
         return response
@@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
         """
         response = cls.convert_blocking_full_response(blocking_response)
 
-        metadata = response.get('metadata', {})
-        response['metadata'] = cls._get_simple_metadata(metadata)
+        metadata = response.get("metadata", {})
+        response["metadata"] = cls._get_simple_metadata(metadata)
 
         return response
 
     @classmethod
-    def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_full_response(
+        cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream full response.
         :param stream_response: stream response
@@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
             yield json.dumps(response_chunk)
 
     @classmethod
-    def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_simple_response(
+        cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream simple response.
         :param stream_response: stream response
@@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'message_id': chunk.message_id,
-                'created_at': chunk.created_at
+                "event": sub_stream_response.event.value,
+                "message_id": chunk.message_id,
+                "created_at": chunk.created_at,
             }
 
             if isinstance(sub_stream_response, MessageEndStreamResponse):
                 sub_stream_response_dict = sub_stream_response.to_dict()
-                metadata = sub_stream_response_dict.get('metadata', {})
-                sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
+                metadata = sub_stream_response_dict.get("metadata", {})
+                sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
                 response_chunk.update(sub_stream_response_dict)
             if isinstance(sub_stream_response, ErrorStreamResponse):
                 data = cls._error_to_stream_response(sub_stream_response.err)

+ 50 - 54
api/core/app/apps/message_based_app_generator.py

@@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
 
 
 class MessageBasedAppGenerator(BaseAppGenerator):
-
     def _handle_response(
-            self, application_generate_entity: Union[
-                ChatAppGenerateEntity,
-                CompletionAppGenerateEntity,
-                AgentChatAppGenerateEntity,
-                AdvancedChatAppGenerateEntity
-            ],
-            queue_manager: AppQueueManager,
-            conversation: Conversation,
-            message: Message,
-            user: Union[Account, EndUser],
-            stream: bool = False,
+        self,
+        application_generate_entity: Union[
+            ChatAppGenerateEntity,
+            CompletionAppGenerateEntity,
+            AgentChatAppGenerateEntity,
+            AdvancedChatAppGenerateEntity,
+        ],
+        queue_manager: AppQueueManager,
+        conversation: Conversation,
+        message: Message,
+        user: Union[Account, EndUser],
+        stream: bool = False,
     ) -> Union[
         ChatbotAppBlockingResponse,
         CompletionAppBlockingResponse,
-        Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
+        Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
     ]:
         """
         Handle response.
@@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             conversation=conversation,
             message=message,
             user=user,
-            stream=stream
+            stream=stream,
         )
 
         try:
@@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
                 logger.exception(e)
                 raise e
 
-    def _get_conversation_by_user(self, app_model: App, conversation_id: str,
-                                  user: Union[Account, EndUser]) -> Conversation:
+    def _get_conversation_by_user(
+        self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
+    ) -> Conversation:
         conversation_filter = [
             Conversation.id == conversation_id,
             Conversation.app_id == app_model.id,
-            Conversation.status == 'normal'
+            Conversation.status == "normal",
         ]
 
         if isinstance(user, Account):
@@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         if not conversation:
             raise ConversationNotExistsError()
 
-        if conversation.status != 'normal':
+        if conversation.status != "normal":
             raise ConversationCompletedError()
 
         return conversation
 
-    def _get_app_model_config(self, app_model: App,
-                              conversation: Optional[Conversation] = None) \
-            -> AppModelConfig:
+    def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
         if conversation:
-            app_model_config = db.session.query(AppModelConfig).filter(
-                AppModelConfig.id == conversation.app_model_config_id,
-                AppModelConfig.app_id == app_model.id
-            ).first()
+            app_model_config = (
+                db.session.query(AppModelConfig)
+                .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
+                .first()
+            )
 
             if not app_model_config:
                 raise AppModelConfigBrokenError()
@@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
 
         return app_model_config
 
-    def _init_generate_records(self,
-                               application_generate_entity: Union[
-                                   ChatAppGenerateEntity,
-                                   CompletionAppGenerateEntity,
-                                   AgentChatAppGenerateEntity,
-                                   AdvancedChatAppGenerateEntity
-                               ],
-                               conversation: Optional[Conversation] = None) \
-            -> tuple[Conversation, Message]:
+    def _init_generate_records(
+        self,
+        application_generate_entity: Union[
+            ChatAppGenerateEntity,
+            CompletionAppGenerateEntity,
+            AgentChatAppGenerateEntity,
+            AdvancedChatAppGenerateEntity,
+        ],
+        conversation: Optional[Conversation] = None,
+    ) -> tuple[Conversation, Message]:
         """
         Initialize generate records
         :param application_generate_entity: application generate entity
@@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         end_user_id = None
         account_id = None
         if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
-            from_source = 'api'
+            from_source = "api"
             end_user_id = application_generate_entity.user_id
         else:
-            from_source = 'console'
+            from_source = "console"
             account_id = application_generate_entity.user_id
 
         if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
@@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             model_provider = application_generate_entity.model_conf.provider
             model_id = application_generate_entity.model_conf.model
             override_model_configs = None
-            if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
-                    and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
+            if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
+                AppMode.AGENT_CHAT,
+                AppMode.CHAT,
+                AppMode.COMPLETION,
+            ]:
                 override_model_configs = app_config.app_model_config_dict
 
         # get conversation introduction
@@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
                 model_id=model_id,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 mode=app_config.app_mode.value,
-                name='New conversation',
+                name="New conversation",
                 inputs=application_generate_entity.inputs,
                 introduction=introduction,
                 system_instruction="",
                 system_instruction_tokens=0,
-                status='normal',
+                status="normal",
                 invoke_from=application_generate_entity.invoke_from.value,
                 from_source=from_source,
                 from_end_user_id=end_user_id,
@@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             answer_price_unit=0,
             provider_response_latency=0,
             total_price=0,
-            currency='USD',
+            currency="USD",
             invoke_from=application_generate_entity.invoke_from.value,
             from_source=from_source,
             from_end_user_id=end_user_id,
-            from_account_id=account_id
+            from_account_id=account_id,
         )
 
         db.session.add(message)
@@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
                 message_id=message.id,
                 type=file.type.value,
                 transfer_method=file.transfer_method.value,
-                belongs_to='user',
+                belongs_to="user",
                 url=file.url,
                 upload_file_id=file.related_id,
-                created_by_role=('account' if account_id else 'end_user'),
+                created_by_role=("account" if account_id else "end_user"),
                 created_by=account_id or end_user_id,
             )
             db.session.add(message_file)
@@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         :param conversation_id: conversation id
         :return: conversation
         """
-        conversation = (
-            db.session.query(Conversation)
-            .filter(Conversation.id == conversation_id)
-            .first()
-        )
+        conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
 
         if not conversation:
             raise ConversationNotExistsError()
@@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         :param message_id: message id
         :return: message
         """
-        message = (
-            db.session.query(Message)
-            .filter(Message.id == message_id)
-            .first()
-        )
+        message = db.session.query(Message).filter(Message.id == message_id).first()
 
         return message

+ 8 - 13
api/core/app/apps/message_based_app_queue_manager.py

@@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
 
 
 class MessageBasedAppQueueManager(AppQueueManager):
-    def __init__(self, task_id: str,
-                 user_id: str,
-                 invoke_from: InvokeFrom,
-                 conversation_id: str,
-                 app_mode: str,
-                 message_id: str) -> None:
+    def __init__(
+        self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
+    ) -> None:
         super().__init__(task_id, user_id, invoke_from)
 
         self._conversation_id = str(conversation_id)
@@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
             message_id=self._message_id,
             conversation_id=self._conversation_id,
             app_mode=self._app_mode,
-            event=event
+            event=event,
         )
 
     def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
@@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
             message_id=self._message_id,
             conversation_id=self._conversation_id,
             app_mode=self._app_mode,
-            event=event
+            event=event,
         )
 
         self._q.put(message)
 
-        if isinstance(event, QueueStopEvent
-                             | QueueErrorEvent
-                             | QueueMessageEndEvent
-                             | QueueAdvancedChatMessageEndEvent):
+        if isinstance(
+            event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
+        ):
             self.stop_listen()
 
         if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
             raise GenerateTaskStoppedException()
-

+ 6 - 12
api/core/app/apps/workflow/app_config_manager.py

@@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
     """
     Workflow App Config Entity.
     """
+
     pass
 
 
@@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
             app_id=app_model.id,
             app_mode=app_mode,
             workflow_id=workflow.id,
-            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
-                config=features_dict
-            ),
-            variables=WorkflowVariablesConfigManager.convert(
-                workflow=workflow
-            ),
-            additional_features=cls.convert_features(features_dict, app_mode)
+            sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
+            variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
+            additional_features=cls.convert_features(features_dict, app_mode),
         )
 
         return app_config
@@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
 
         # file upload validation
         config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
-            config=config,
-            is_vision=False
+            config=config, is_vision=False
         )
         related_config_keys.extend(current_related_config_keys)
 
@@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
 
         # moderation validation
         config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
-            tenant_id=tenant_id,
-            config=config,
-            only_structure_validate=only_structure_validate
+            tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
         )
         related_config_keys.extend(current_related_config_keys)
 

+ 59 - 70
api/core/app/apps/workflow/app_generator.py

@@ -34,26 +34,28 @@ logger = logging.getLogger(__name__)
 class WorkflowAppGenerator(BaseAppGenerator):
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
         stream: Literal[True] = True,
         call_depth: int = 0,
-        workflow_thread_pool_id: Optional[str] = None
+        workflow_thread_pool_id: Optional[str] = None,
     ) -> Generator[str, None, None]: ...
 
     @overload
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
         stream: Literal[False] = False,
         call_depth: int = 0,
-        workflow_thread_pool_id: Optional[str] = None
+        workflow_thread_pool_id: Optional[str] = None,
     ) -> dict: ...
 
     def generate(
@@ -65,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
         invoke_from: InvokeFrom,
         stream: bool = True,
         call_depth: int = 0,
-        workflow_thread_pool_id: Optional[str] = None
+        workflow_thread_pool_id: Optional[str] = None,
     ):
         """
         Generate App response.
@@ -79,26 +81,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
         :param call_depth: call depth
         :param workflow_thread_pool_id: workflow thread pool id
         """
-        inputs = args['inputs']
+        inputs = args["inputs"]
 
         # parse files
-        files = args['files'] if args.get('files') else []
+        files = args["files"] if args.get("files") else []
         message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
         file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
         if file_extra_config:
-            file_objs = message_file_parser.validate_and_transform_files_arg(
-                files,
-                file_extra_config,
-                user
-            )
+            file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
         else:
             file_objs = []
 
         # convert to app config
-        app_config = WorkflowAppConfigManager.get_app_config(
-            app_model=app_model,
-            workflow=workflow
-        )
+        app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
 
         # get tracing instance
         user_id = user.id if isinstance(user, Account) else user.session_id
@@ -114,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
             stream=stream,
             invoke_from=invoke_from,
             call_depth=call_depth,
-            trace_manager=trace_manager
+            trace_manager=trace_manager,
         )
         contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
 
@@ -125,18 +120,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
             application_generate_entity=application_generate_entity,
             invoke_from=invoke_from,
             stream=stream,
-            workflow_thread_pool_id=workflow_thread_pool_id
+            workflow_thread_pool_id=workflow_thread_pool_id,
         )
 
     def _generate(
-        self, *,
+        self,
+        *,
         app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         application_generate_entity: WorkflowAppGenerateEntity,
         invoke_from: InvokeFrom,
         stream: bool = True,
-        workflow_thread_pool_id: Optional[str] = None
+        workflow_thread_pool_id: Optional[str] = None,
     ) -> dict[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
@@ -154,17 +150,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
             task_id=application_generate_entity.task_id,
             user_id=application_generate_entity.user_id,
             invoke_from=application_generate_entity.invoke_from,
-            app_mode=app_model.mode
+            app_mode=app_model.mode,
         )
 
         # new thread
-        worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(), # type: ignore
-            'application_generate_entity': application_generate_entity,
-            'queue_manager': queue_manager,
-            'context': contextvars.copy_context(),
-            'workflow_thread_pool_id': workflow_thread_pool_id
-        })
+        worker_thread = threading.Thread(
+            target=self._generate_worker,
+            kwargs={
+                "flask_app": current_app._get_current_object(),  # type: ignore
+                "application_generate_entity": application_generate_entity,
+                "queue_manager": queue_manager,
+                "context": contextvars.copy_context(),
+                "workflow_thread_pool_id": workflow_thread_pool_id,
+            },
+        )
 
         worker_thread.start()
 
@@ -177,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
             stream=stream,
         )
 
-        return WorkflowAppGenerateResponseConverter.convert(
-            response=response,
-            invoke_from=invoke_from
-        )
+        return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
 
-    def single_iteration_generate(self, app_model: App,
-                                  workflow: Workflow,
-                                  node_id: str,
-                                  user: Account,
-                                  args: dict,
-                                  stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
+    def single_iteration_generate(
+        self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
+    ) -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -199,16 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
         :param stream: is stream
         """
         if not node_id:
-            raise ValueError('node_id is required')
+            raise ValueError("node_id is required")
 
-        if args.get('inputs') is None:
-            raise ValueError('inputs is required')
+        if args.get("inputs") is None:
+            raise ValueError("inputs is required")
 
         # convert to app config
-        app_config = WorkflowAppConfigManager.get_app_config(
-            app_model=app_model,
-            workflow=workflow
-        )
+        app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
 
         # init application generate entity
         application_generate_entity = WorkflowAppGenerateEntity(
@@ -219,13 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
             user_id=user.id,
             stream=stream,
             invoke_from=InvokeFrom.DEBUGGER,
-            extras={
-                "auto_generate_conversation_name": False
-            },
+            extras={"auto_generate_conversation_name": False},
             single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
-                node_id=node_id,
-                inputs=args['inputs']
-            )
+                node_id=node_id, inputs=args["inputs"]
+            ),
         )
         contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
 
@@ -235,14 +222,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
             user=user,
             invoke_from=InvokeFrom.DEBUGGER,
             application_generate_entity=application_generate_entity,
-            stream=stream
+            stream=stream,
         )
 
-    def _generate_worker(self, flask_app: Flask,
-                         application_generate_entity: WorkflowAppGenerateEntity,
-                         queue_manager: AppQueueManager,
-                         context: contextvars.Context,
-                         workflow_thread_pool_id: Optional[str] = None) -> None:
+    def _generate_worker(
+        self,
+        flask_app: Flask,
+        application_generate_entity: WorkflowAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        context: contextvars.Context,
+        workflow_thread_pool_id: Optional[str] = None,
+    ) -> None:
         """
         Generate worker in a new thread.
         :param flask_app: Flask app
@@ -259,7 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
                 runner = WorkflowAppRunner(
                     application_generate_entity=application_generate_entity,
                     queue_manager=queue_manager,
-                    workflow_thread_pool_id=workflow_thread_pool_id
+                    workflow_thread_pool_id=workflow_thread_pool_id,
                 )
 
                 runner.run()
@@ -267,14 +257,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
                 pass
             except InvokeAuthorizationError:
                 queue_manager.publish_error(
-                    InvokeAuthorizationError('Incorrect API key provided'),
-                    PublishFrom.APPLICATION_MANAGER
+                    InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
                 )
             except ValidationError as e:
                 logger.exception("Validation Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
+                if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
                     logger.exception("Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:
@@ -283,14 +272,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
             finally:
                 db.session.close()
 
-    def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
-                         workflow: Workflow,
-                         queue_manager: AppQueueManager,
-                         user: Union[Account, EndUser],
-                         stream: bool = False) -> Union[
-        WorkflowAppBlockingResponse,
-        Generator[WorkflowAppStreamResponse, None, None]
-    ]:
+    def _handle_response(
+        self,
+        application_generate_entity: WorkflowAppGenerateEntity,
+        workflow: Workflow,
+        queue_manager: AppQueueManager,
+        user: Union[Account, EndUser],
+        stream: bool = False,
+    ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
         """
         Handle response.
         :param application_generate_entity: application generate entity
@@ -306,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
             workflow=workflow,
             queue_manager=queue_manager,
             user=user,
-            stream=stream
+            stream=stream,
         )
 
         try:

+ 10 - 14
api/core/app/apps/workflow/app_queue_manager.py

@@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
 
 
 class WorkflowAppQueueManager(AppQueueManager):
-    def __init__(self, task_id: str,
-                 user_id: str,
-                 invoke_from: InvokeFrom,
-                 app_mode: str) -> None:
+    def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
         super().__init__(task_id, user_id, invoke_from)
 
         self._app_mode = app_mode
@@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager):
         :param pub_from:
         :return:
         """
-        message = WorkflowQueueMessage(
-            task_id=self._task_id,
-            app_mode=self._app_mode,
-            event=event
-        )
+        message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
 
         self._q.put(message)
 
-        if isinstance(event, QueueStopEvent
-                             | QueueErrorEvent
-                             | QueueMessageEndEvent
-                             | QueueWorkflowSucceededEvent
-                             | QueueWorkflowFailedEvent):
+        if isinstance(
+            event,
+            QueueStopEvent
+            | QueueErrorEvent
+            | QueueMessageEndEvent
+            | QueueWorkflowSucceededEvent
+            | QueueWorkflowFailedEvent,
+        ):
             self.stop_listen()
 
         if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():

+ 10 - 13
api/core/app/apps/workflow/app_runner.py

@@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
     """
 
     def __init__(
-            self,
-            application_generate_entity: WorkflowAppGenerateEntity,
-            queue_manager: AppQueueManager,
-            workflow_thread_pool_id: Optional[str] = None
+        self,
+        application_generate_entity: WorkflowAppGenerateEntity,
+        queue_manager: AppQueueManager,
+        workflow_thread_pool_id: Optional[str] = None,
     ) -> None:
         """
         :param application_generate_entity: application generate entity
@@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
 
         app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
         if not app_record:
-            raise ValueError('App not found')
+            raise ValueError("App not found")
 
         workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
         if not workflow:
-            raise ValueError('Workflow not initialized')
+            raise ValueError("Workflow not initialized")
 
         db.session.close()
 
         workflow_callbacks: list[WorkflowCallback] = []
-        if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
+        if bool(os.environ.get("DEBUG", "False").lower() == "true"):
             workflow_callbacks.append(WorkflowLoggingCallback())
 
         # if only single iteration run is requested
@@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
             graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
                 workflow=workflow,
                 node_id=self.application_generate_entity.single_iteration_run.node_id,
-                user_inputs=self.application_generate_entity.single_iteration_run.inputs
+                user_inputs=self.application_generate_entity.single_iteration_run.inputs,
             )
         else:
-
             inputs = self.application_generate_entity.inputs
             files = self.application_generate_entity.files
 
@@ -120,12 +119,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
             invoke_from=self.application_generate_entity.invoke_from,
             call_depth=self.application_generate_entity.call_depth,
             variable_pool=variable_pool,
-            thread_pool_id=self.workflow_thread_pool_id
+            thread_pool_id=self.workflow_thread_pool_id,
         )
 
-        generator = workflow_entry.run(
-            callbacks=workflow_callbacks
-        )
+        generator = workflow_entry.run(callbacks=workflow_callbacks)
 
         for event in generator:
             self._handle_event(workflow_entry, event)

+ 12 - 10
api/core/app/apps/workflow/generate_response_converter.py

@@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
         return cls.convert_blocking_full_response(blocking_response)
 
     @classmethod
-    def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_full_response(
+        cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream full response.
         :param stream_response: stream response
@@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'workflow_run_id': chunk.workflow_run_id,
+                "event": sub_stream_response.event.value,
+                "workflow_run_id": chunk.workflow_run_id,
             }
 
             if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
             yield json.dumps(response_chunk)
 
     @classmethod
-    def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-            -> Generator[str, None, None]:
+    def convert_stream_simple_response(
+        cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
+    ) -> Generator[str, None, None]:
         """
         Convert stream simple response.
         :param stream_response: stream response
@@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
             sub_stream_response = chunk.stream_response
 
             if isinstance(sub_stream_response, PingStreamResponse):
-                yield 'ping'
+                yield "ping"
                 continue
 
             response_chunk = {
-                'event': sub_stream_response.event.value,
-                'workflow_run_id': chunk.workflow_run_id,
+                "event": sub_stream_response.event.value,
+                "workflow_run_id": chunk.workflow_run_id,
             }
 
             if isinstance(sub_stream_response, ErrorStreamResponse):

+ 60 - 73
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
     """
     WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
+
     _workflow: Workflow
     _user: Union[Account, EndUser]
     _task_state: WorkflowTaskState
     _application_generate_entity: WorkflowAppGenerateEntity
     _workflow_system_variables: dict[SystemVariableKey, Any]
 
-    def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
-                 workflow: Workflow,
-                 queue_manager: AppQueueManager,
-                 user: Union[Account, EndUser],
-                 stream: bool) -> None:
+    def __init__(
+        self,
+        application_generate_entity: WorkflowAppGenerateEntity,
+        workflow: Workflow,
+        queue_manager: AppQueueManager,
+        user: Union[Account, EndUser],
+        stream: bool,
+    ) -> None:
         """
         Initialize GenerateTaskPipeline.
         :param application_generate_entity: application generate entity
@@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         self._workflow = workflow
         self._workflow_system_variables = {
             SystemVariableKey.FILES: application_generate_entity.files,
-            SystemVariableKey.USER_ID: user_id
+            SystemVariableKey.USER_ID: user_id,
         }
 
         self._task_state = WorkflowTaskState()
@@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         db.session.refresh(self._user)
         db.session.close()
 
-        generator = self._wrapper_process_stream_response(
-            trace_manager=self._application_generate_entity.trace_manager
-        )
+        generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
         if self._stream:
             return self._to_stream_response(generator)
         else:
             return self._to_blocking_response(generator)
 
-    def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-            -> WorkflowAppBlockingResponse:
+    def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
         """
         To blocking response.
         :return:
@@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                         total_tokens=stream_response.data.total_tokens,
                         total_steps=stream_response.data.total_steps,
                         created_at=int(stream_response.data.created_at),
-                        finished_at=int(stream_response.data.finished_at)
-                    )
+                        finished_at=int(stream_response.data.finished_at),
+                    ),
                 )
 
                 return response
             else:
                 continue
 
-        raise Exception('Queue listening stopped unexpectedly.')
+        raise Exception("Queue listening stopped unexpectedly.")
 
-    def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-            -> Generator[WorkflowAppStreamResponse, None, None]:
+    def _to_stream_response(
+        self, generator: Generator[StreamResponse, None, None]
+    ) -> Generator[WorkflowAppStreamResponse, None, None]:
         """
         To stream response.
         :return:
@@ -158,10 +160,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             if isinstance(stream_response, WorkflowStartStreamResponse):
                 workflow_run_id = stream_response.workflow_run_id
 
-            yield WorkflowAppStreamResponse(
-                workflow_run_id=workflow_run_id,
-                stream_response=stream_response
-            )
+            yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
 
     def _listenAudioMsg(self, publisher, task_id: str):
         if not publisher:
@@ -171,17 +170,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
         return None
 
-    def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
-            Generator[StreamResponse, None, None]:
-
+    def _wrapper_process_stream_response(
+        self, trace_manager: Optional[TraceQueueManager] = None
+    ) -> Generator[StreamResponse, None, None]:
         tts_publisher = None
         task_id = self._application_generate_entity.task_id
         tenant_id = self._application_generate_entity.app_config.tenant_id
         features_dict = self._workflow.features_dict
 
-        if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
-                'text_to_speech'].get('autoPlay') == 'enabled':
-            tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
+        if (
+            features_dict.get("text_to_speech")
+            and features_dict["text_to_speech"].get("enabled")
+            and features_dict["text_to_speech"].get("autoPlay") == "enabled"
+        ):
+            tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
 
         for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
             while True:
@@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             except Exception as e:
                 logger.error(e)
                 break
-        yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
-
+        yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
 
     def _process_stream_response(
         self,
         tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
-        trace_manager: Optional[TraceQueueManager] = None
+        trace_manager: Optional[TraceQueueManager] = None,
     ) -> Generator[StreamResponse, None, None]:
         """
         Process stream response.
@@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 # init workflow run
                 workflow_run = self._handle_workflow_run_start()
                 yield self._workflow_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                 )
             elif isinstance(event, QueueNodeStartedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
-                workflow_node_execution = self._handle_node_execution_start(
-                    workflow_run=workflow_run,
-                    event=event
-                )
+                workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
 
                 response = self._workflow_node_start_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution
+                    workflow_node_execution=workflow_node_execution,
                 )
 
                 if response:
@@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution
+                    workflow_node_execution=workflow_node_execution,
                 )
 
                 if response:
@@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution
+                    workflow_node_execution=workflow_node_execution,
                 )
 
                 if response:
                     yield response
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_parallel_branch_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_parallel_branch_finished_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueIterationStartEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_iteration_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueIterationNextEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_iteration_next_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueIterationCompletedEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 yield self._workflow_iteration_completed_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run,
-                    event=event
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
                 )
             elif isinstance(event, QueueWorkflowSucceededEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 if not graph_runtime_state:
-                    raise Exception('Graph runtime state not initialized.')
+                    raise Exception("Graph runtime state not initialized.")
 
                 workflow_run = self._handle_workflow_run_success(
                     workflow_run=workflow_run,
                     start_at=graph_runtime_state.start_at,
                     total_tokens=graph_runtime_state.total_tokens,
                     total_steps=graph_runtime_state.node_run_steps,
-                    outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
+                    outputs=json.dumps(event.outputs)
+                    if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
+                    else None,
                     conversation_id=None,
                     trace_manager=trace_manager,
                 )
@@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 self._save_workflow_app_log(workflow_run)
 
                 yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                 )
             elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
                 if not workflow_run:
-                    raise Exception('Workflow run not initialized.')
+                    raise Exception("Workflow run not initialized.")
 
                 if not graph_runtime_state:
-                    raise Exception('Graph runtime state not initialized.')
+                    raise Exception("Graph runtime state not initialized.")
 
                 workflow_run = self._handle_workflow_run_failed(
                     workflow_run=workflow_run,
                     start_at=graph_runtime_state.start_at,
                     total_tokens=graph_runtime_state.total_tokens,
                     total_steps=graph_runtime_state.node_run_steps,
-                    status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
+                    status=WorkflowRunStatus.FAILED
+                    if isinstance(event, QueueWorkflowFailedEvent)
+                    else WorkflowRunStatus.STOPPED,
                     error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
                     conversation_id=None,
                     trace_manager=trace_manager,
@@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 self._save_workflow_app_log(workflow_run)
 
                 yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_run=workflow_run
+                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                 )
             elif isinstance(event, QueueTextChunkEvent):
                 delta_text = event.text
@@ -394,7 +383,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         if tts_publisher:
             tts_publisher.publish(None)
 
-
     def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
         """
         Save workflow app log.
@@ -417,7 +405,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         workflow_app_log.workflow_id = workflow_run.workflow_id
         workflow_app_log.workflow_run_id = workflow_run.id
         workflow_app_log.created_from = created_from.value
-        workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
+        workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
         workflow_app_log.created_by = self._user.id
 
         db.session.add(workflow_app_log)
@@ -431,8 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         :return:
         """
         response = TextChunkStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            data=TextChunkStreamResponse.Data(text=text)
+            task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
         )
 
         return response

+ 71 - 79
api/core/app/apps/workflow_app_runner.py

@@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner):
         """
         Init graph
         """
-        if 'nodes' not in graph_config or 'edges' not in graph_config:
-            raise ValueError('nodes or edges not found in workflow graph')
+        if "nodes" not in graph_config or "edges" not in graph_config:
+            raise ValueError("nodes or edges not found in workflow graph")
 
-        if not isinstance(graph_config.get('nodes'), list):
-            raise ValueError('nodes in workflow graph must be a list')
+        if not isinstance(graph_config.get("nodes"), list):
+            raise ValueError("nodes in workflow graph must be a list")
 
-        if not isinstance(graph_config.get('edges'), list):
-            raise ValueError('edges in workflow graph must be a list')
+        if not isinstance(graph_config.get("edges"), list):
+            raise ValueError("edges in workflow graph must be a list")
         # init graph
-        graph = Graph.init(
-            graph_config=graph_config
-        )
+        graph = Graph.init(graph_config=graph_config)
 
         if not graph:
-            raise ValueError('graph not found in workflow')
-        
+            raise ValueError("graph not found in workflow")
+
         return graph
 
     def _get_graph_and_variable_pool_of_single_iteration(
-            self, 
-            workflow: Workflow,
-            node_id: str,
-            user_inputs: dict,
-        ) -> tuple[Graph, VariablePool]:
+        self,
+        workflow: Workflow,
+        node_id: str,
+        user_inputs: dict,
+    ) -> tuple[Graph, VariablePool]:
         """
         Get variable pool of single iteration
         """
         # fetch workflow graph
         graph_config = workflow.graph_dict
         if not graph_config:
-            raise ValueError('workflow graph not found')
-        
+            raise ValueError("workflow graph not found")
+
         graph_config = cast(dict[str, Any], graph_config)
 
-        if 'nodes' not in graph_config or 'edges' not in graph_config:
-            raise ValueError('nodes or edges not found in workflow graph')
+        if "nodes" not in graph_config or "edges" not in graph_config:
+            raise ValueError("nodes or edges not found in workflow graph")
 
-        if not isinstance(graph_config.get('nodes'), list):
-            raise ValueError('nodes in workflow graph must be a list')
+        if not isinstance(graph_config.get("nodes"), list):
+            raise ValueError("nodes in workflow graph must be a list")
 
-        if not isinstance(graph_config.get('edges'), list):
-            raise ValueError('edges in workflow graph must be a list')
+        if not isinstance(graph_config.get("edges"), list):
+            raise ValueError("edges in workflow graph must be a list")
 
         # filter nodes only in iteration
         node_configs = [
-            node for node in graph_config.get('nodes', []) 
-            if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
+            node
+            for node in graph_config.get("nodes", [])
+            if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
         ]
 
-        graph_config['nodes'] = node_configs
+        graph_config["nodes"] = node_configs
 
-        node_ids = [node.get('id') for node in node_configs]
+        node_ids = [node.get("id") for node in node_configs]
 
         # filter edges only in iteration
         edge_configs = [
-            edge for edge in graph_config.get('edges', []) 
-            if (edge.get('source') is None or edge.get('source') in node_ids) 
-            and (edge.get('target') is None or edge.get('target') in node_ids) 
+            edge
+            for edge in graph_config.get("edges", [])
+            if (edge.get("source") is None or edge.get("source") in node_ids)
+            and (edge.get("target") is None or edge.get("target") in node_ids)
         ]
 
-        graph_config['edges'] = edge_configs
+        graph_config["edges"] = edge_configs
 
         # init graph
-        graph = Graph.init(
-            graph_config=graph_config,
-            root_node_id=node_id
-        )
+        graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
 
         if not graph:
-            raise ValueError('graph not found in workflow')
-        
+            raise ValueError("graph not found in workflow")
+
         # fetch node config from node id
         iteration_node_config = None
         for node in node_configs:
-            if node.get('id') == node_id:
+            if node.get("id") == node_id:
                 iteration_node_config = node
                 break
 
         if not iteration_node_config:
-            raise ValueError('iteration node id not found in workflow graph')
-        
+            raise ValueError("iteration node id not found in workflow graph")
+
         # Get node class
-        node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
+        node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
         node_cls = node_classes.get(node_type)
         node_cls = cast(type[BaseNode], node_cls)
 
@@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner):
 
         try:
             variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
-                graph_config=workflow.graph_dict, 
-                config=iteration_node_config
+                graph_config=workflow.graph_dict, config=iteration_node_config
             )
         except NotImplementedError:
             variable_mapping = {}
@@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner):
             variable_pool=variable_pool,
             tenant_id=workflow.tenant_id,
             node_type=node_type,
-            node_data=IterationNodeData(**iteration_node_config.get('data', {}))
+            node_data=IterationNodeData(**iteration_node_config.get("data", {})),
         )
 
         return graph, variable_pool
@@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner):
         """
         if isinstance(event, GraphRunStartedEvent):
             self._publish_event(
-                QueueWorkflowStartedEvent(
-                    graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
-                )
+                QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
             )
         elif isinstance(event, GraphRunSucceededEvent):
-            self._publish_event(
-                QueueWorkflowSucceededEvent(outputs=event.outputs)
-            )
+            self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
         elif isinstance(event, GraphRunFailedEvent):
-            self._publish_event(
-                QueueWorkflowFailedEvent(error=event.error)
-            )
+            self._publish_event(QueueWorkflowFailedEvent(error=event.error))
         elif isinstance(event, NodeRunStartedEvent):
             self._publish_event(
                 QueueNodeStartedEvent(
@@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner):
                     start_at=event.route_node_state.start_at,
                     node_run_index=event.route_node_state.index,
                     predecessor_node_id=event.predecessor_node_id,
-                    in_iteration_id=event.in_iteration_id
+                    in_iteration_id=event.in_iteration_id,
                 )
             )
         elif isinstance(event, NodeRunSucceededEvent):
@@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner):
                     parent_parallel_start_node_id=event.parent_parallel_start_node_id,
                     start_at=event.route_node_state.start_at,
                     inputs=event.route_node_state.node_run_result.inputs
-                    if event.route_node_state.node_run_result else {},
+                    if event.route_node_state.node_run_result
+                    else {},
                     process_data=event.route_node_state.node_run_result.process_data
-                    if event.route_node_state.node_run_result else {},
+                    if event.route_node_state.node_run_result
+                    else {},
                     outputs=event.route_node_state.node_run_result.outputs
-                    if event.route_node_state.node_run_result else {},
+                    if event.route_node_state.node_run_result
+                    else {},
                     execution_metadata=event.route_node_state.node_run_result.metadata
-                    if event.route_node_state.node_run_result else {},
-                    in_iteration_id=event.in_iteration_id
+                    if event.route_node_state.node_run_result
+                    else {},
+                    in_iteration_id=event.in_iteration_id,
                 )
             )
         elif isinstance(event, NodeRunFailedEvent):
@@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner):
                     parent_parallel_start_node_id=event.parent_parallel_start_node_id,
                     start_at=event.route_node_state.start_at,
                     inputs=event.route_node_state.node_run_result.inputs
-                    if event.route_node_state.node_run_result else {},
+                    if event.route_node_state.node_run_result
+                    else {},
                     process_data=event.route_node_state.node_run_result.process_data
-                    if event.route_node_state.node_run_result else {},
+                    if event.route_node_state.node_run_result
+                    else {},
                     outputs=event.route_node_state.node_run_result.outputs
-                    if event.route_node_state.node_run_result else {},
-                    error=event.route_node_state.node_run_result.error
                     if event.route_node_state.node_run_result
-                       and event.route_node_state.node_run_result.error
+                    else {},
+                    error=event.route_node_state.node_run_result.error
+                    if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
                     else "Unknown error",
-                    in_iteration_id=event.in_iteration_id
+                    in_iteration_id=event.in_iteration_id,
                 )
             )
         elif isinstance(event, NodeRunStreamChunkEvent):
@@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner):
                 QueueTextChunkEvent(
                     text=event.chunk_content,
                     from_variable_selector=event.from_variable_selector,
-                    in_iteration_id=event.in_iteration_id
+                    in_iteration_id=event.in_iteration_id,
                 )
             )
         elif isinstance(event, NodeRunRetrieverResourceEvent):
             self._publish_event(
                 QueueRetrieverResourcesEvent(
-                    retriever_resources=event.retriever_resources,
-                    in_iteration_id=event.in_iteration_id
+                    retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
                 )
             )
         elif isinstance(event, ParallelBranchRunStartedEvent):
@@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner):
                     parallel_start_node_id=event.parallel_start_node_id,
                     parent_parallel_id=event.parent_parallel_id,
                     parent_parallel_start_node_id=event.parent_parallel_start_node_id,
-                    in_iteration_id=event.in_iteration_id
+                    in_iteration_id=event.in_iteration_id,
                 )
             )
         elif isinstance(event, ParallelBranchRunSucceededEvent):
@@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner):
                     parallel_start_node_id=event.parallel_start_node_id,
                     parent_parallel_id=event.parent_parallel_id,
                     parent_parallel_start_node_id=event.parent_parallel_start_node_id,
-                    in_iteration_id=event.in_iteration_id
+                    in_iteration_id=event.in_iteration_id,
                 )
             )
         elif isinstance(event, ParallelBranchRunFailedEvent):
@@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner):
                     parent_parallel_id=event.parent_parallel_id,
                     parent_parallel_start_node_id=event.parent_parallel_start_node_id,
                     in_iteration_id=event.in_iteration_id,
-                    error=event.error
+                    error=event.error,
                 )
             )
         elif isinstance(event, IterationRunStartedEvent):
@@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner):
                     node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
                     inputs=event.inputs,
                     predecessor_node_id=event.predecessor_node_id,
-                    metadata=event.metadata
+                    metadata=event.metadata,
                 )
             )
         elif isinstance(event, IterationRunNextEvent):
@@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner):
                     outputs=event.outputs,
                     metadata=event.metadata,
                     steps=event.steps,
-                    error=event.error if isinstance(event, IterationRunFailedEvent) else None
+                    error=event.error if isinstance(event, IterationRunFailedEvent) else None,
                 )
             )
 
@@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner):
 
         # return workflow
         return workflow
-    
+
     def _publish_event(self, event: AppQueueEvent) -> None:
-        self.queue_manager.publish(
-            event,
-            PublishFrom.APPLICATION_MANAGER
-        )
+        self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

+ 85 - 115
api/core/app/apps/workflow_logging_callback.py

@@ -30,169 +30,145 @@ _TEXT_COLOR_MAPPING = {
 
 
 class WorkflowLoggingCallback(WorkflowCallback):
-
     def __init__(self) -> None:
         self.current_node_id = None
 
-    def on_event(
-            self,
-            event: GraphEngineEvent
-    ) -> None:
+    def on_event(self, event: GraphEngineEvent) -> None:
         if isinstance(event, GraphRunStartedEvent):
-            self.print_text("\n[GraphRunStartedEvent]", color='pink')
+            self.print_text("\n[GraphRunStartedEvent]", color="pink")
         elif isinstance(event, GraphRunSucceededEvent):
-            self.print_text("\n[GraphRunSucceededEvent]", color='green')
+            self.print_text("\n[GraphRunSucceededEvent]", color="green")
         elif isinstance(event, GraphRunFailedEvent):
-            self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
+            self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
         elif isinstance(event, NodeRunStartedEvent):
-            self.on_workflow_node_execute_started(
-                event=event
-            )
+            self.on_workflow_node_execute_started(event=event)
         elif isinstance(event, NodeRunSucceededEvent):
-            self.on_workflow_node_execute_succeeded(
-                event=event
-            )
+            self.on_workflow_node_execute_succeeded(event=event)
         elif isinstance(event, NodeRunFailedEvent):
-            self.on_workflow_node_execute_failed(
-                event=event
-            )
+            self.on_workflow_node_execute_failed(event=event)
         elif isinstance(event, NodeRunStreamChunkEvent):
-            self.on_node_text_chunk(
-                event=event
-            )
+            self.on_node_text_chunk(event=event)
         elif isinstance(event, ParallelBranchRunStartedEvent):
-            self.on_workflow_parallel_started(
-                event=event
-            )
+            self.on_workflow_parallel_started(event=event)
         elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
-            self.on_workflow_parallel_completed(
-                event=event
-            )
+            self.on_workflow_parallel_completed(event=event)
         elif isinstance(event, IterationRunStartedEvent):
-            self.on_workflow_iteration_started(
-                event=event
-            )
+            self.on_workflow_iteration_started(event=event)
         elif isinstance(event, IterationRunNextEvent):
-            self.on_workflow_iteration_next(
-                event=event
-            )
+            self.on_workflow_iteration_next(event=event)
         elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
-            self.on_workflow_iteration_completed(
-                event=event
-            )
+            self.on_workflow_iteration_completed(event=event)
         else:
-            self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
+            self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
 
-    def on_workflow_node_execute_started(
-            self,
-            event: NodeRunStartedEvent
-    ) -> None:
+    def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
         """
         Workflow node execute started
         """
-        self.print_text("\n[NodeRunStartedEvent]", color='yellow')
-        self.print_text(f"Node ID: {event.node_id}", color='yellow')
-        self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
-        self.print_text(f"Type: {event.node_type.value}", color='yellow')
+        self.print_text("\n[NodeRunStartedEvent]", color="yellow")
+        self.print_text(f"Node ID: {event.node_id}", color="yellow")
+        self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
+        self.print_text(f"Type: {event.node_type.value}", color="yellow")
 
-    def on_workflow_node_execute_succeeded(
-            self,
-            event: NodeRunSucceededEvent
-    ) -> None:
+    def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
         """
         Workflow node execute succeeded
         """
         route_node_state = event.route_node_state
 
-        self.print_text("\n[NodeRunSucceededEvent]", color='green')
-        self.print_text(f"Node ID: {event.node_id}", color='green')
-        self.print_text(f"Node Title: {event.node_data.title}", color='green')
-        self.print_text(f"Type: {event.node_type.value}", color='green')
+        self.print_text("\n[NodeRunSucceededEvent]", color="green")
+        self.print_text(f"Node ID: {event.node_id}", color="green")
+        self.print_text(f"Node Title: {event.node_data.title}", color="green")
+        self.print_text(f"Type: {event.node_type.value}", color="green")
 
         if route_node_state.node_run_result:
             node_run_result = route_node_state.node_run_result
-            self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
-                            color='green')
+            self.print_text(
+                f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
+            )
             self.print_text(
                 f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
-                color='green')
-            self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
-                            color='green')
+                color="green",
+            )
+            self.print_text(
+                f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
+                color="green",
+            )
             self.print_text(
                 f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
-                color='green')
+                color="green",
+            )
 
-    def on_workflow_node_execute_failed(
-            self,
-            event: NodeRunFailedEvent
-    ) -> None:
+    def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
         """
         Workflow node execute failed
         """
         route_node_state = event.route_node_state
 
-        self.print_text("\n[NodeRunFailedEvent]", color='red')
-        self.print_text(f"Node ID: {event.node_id}", color='red')
-        self.print_text(f"Node Title: {event.node_data.title}", color='red')
-        self.print_text(f"Type: {event.node_type.value}", color='red')
+        self.print_text("\n[NodeRunFailedEvent]", color="red")
+        self.print_text(f"Node ID: {event.node_id}", color="red")
+        self.print_text(f"Node Title: {event.node_data.title}", color="red")
+        self.print_text(f"Type: {event.node_type.value}", color="red")
 
         if route_node_state.node_run_result:
             node_run_result = route_node_state.node_run_result
-            self.print_text(f"Error: {node_run_result.error}", color='red')
-            self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
-                            color='red')
+            self.print_text(f"Error: {node_run_result.error}", color="red")
+            self.print_text(
+                f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
+            )
             self.print_text(
                 f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
-                color='red')
-            self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
-                            color='red')
+                color="red",
+            )
+            self.print_text(
+                f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
+            )
 
-    def on_node_text_chunk(
-            self,
-            event: NodeRunStreamChunkEvent
-    ) -> None:
+    def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
         """
         Publish text chunk
         """
         route_node_state = event.route_node_state
         if not self.current_node_id or self.current_node_id != route_node_state.node_id:
             self.current_node_id = route_node_state.node_id
-            self.print_text('\n[NodeRunStreamChunkEvent]')
+            self.print_text("\n[NodeRunStreamChunkEvent]")
             self.print_text(f"Node ID: {route_node_state.node_id}")
 
             node_run_result = route_node_state.node_run_result
             if node_run_result:
                 self.print_text(
-                    f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
+                    f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
+                )
 
         self.print_text(event.chunk_content, color="pink", end="")
 
-    def on_workflow_parallel_started(
-            self,
-            event: ParallelBranchRunStartedEvent
-    ) -> None:
+    def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
         """
         Publish parallel started
         """
-        self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
-        self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
-        self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
+        self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
+        self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
+        self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
         if event.in_iteration_id:
-            self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
+            self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
 
     def on_workflow_parallel_completed(
-            self,
-            event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
+        self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
     ) -> None:
         """
         Publish parallel completed
         """
         if isinstance(event, ParallelBranchRunSucceededEvent):
-            color = 'blue'
+            color = "blue"
         elif isinstance(event, ParallelBranchRunFailedEvent):
-            color = 'red'
-
-        self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
+            color = "red"
+
+        self.print_text(
+            "\n[ParallelBranchRunSucceededEvent]"
+            if isinstance(event, ParallelBranchRunSucceededEvent)
+            else "\n[ParallelBranchRunFailedEvent]",
+            color=color,
+        )
         self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
         self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
         if event.in_iteration_id:
@@ -201,43 +177,37 @@ class WorkflowLoggingCallback(WorkflowCallback):
         if isinstance(event, ParallelBranchRunFailedEvent):
             self.print_text(f"Error: {event.error}", color=color)
 
-    def on_workflow_iteration_started(
-            self,
-            event: IterationRunStartedEvent
-    ) -> None:
+    def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
         """
         Publish iteration started
         """
-        self.print_text("\n[IterationRunStartedEvent]", color='blue')
-        self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
+        self.print_text("\n[IterationRunStartedEvent]", color="blue")
+        self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
 
-    def on_workflow_iteration_next(
-            self,
-            event: IterationRunNextEvent
-    ) -> None:
+    def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
         """
         Publish iteration next
         """
-        self.print_text("\n[IterationRunNextEvent]", color='blue')
-        self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
-        self.print_text(f"Iteration Index: {event.index}", color='blue')
+        self.print_text("\n[IterationRunNextEvent]", color="blue")
+        self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
+        self.print_text(f"Iteration Index: {event.index}", color="blue")
 
-    def on_workflow_iteration_completed(
-            self,
-            event: IterationRunSucceededEvent | IterationRunFailedEvent
-    ) -> None:
+    def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
         """
         Publish iteration completed
         """
-        self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
-        self.print_text(f"Node ID: {event.iteration_id}", color='blue')
+        self.print_text(
+            "\n[IterationRunSucceededEvent]"
+            if isinstance(event, IterationRunSucceededEvent)
+            else "\n[IterationRunFailedEvent]",
+            color="blue",
+        )
+        self.print_text(f"Node ID: {event.iteration_id}", color="blue")
 
-    def print_text(
-            self, text: str, color: Optional[str] = None, end: str = "\n"
-    ) -> None:
+    def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
         """Print text with highlighting and no end characters."""
         text_to_print = self._get_colored_text(text, color) if color else text
-        print(f'{text_to_print}', end=end)
+        print(f"{text_to_print}", end=end)
 
     def _get_colored_text(self, text: str, color: str) -> str:
         """Get colored text."""

+ 23 - 11
api/core/app/entities/app_invoke_entities.py

@@ -15,13 +15,14 @@ class InvokeFrom(Enum):
     """
     Invoke From.
     """
-    SERVICE_API = 'service-api'
-    WEB_APP = 'web-app'
-    EXPLORE = 'explore'
-    DEBUGGER = 'debugger'
+
+    SERVICE_API = "service-api"
+    WEB_APP = "web-app"
+    EXPLORE = "explore"
+    DEBUGGER = "debugger"
 
     @classmethod
-    def value_of(cls, value: str) -> 'InvokeFrom':
+    def value_of(cls, value: str) -> "InvokeFrom":
         """
         Get value of given mode.
 
@@ -31,7 +32,7 @@ class InvokeFrom(Enum):
         for mode in cls:
             if mode.value == value:
                 return mode
-        raise ValueError(f'invalid invoke from value {value}')
+        raise ValueError(f"invalid invoke from value {value}")
 
     def to_source(self) -> str:
         """
@@ -40,21 +41,22 @@ class InvokeFrom(Enum):
         :return: source
         """
         if self == InvokeFrom.WEB_APP:
-            return 'web_app'
+            return "web_app"
         elif self == InvokeFrom.DEBUGGER:
-            return 'dev'
+            return "dev"
         elif self == InvokeFrom.EXPLORE:
-            return 'explore_app'
+            return "explore_app"
         elif self == InvokeFrom.SERVICE_API:
-            return 'api'
+            return "api"
 
-        return 'dev'
+        return "dev"
 
 
 class ModelConfigWithCredentialsEntity(BaseModel):
     """
     Model Config With Credentials Entity.
     """
+
     provider: str
     model: str
     model_schema: AIModelEntity
@@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
     """
     App Generate Entity.
     """
+
     task_id: str
 
     # app config
@@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
     """
     Chat Application Generate Entity.
     """
+
     # app config
     app_config: EasyUIBasedAppConfig
     model_conf: ModelConfigWithCredentialsEntity
@@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
     """
     Chat Application Generate Entity.
     """
+
     conversation_id: Optional[str] = None
 
 
@@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
     """
     Completion Application Generate Entity.
     """
+
     pass
 
 
@@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
     """
     Agent Chat Application Generate Entity.
     """
+
     conversation_id: Optional[str] = None
 
 
@@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
     """
     Advanced Chat Application Generate Entity.
     """
+
     # app config
     app_config: WorkflowUIBasedAppConfig
 
@@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
         """
         Single Iteration Run Entity.
         """
+
         node_id: str
         inputs: dict
 
     single_iteration_run: Optional[SingleIterationRunEntity] = None
 
+
 class WorkflowAppGenerateEntity(AppGenerateEntity):
     """
     Workflow Application Generate Entity.
     """
+
     # app config
     app_config: WorkflowUIBasedAppConfig
 
@@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
         """
         Single Iteration Run Entity.
         """
+
         node_id: str
         inputs: dict
 

+ 44 - 10
api/core/app/entities/queue_entities.py

@@ -14,6 +14,7 @@ class QueueEvent(str, Enum):
     """
     QueueEvent enum
     """
+
     LLM_CHUNK = "llm_chunk"
     TEXT_CHUNK = "text_chunk"
     AGENT_MESSAGE = "agent_message"
@@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel):
     """
     QueueEvent abstract entity
     """
+
     event: QueueEvent
 
 
@@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent):
     QueueLLMChunkEvent entity
     Only for basic mode apps
     """
+
     event: QueueEvent = QueueEvent.LLM_CHUNK
     chunk: LLMResultChunk
 
+
 class QueueIterationStartEvent(AppQueueEvent):
     """
     QueueIterationStartEvent entity
     """
+
     event: QueueEvent = QueueEvent.ITERATION_START
     node_execution_id: str
     node_id: str
@@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent):
     predecessor_node_id: Optional[str] = None
     metadata: Optional[dict[str, Any]] = None
 
+
 class QueueIterationNextEvent(AppQueueEvent):
     """
     QueueIterationNextEvent entity
     """
+
     event: QueueEvent = QueueEvent.ITERATION_NEXT
 
     index: int
@@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent):
     """parent parallel start node id if node is in parallel"""
 
     node_run_index: int
-    output: Optional[Any] = None # output for the current iteration
+    output: Optional[Any] = None  # output for the current iteration
 
-    @field_validator('output', mode='before')
+    @field_validator("output", mode="before")
     @classmethod
     def set_output(cls, v):
         """
@@ -113,12 +120,14 @@ class QueueIterationNextEvent(AppQueueEvent):
             return None
         if isinstance(v, int | float | str | bool | dict | list):
             return v
-        raise ValueError('output must be a valid type')
+        raise ValueError("output must be a valid type")
+
 
 class QueueIterationCompletedEvent(AppQueueEvent):
     """
     QueueIterationCompletedEvent entity
     """
+
     event: QueueEvent = QueueEvent.ITERATION_COMPLETED
 
     node_execution_id: str
@@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
     parent_parallel_start_node_id: Optional[str] = None
     """parent parallel start node id if node is in parallel"""
     start_at: datetime
-    
+
     node_run_index: int
     inputs: Optional[dict[str, Any]] = None
     outputs: Optional[dict[str, Any]] = None
@@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent):
     """
     QueueTextChunkEvent entity
     """
+
     event: QueueEvent = QueueEvent.TEXT_CHUNK
     text: str
     from_variable_selector: Optional[list[str]] = None
@@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent):
     """
     QueueMessageEvent entity
     """
+
     event: QueueEvent = QueueEvent.AGENT_MESSAGE
     chunk: LLMResultChunk
 
-    
+
 class QueueMessageReplaceEvent(AppQueueEvent):
     """
     QueueMessageReplaceEvent entity
     """
+
     event: QueueEvent = QueueEvent.MESSAGE_REPLACE
     text: str
 
@@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
     """
     QueueRetrieverResourcesEvent entity
     """
+
     event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
     retriever_resources: list[dict]
     in_iteration_id: Optional[str] = None
@@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
     """
     QueueAnnotationReplyEvent entity
     """
+
     event: QueueEvent = QueueEvent.ANNOTATION_REPLY
     message_annotation_id: str
 
@@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent):
     """
     QueueMessageEndEvent entity
     """
+
     event: QueueEvent = QueueEvent.MESSAGE_END
     llm_result: Optional[LLMResult] = None
 
@@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
     """
     QueueAdvancedChatMessageEndEvent entity
     """
+
     event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
 
 
@@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
     """
     QueueWorkflowStartedEvent entity
     """
+
     event: QueueEvent = QueueEvent.WORKFLOW_STARTED
     graph_runtime_state: GraphRuntimeState
 
@@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
     """
     QueueWorkflowSucceededEvent entity
     """
+
     event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
     outputs: Optional[dict[str, Any]] = None
 
@@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
     """
     QueueWorkflowFailedEvent entity
     """
+
     event: QueueEvent = QueueEvent.WORKFLOW_FAILED
     error: str
 
@@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
     """
     QueueNodeStartedEvent entity
     """
+
     event: QueueEvent = QueueEvent.NODE_STARTED
 
     node_execution_id: str
@@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
     """
     QueueNodeSucceededEvent entity
     """
+
     event: QueueEvent = QueueEvent.NODE_SUCCEEDED
 
     node_execution_id: str
@@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
     """
     QueueNodeFailedEvent entity
     """
+
     event: QueueEvent = QueueEvent.NODE_FAILED
 
     node_execution_id: str
@@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
     """
     QueueAgentThoughtEvent entity
     """
+
     event: QueueEvent = QueueEvent.AGENT_THOUGHT
     agent_thought_id: str
 
@@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent):
     """
     QueueAgentThoughtEvent entity
     """
+
     event: QueueEvent = QueueEvent.MESSAGE_FILE
     message_file_id: str
 
@@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent):
     """
     QueueErrorEvent entity
     """
+
     event: QueueEvent = QueueEvent.ERROR
     error: Any = None
 
@@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent):
     """
     QueuePingEvent entity
     """
+
     event: QueueEvent = QueueEvent.PING
 
 
@@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent):
     """
     QueueStopEvent entity
     """
+
     class StopBy(Enum):
         """
         Stop by enum
         """
+
         USER_MANUAL = "user-manual"
         ANNOTATION_REPLY = "annotation-reply"
         OUTPUT_MODERATION = "output-moderation"
@@ -365,19 +393,20 @@ class QueueStopEvent(AppQueueEvent):
         To stop reason
         """
         reason_mapping = {
-            QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
-            QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
-            QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
-            QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
+            QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.",
+            QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.",
+            QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.",
+            QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.",
         }
 
-        return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
+        return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
 
 
 class QueueMessage(BaseModel):
     """
     QueueMessage abstract entity
     """
+
     task_id: str
     app_mode: str
     event: AppQueueEvent
@@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage):
     """
     MessageQueueMessage entity
     """
+
     message_id: str
     conversation_id: str
 
@@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage):
     """
     WorkflowQueueMessage entity
     """
+
     pass
 
 
@@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
     """
     QueueParallelBranchRunStartedEvent entity
     """
+
     event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
 
     parallel_id: str
@@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
     """
     QueueParallelBranchRunSucceededEvent entity
     """
+
     event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
 
     parallel_id: str
@@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
     """
     QueueParallelBranchRunFailedEvent entity
     """
+
     event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
 
     parallel_id: str

+ 37 - 3
api/core/app/entities/task_entities.py

@@ -12,6 +12,7 @@ class TaskState(BaseModel):
     """
     TaskState entity
     """
+
     metadata: dict = {}
 
 
@@ -19,6 +20,7 @@ class EasyUITaskState(TaskState):
     """
     EasyUITaskState entity
     """
+
     llm_result: LLMResult
 
 
@@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState):
     """
     WorkflowTaskState entity
     """
+
     answer: str = ""
 
 
@@ -33,6 +36,7 @@ class StreamEvent(Enum):
     """
     Stream event
     """
+
     PING = "ping"
     ERROR = "error"
     MESSAGE = "message"
@@ -60,6 +64,7 @@ class StreamResponse(BaseModel):
     """
     StreamResponse entity
     """
+
     event: StreamEvent
     task_id: str
 
@@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse):
     """
     ErrorStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.ERROR
     err: Exception
     model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -80,6 +86,7 @@ class MessageStreamResponse(StreamResponse):
     """
     MessageStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.MESSAGE
     id: str
     answer: str
@@ -89,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse):
     """
     MessageStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.TTS_MESSAGE
     audio: str
 
@@ -97,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse):
     """
     MessageStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.TTS_MESSAGE_END
     audio: str
 
@@ -105,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse):
     """
     MessageEndStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.MESSAGE_END
     id: str
     metadata: dict = {}
@@ -114,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse):
     """
     MessageFileStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.MESSAGE_FILE
     id: str
     type: str
@@ -125,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse):
     """
     MessageReplaceStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.MESSAGE_REPLACE
     answer: str
 
@@ -133,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse):
     """
     AgentThoughtStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.AGENT_THOUGHT
     id: str
     position: int
@@ -148,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse):
     """
     AgentMessageStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.AGENT_MESSAGE
     id: str
     answer: str
@@ -162,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         workflow_id: str
         sequence_number: int
@@ -182,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         workflow_id: str
         sequence_number: int
@@ -210,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         node_id: str
         node_type: str
@@ -249,7 +266,7 @@ class NodeStartStreamResponse(StreamResponse):
                 "parent_parallel_id": self.data.parent_parallel_id,
                 "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
                 "iteration_id": self.data.iteration_id,
-            }
+            },
         }
 
 
@@ -262,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         node_id: str
         node_type: str
@@ -315,9 +333,9 @@ class NodeFinishStreamResponse(StreamResponse):
                 "parent_parallel_id": self.data.parent_parallel_id,
                 "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
                 "iteration_id": self.data.iteration_id,
-            }
+            },
         }
-    
+
 
 class ParallelBranchStartStreamResponse(StreamResponse):
     """
@@ -328,6 +346,7 @@ class ParallelBranchStartStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         parallel_id: str
         parallel_branch_id: str
         parent_parallel_id: Optional[str] = None
@@ -349,6 +368,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         parallel_id: str
         parallel_branch_id: str
         parent_parallel_id: Optional[str] = None
@@ -372,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         node_id: str
         node_type: str
@@ -397,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         node_id: str
         node_type: str
@@ -422,6 +444,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         id: str
         node_id: str
         node_type: str
@@ -454,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         text: str
 
     event: StreamEvent = StreamEvent.TEXT_CHUNK
@@ -469,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse):
         """
         Data entity
         """
+
         text: str
 
     event: StreamEvent = StreamEvent.TEXT_REPLACE
@@ -479,6 +504,7 @@ class PingStreamResponse(StreamResponse):
     """
     PingStreamResponse entity
     """
+
     event: StreamEvent = StreamEvent.PING
 
 
@@ -486,6 +512,7 @@ class AppStreamResponse(BaseModel):
     """
     AppStreamResponse entity
     """
+
     stream_response: StreamResponse
 
 
@@ -493,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse):
     """
     ChatbotAppStreamResponse entity
     """
+
     conversation_id: str
     message_id: str
     created_at: int
@@ -502,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse):
     """
     CompletionAppStreamResponse entity
     """
+
     message_id: str
     created_at: int
 
@@ -510,6 +539,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
     """
     WorkflowAppStreamResponse entity
     """
+
     workflow_run_id: Optional[str] = None
 
 
@@ -517,6 +547,7 @@ class AppBlockingResponse(BaseModel):
     """
     AppBlockingResponse entity
     """
+
     task_id: str
 
     def to_dict(self) -> dict:
@@ -532,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
         """
         Data entity
         """
+
         id: str
         mode: str
         conversation_id: str
@@ -552,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
         """
         Data entity
         """
+
         id: str
         mode: str
         message_id: str
@@ -571,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
         """
         Data entity
         """
+
         id: str
         workflow_id: str
         status: str

+ 27 - 33
api/core/app/features/annotation_reply/annotation_reply.py

@@ -13,11 +13,9 @@ logger = logging.getLogger(__name__)
 
 
 class AnnotationReplyFeature:
-    def query(self, app_record: App,
-              message: Message,
-              query: str,
-              user_id: str,
-              invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
+    def query(
+        self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
+    ) -> Optional[MessageAnnotation]:
         """
         Query app annotations to reply
         :param app_record: app record
@@ -27,8 +25,9 @@ class AnnotationReplyFeature:
         :param invoke_from: invoke from
         :return:
         """
-        annotation_setting = db.session.query(AppAnnotationSetting).filter(
-            AppAnnotationSetting.app_id == app_record.id).first()
+        annotation_setting = (
+            db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
+        )
 
         if not annotation_setting:
             return None
@@ -41,55 +40,50 @@ class AnnotationReplyFeature:
             embedding_model_name = collection_binding_detail.model_name
 
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                embedding_provider_name,
-                embedding_model_name,
-                'annotation'
+                embedding_provider_name, embedding_model_name, "annotation"
             )
 
             dataset = Dataset(
                 id=app_record.id,
                 tenant_id=app_record.tenant_id,
-                indexing_technique='high_quality',
+                indexing_technique="high_quality",
                 embedding_model_provider=embedding_provider_name,
                 embedding_model=embedding_model_name,
-                collection_binding_id=dataset_collection_binding.id
+                collection_binding_id=dataset_collection_binding.id,
             )
 
-            vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
+            vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
 
             documents = vector.search_by_vector(
-                query=query,
-                top_k=1,
-                score_threshold=score_threshold,
-                filter={
-                    'group_id': [dataset.id]
-                }
+                query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
             )
 
             if documents:
-                annotation_id = documents[0].metadata['annotation_id']
-                score = documents[0].metadata['score']
+                annotation_id = documents[0].metadata["annotation_id"]
+                score = documents[0].metadata["score"]
                 annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
                 if annotation:
                     if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
-                        from_source = 'api'
+                        from_source = "api"
                     else:
-                        from_source = 'console'
+                        from_source = "console"
 
                     # insert annotation history
-                    AppAnnotationService.add_annotation_history(annotation.id,
-                                                                app_record.id,
-                                                                annotation.question,
-                                                                annotation.content,
-                                                                query,
-                                                                user_id,
-                                                                message.id,
-                                                                from_source,
-                                                                score)
+                    AppAnnotationService.add_annotation_history(
+                        annotation.id,
+                        app_record.id,
+                        annotation.question,
+                        annotation.content,
+                        query,
+                        user_id,
+                        message.id,
+                        from_source,
+                        score,
+                    )
 
                     return annotation
         except Exception as e:
-            logger.warning(f'Query annotation failed, exception: {str(e)}.')
+            logger.warning(f"Query annotation failed, exception: {str(e)}.")
             return None
 
         return None

+ 4 - 6
api/core/app/features/hosting_moderation/hosting_moderation.py

@@ -8,8 +8,9 @@ logger = logging.getLogger(__name__)
 
 
 class HostingModerationFeature:
-    def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
-              prompt_messages: list[PromptMessage]) -> bool:
+    def check(
+        self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
+    ) -> bool:
         """
         Check hosting moderation
         :param application_generate_entity: application generate entity
@@ -23,9 +24,6 @@ class HostingModerationFeature:
             if isinstance(prompt_message.content, str):
                 text += prompt_message.content + "\n"
 
-        moderation_result = moderation.check_moderation(
-            model_config,
-            text
-        )
+        moderation_result = moderation.check_moderation(model_config, text)
 
         return moderation_result

+ 14 - 9
api/core/app/features/rate_limiting/rate_limit.py

@@ -19,7 +19,7 @@ class RateLimit:
     _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60  # recalculate request_count from request_detail every 5 minutes
     _instance_dict = {}
 
-    def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
+    def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
         if client_id not in cls._instance_dict:
             instance = super().__new__(cls)
             cls._instance_dict[client_id] = instance
@@ -27,13 +27,13 @@ class RateLimit:
 
     def __init__(self, client_id: str, max_active_requests: int):
         self.max_active_requests = max_active_requests
-        if hasattr(self, 'initialized'):
+        if hasattr(self, "initialized"):
             return
         self.initialized = True
         self.client_id = client_id
         self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
         self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
-        self.last_recalculate_time = float('-inf')
+        self.last_recalculate_time = float("-inf")
         self.flush_cache(use_local_value=True)
 
     def flush_cache(self, use_local_value=False):
@@ -46,7 +46,7 @@ class RateLimit:
                 pipe.execute()
         else:
             with redis_client.pipeline() as pipe:
-                self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
+                self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
                 redis_client.expire(self.max_active_requests_key, timedelta(days=1))
 
         # flush max active requests (in-transit request list)
@@ -54,8 +54,11 @@ class RateLimit:
             return
         request_details = redis_client.hgetall(self.active_requests_key)
         redis_client.expire(self.active_requests_key, timedelta(days=1))
-        timeout_requests = [k for k, v in request_details.items() if
-                            time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
+        timeout_requests = [
+            k
+            for k, v in request_details.items()
+            if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
+        ]
         if timeout_requests:
             redis_client.hdel(self.active_requests_key, *timeout_requests)
 
@@ -69,8 +72,10 @@ class RateLimit:
 
         active_requests_count = redis_client.hlen(self.active_requests_key)
         if active_requests_count >= self.max_active_requests:
-            raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
-                                              "concurrent requests allowed is {}.".format(self.max_active_requests))
+            raise AppInvokeQuotaExceededError(
+                "Too many requests. Please try again later. The current maximum "
+                "concurrent requests allowed is {}.".format(self.max_active_requests)
+            )
         redis_client.hset(self.active_requests_key, request_id, str(time.time()))
         return request_id
 
@@ -116,5 +121,5 @@ class RateLimitGenerator:
         if not self.closed:
             self.closed = True
             self.rate_limit.exit(self.request_id)
-            if self.generator is not None and hasattr(self.generator, 'close'):
+            if self.generator is not None and hasattr(self.generator, "close"):
                 self.generator.close()

+ 21 - 21
api/core/app/segments/__init__.py

@@ -25,25 +25,25 @@ from .variables import (
 )
 
 __all__ = [
-    'IntegerVariable',
-    'FloatVariable',
-    'ObjectVariable',
-    'SecretVariable',
-    'StringVariable',
-    'ArrayAnyVariable',
-    'Variable',
-    'SegmentType',
-    'SegmentGroup',
-    'Segment',
-    'NoneSegment',
-    'NoneVariable',
-    'IntegerSegment',
-    'FloatSegment',
-    'ObjectSegment',
-    'ArrayAnySegment',
-    'StringSegment',
-    'ArrayStringVariable',
-    'ArrayNumberVariable',
-    'ArrayObjectVariable',
-    'ArraySegment',
+    "IntegerVariable",
+    "FloatVariable",
+    "ObjectVariable",
+    "SecretVariable",
+    "StringVariable",
+    "ArrayAnyVariable",
+    "Variable",
+    "SegmentType",
+    "SegmentGroup",
+    "Segment",
+    "NoneSegment",
+    "NoneVariable",
+    "IntegerSegment",
+    "FloatSegment",
+    "ObjectSegment",
+    "ArrayAnySegment",
+    "StringSegment",
+    "ArrayStringVariable",
+    "ArrayNumberVariable",
+    "ArrayObjectVariable",
+    "ArraySegment",
 ]

+ 10 - 10
api/core/app/segments/factory.py

@@ -28,12 +28,12 @@ from .variables import (
 
 
 def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
-    if (value_type := mapping.get('value_type')) is None:
-        raise VariableError('missing value type')
-    if not mapping.get('name'):
-        raise VariableError('missing name')
-    if (value := mapping.get('value')) is None:
-        raise VariableError('missing value')
+    if (value_type := mapping.get("value_type")) is None:
+        raise VariableError("missing value type")
+    if not mapping.get("name"):
+        raise VariableError("missing name")
+    if (value := mapping.get("value")) is None:
+        raise VariableError("missing value")
     match value_type:
         case SegmentType.STRING:
             result = StringVariable.model_validate(mapping)
@@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
         case SegmentType.NUMBER if isinstance(value, float):
             result = FloatVariable.model_validate(mapping)
         case SegmentType.NUMBER if not isinstance(value, float | int):
-            raise VariableError(f'invalid number value {value}')
+            raise VariableError(f"invalid number value {value}")
         case SegmentType.OBJECT if isinstance(value, dict):
             result = ObjectVariable.model_validate(mapping)
         case SegmentType.ARRAY_STRING if isinstance(value, list):
@@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
         case SegmentType.ARRAY_OBJECT if isinstance(value, list):
             result = ArrayObjectVariable.model_validate(mapping)
         case _:
-            raise VariableError(f'not supported value type {value_type}')
+            raise VariableError(f"not supported value type {value_type}")
     if result.size > dify_config.MAX_VARIABLE_SIZE:
-        raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
+        raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
     return result
 
 
@@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
         return ObjectSegment(value=value)
     if isinstance(value, list):
         return ArrayAnySegment(value=value)
-    raise ValueError(f'not supported value {value}')
+    raise ValueError(f"not supported value {value}")

+ 2 - 2
api/core/app/segments/parser.py

@@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
 
 from . import SegmentGroup, factory
 
-VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
+VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
 
 
 def convert_template(*, template: str, variable_pool: VariablePool):
     parts = re.split(VARIABLE_PATTERN, template)
     segments = []
     for part in filter(lambda x: x, parts):
-        if '.' in part and (value := variable_pool.get(part.split('.'))):
+        if "." in part and (value := variable_pool.get(part.split("."))):
             segments.append(value)
         else:
             segments.append(factory.build_segment(part))

+ 3 - 3
api/core/app/segments/segment_group.py

@@ -8,15 +8,15 @@ class SegmentGroup(Segment):
 
     @property
     def text(self):
-        return ''.join([segment.text for segment in self.value])
+        return "".join([segment.text for segment in self.value])
 
     @property
     def log(self):
-        return ''.join([segment.log for segment in self.value])
+        return "".join([segment.log for segment in self.value])
 
     @property
     def markdown(self):
-        return ''.join([segment.markdown for segment in self.value])
+        return "".join([segment.markdown for segment in self.value])
 
     def to_object(self):
         return [segment.to_object() for segment in self.value]

+ 10 - 14
api/core/app/segments/segments.py

@@ -14,13 +14,13 @@ class Segment(BaseModel):
     value_type: SegmentType
     value: Any
 
-    @field_validator('value_type')
+    @field_validator("value_type")
     def validate_value_type(cls, value):
         """
         This validator checks if the provided value is equal to the default value of the 'value_type' field.
         If the value is different, a ValueError is raised.
         """
-        if value != cls.model_fields['value_type'].default:
+        if value != cls.model_fields["value_type"].default:
             raise ValueError("Cannot modify 'value_type'")
         return value
 
@@ -50,15 +50,15 @@ class NoneSegment(Segment):
 
     @property
     def text(self) -> str:
-        return 'null'
+        return "null"
 
     @property
     def log(self) -> str:
-        return 'null'
+        return "null"
 
     @property
     def markdown(self) -> str:
-        return 'null'
+        return "null"
 
 
 class StringSegment(Segment):
@@ -76,24 +76,21 @@ class IntegerSegment(Segment):
     value: int
 
 
-
-
-
 class ObjectSegment(Segment):
     value_type: SegmentType = SegmentType.OBJECT
     value: Mapping[str, Any]
 
     @property
     def text(self) -> str:
-        return json.dumps(self.model_dump()['value'], ensure_ascii=False)
+        return json.dumps(self.model_dump()["value"], ensure_ascii=False)
 
     @property
     def log(self) -> str:
-        return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
+        return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
 
     @property
     def markdown(self) -> str:
-        return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
+        return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
 
 
 class ArraySegment(Segment):
@@ -101,11 +98,11 @@ class ArraySegment(Segment):
     def markdown(self) -> str:
         items = []
         for item in self.value:
-            if hasattr(item, 'to_markdown'):
+            if hasattr(item, "to_markdown"):
                 items.append(item.to_markdown())
             else:
                 items.append(str(item))
-        return '\n'.join(items)
+        return "\n".join(items)
 
 
 class ArrayAnySegment(ArraySegment):
@@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
 class ArrayObjectSegment(ArraySegment):
     value_type: SegmentType = SegmentType.ARRAY_OBJECT
     value: Sequence[Mapping[str, Any]]
-

+ 10 - 10
api/core/app/segments/types.py

@@ -2,14 +2,14 @@ from enum import Enum
 
 
 class SegmentType(str, Enum):
-    NONE = 'none'
-    NUMBER = 'number'
-    STRING = 'string'
-    SECRET = 'secret'
-    ARRAY_ANY = 'array[any]'
-    ARRAY_STRING = 'array[string]'
-    ARRAY_NUMBER = 'array[number]'
-    ARRAY_OBJECT = 'array[object]'
-    OBJECT = 'object'
+    NONE = "none"
+    NUMBER = "number"
+    STRING = "string"
+    SECRET = "secret"
+    ARRAY_ANY = "array[any]"
+    ARRAY_STRING = "array[string]"
+    ARRAY_NUMBER = "array[number]"
+    ARRAY_OBJECT = "array[object]"
+    OBJECT = "object"
 
-    GROUP = 'group'
+    GROUP = "group"

+ 2 - 3
api/core/app/segments/variables.py

@@ -23,11 +23,11 @@ class Variable(Segment):
     """
 
     id: str = Field(
-        default='',
+        default="",
         description="Unique identity for variable. It's only used by environment variables now.",
     )
     name: str
-    description: str = Field(default='', description='Description of the variable.')
+    description: str = Field(default="", description="Description of the variable.")
 
 
 class StringVariable(StringSegment, Variable):
@@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
     pass
 
 
-
 class SecretVariable(StringVariable):
     value_type: SegmentType = SegmentType.SECRET
 

+ 20 - 22
api/core/app/task_pipeline/based_generate_task_pipeline.py

@@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
     _task_state: TaskState
     _application_generate_entity: AppGenerateEntity
 
-    def __init__(self, application_generate_entity: AppGenerateEntity,
-                 queue_manager: AppQueueManager,
-                 user: Union[Account, EndUser],
-                 stream: bool) -> None:
+    def __init__(
+        self,
+        application_generate_entity: AppGenerateEntity,
+        queue_manager: AppQueueManager,
+        user: Union[Account, EndUser],
+        stream: bool,
+    ) -> None:
         """
         Initialize GenerateTaskPipeline.
         :param application_generate_entity: application generate entity
@@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline:
         e = event.error
 
         if isinstance(e, InvokeAuthorizationError):
-            err = InvokeAuthorizationError('Incorrect API key provided')
+            err = InvokeAuthorizationError("Incorrect API key provided")
         elif isinstance(e, InvokeError) or isinstance(e, ValueError):
             err = e
         else:
-            err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
+            err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
 
         if message:
             refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
 
             if refetch_message:
                 err_desc = self._error_to_desc(err)
-                refetch_message.status = 'error'
+                refetch_message.status = "error"
                 refetch_message.error = err_desc
 
                 db.session.commit()
@@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline:
         :return:
         """
         if isinstance(e, QuotaExceededError):
-            return ("Your quota for Dify Hosted Model Provider has been exhausted. "
-                    "Please go to Settings -> Model Provider to complete your own provider credentials.")
+            return (
+                "Your quota for Dify Hosted Model Provider has been exhausted. "
+                "Please go to Settings -> Model Provider to complete your own provider credentials."
+            )
 
-        message = getattr(e, 'description', str(e))
+        message = getattr(e, "description", str(e))
         if not message:
-            message = 'Internal Server Error, please contact support.'
+            message = "Internal Server Error, please contact support."
 
         return message
 
@@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline:
         :param e: exception
         :return:
         """
-        return ErrorStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            err=e
-        )
+        return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
 
     def _ping_stream_response(self) -> PingStreamResponse:
         """
@@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline:
             return OutputModeration(
                 tenant_id=app_config.tenant_id,
                 app_id=app_config.app_id,
-                rule=ModerationRule(
-                    type=sensitive_word_avoidance.type,
-                    config=sensitive_word_avoidance.config
-                ),
-                queue_manager=self._queue_manager
+                rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
+                queue_manager=self._queue_manager,
             )
 
     def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
@@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline:
             self._output_moderation_handler.stop_thread()
 
             completion = self._output_moderation_handler.moderation_completion(
-                completion=completion,
-                public_event=False
+                completion=completion, public_event=False
             )
 
             self._output_moderation_handler = None

+ 76 - 99
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
     """
     EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
+
     _task_state: EasyUITaskState
-    _application_generate_entity: Union[
-        ChatAppGenerateEntity,
-        CompletionAppGenerateEntity,
-        AgentChatAppGenerateEntity
-    ]
-
-    def __init__(self, application_generate_entity: Union[
-        ChatAppGenerateEntity,
-        CompletionAppGenerateEntity,
-        AgentChatAppGenerateEntity
-    ],
-                 queue_manager: AppQueueManager,
-                 conversation: Conversation,
-                 message: Message,
-                 user: Union[Account, EndUser],
-                 stream: bool) -> None:
+    _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
+
+    def __init__(
+        self,
+        application_generate_entity: Union[
+            ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
+        ],
+        queue_manager: AppQueueManager,
+        conversation: Conversation,
+        message: Message,
+        user: Union[Account, EndUser],
+        stream: bool,
+    ) -> None:
         """
         Initialize GenerateTaskPipeline.
         :param application_generate_entity: application generate entity
@@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                 model=self._model_config.model,
                 prompt_messages=[],
                 message=AssistantPromptMessage(content=""),
-                usage=LLMUsage.empty_usage()
+                usage=LLMUsage.empty_usage(),
             )
         )
 
         self._conversation_name_generate_thread = None
 
     def process(
-            self,
+        self,
     ) -> Union[
         ChatbotAppBlockingResponse,
         CompletionAppBlockingResponse,
-        Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
+        Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
     ]:
         """
         Process generate task pipeline.
@@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
             # start generate conversation name thread
             self._conversation_name_generate_thread = self._generate_conversation_name(
-                self._conversation,
-                self._application_generate_entity.query
+                self._conversation, self._application_generate_entity.query
             )
 
-        generator = self._wrapper_process_stream_response(
-            trace_manager=self._application_generate_entity.trace_manager
-        )
+        generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
         if self._stream:
             return self._to_stream_response(generator)
         else:
             return self._to_blocking_response(generator)
 
-    def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
-        ChatbotAppBlockingResponse,
-        CompletionAppBlockingResponse
-    ]:
+    def _to_blocking_response(
+        self, generator: Generator[StreamResponse, None, None]
+    ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
         """
         Process blocking response.
         :return:
@@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
             if isinstance(stream_response, ErrorStreamResponse):
                 raise stream_response.err
             elif isinstance(stream_response, MessageEndStreamResponse):
-                extras = {
-                    'usage': jsonable_encoder(self._task_state.llm_result.usage)
-                }
+                extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
                 if self._task_state.metadata:
-                    extras['metadata'] = self._task_state.metadata
+                    extras["metadata"] = self._task_state.metadata
 
                 if self._conversation.mode == AppMode.COMPLETION.value:
                     response = CompletionAppBlockingResponse(
@@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                             message_id=self._message.id,
                             answer=self._task_state.llm_result.message.content,
                             created_at=int(self._message.created_at.timestamp()),
-                            **extras
-                        )
+                            **extras,
+                        ),
                     )
                 else:
                     response = ChatbotAppBlockingResponse(
@@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                             message_id=self._message.id,
                             answer=self._task_state.llm_result.message.content,
                             created_at=int(self._message.created_at.timestamp()),
-                            **extras
-                        )
+                            **extras,
+                        ),
                     )
 
                 return response
             else:
                 continue
 
-        raise Exception('Queue listening stopped unexpectedly.')
+        raise Exception("Queue listening stopped unexpectedly.")
 
-    def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-            -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
+    def _to_stream_response(
+        self, generator: Generator[StreamResponse, None, None]
+    ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
         """
         To stream response.
         :return:
@@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                 yield CompletionAppStreamResponse(
                     message_id=self._message.id,
                     created_at=int(self._message.created_at.timestamp()),
-                    stream_response=stream_response
+                    stream_response=stream_response,
                 )
             else:
                 yield ChatbotAppStreamResponse(
                     conversation_id=self._conversation.id,
                     message_id=self._message.id,
                     created_at=int(self._message.created_at.timestamp()),
-                    stream_response=stream_response
+                    stream_response=stream_response,
                 )
 
     def _listenAudioMsg(self, publisher, task_id: str):
@@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
             return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
         return None
 
-    def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
-            Generator[StreamResponse, None, None]:
-
+    def _wrapper_process_stream_response(
+        self, trace_manager: Optional[TraceQueueManager] = None
+    ) -> Generator[StreamResponse, None, None]:
         tenant_id = self._application_generate_entity.app_config.tenant_id
         task_id = self._application_generate_entity.task_id
         publisher = None
-        text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
-        if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
-            publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
+        text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
+        if (
+            text_to_speech_dict
+            and text_to_speech_dict.get("autoPlay") == "enabled"
+            and text_to_speech_dict.get("enabled")
+        ):
+            publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
         for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
             while True:
                 audio_response = self._listenAudioMsg(publisher, task_id)
@@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                 break
             else:
                 start_listener_time = time.time()
-                yield MessageAudioStreamResponse(audio=audio.audio,
-                                                 task_id=task_id)
-        yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
+                yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
+        yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
 
     def _process_stream_response(
-            self,
-            publisher: AppGeneratorTTSPublisher,
-            trace_manager: Optional[TraceQueueManager] = None
+        self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
     ) -> Generator[StreamResponse, None, None]:
         """
         Process stream response.
@@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         if self._conversation_name_generate_thread:
             self._conversation_name_generate_thread.join()
 
-    def _save_message(
-            self, trace_manager: Optional[TraceQueueManager] = None
-    ) -> None:
+    def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
         """
         Save message.
         :return:
@@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
 
         self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
-            self._model_config.mode,
-            self._task_state.llm_result.prompt_messages
+            self._model_config.mode, self._task_state.llm_result.prompt_messages
         )
         self._message.message_tokens = usage.prompt_tokens
         self._message.message_unit_price = usage.prompt_unit_price
         self._message.message_price_unit = usage.prompt_price_unit
-        self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
-            if llm_result.message.content else ''
+        self._message.answer = (
+            PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
+            if llm_result.message.content
+            else ""
+        )
         self._message.answer_tokens = usage.completion_tokens
         self._message.answer_unit_price = usage.completion_unit_price
         self._message.answer_price_unit = usage.completion_price_unit
         self._message.provider_response_latency = time.perf_counter() - self._start_at
         self._message.total_price = usage.total_price
         self._message.currency = usage.currency
-        self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
-            if self._task_state.metadata else None
+        self._message.message_metadata = (
+            json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+        )
 
         db.session.commit()
 
         if trace_manager:
             trace_manager.add_trace_task(
                 TraceTask(
-                    TraceTaskName.MESSAGE_TRACE,
-                    conversation_id=self._conversation.id,
-                    message_id=self._message.id
+                    TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
                 )
             )
 
@@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
             self._message,
             application_generate_entity=self._application_generate_entity,
             conversation=self._conversation,
-            is_first_message=self._application_generate_entity.app_config.app_mode in [
-                AppMode.AGENT_CHAT,
-                AppMode.CHAT
-            ] and self._application_generate_entity.conversation_id is None,
-            extras=self._application_generate_entity.extras
+            is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
+            and self._application_generate_entity.conversation_id is None,
+            extras=self._application_generate_entity.extras,
         )
 
     def _handle_stop(self, event: QueueStopEvent) -> None:
@@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         model = model_config.model
 
         model_instance = ModelInstance(
-            provider_model_bundle=model_config.provider_model_bundle,
-            model=model_config.model
+            provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
         )
 
         # calculate num tokens
         prompt_tokens = 0
         if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
-            prompt_tokens = model_instance.get_llm_num_tokens(
-                self._task_state.llm_result.prompt_messages
-            )
+            prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
 
         completion_tokens = 0
         if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
-            completion_tokens = model_instance.get_llm_num_tokens(
-                [self._task_state.llm_result.message]
-            )
+            completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
 
         credentials = model_config.credentials
 
@@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         model_type_instance = model_config.provider_model_bundle.model_type_instance
         model_type_instance = cast(LargeLanguageModel, model_type_instance)
         self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
-            model,
-            credentials,
-            prompt_tokens,
-            completion_tokens
+            model, credentials, prompt_tokens, completion_tokens
         )
 
     def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         Message end to stream response.
         :return:
         """
-        self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
+        self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
 
         extras = {}
         if self._task_state.metadata:
-            extras['metadata'] = self._task_state.metadata
+            extras["metadata"] = self._task_state.metadata
 
         return MessageEndStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            id=self._message.id,
-            **extras
+            task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
         )
 
     def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         :return:
         """
         return AgentMessageStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            id=message_id,
-            answer=answer
+            task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
         )
 
     def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
@@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         :return:
         """
         agent_thought: MessageAgentThought = (
-            db.session.query(MessageAgentThought)
-            .filter(MessageAgentThought.id == event.agent_thought_id)
-            .first()
+            db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
         )
         db.session.refresh(agent_thought)
         db.session.close()
@@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                 tool=agent_thought.tool,
                 tool_labels=agent_thought.tool_labels,
                 tool_input=agent_thought.tool_input,
-                message_files=agent_thought.files
+                message_files=agent_thought.files,
             )
 
         return None
@@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                             prompt_messages=self._task_state.llm_result.prompt_messages,
                             delta=LLMResultChunkDelta(
                                 index=0,
-                                message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
-                            )
+                                message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
+                            ),
                         )
-                    ), PublishFrom.TASK_PIPELINE
+                    ),
+                    PublishFrom.TASK_PIPELINE,
                 )
 
                 self._queue_manager.publish(
-                    QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
-                    PublishFrom.TASK_PIPELINE
+                    QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
                 )
                 return True
             else:

+ 27 - 48
api/core/app/task_pipeline/message_cycle_manage.py

@@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService
 
 class MessageCycleManage:
     _application_generate_entity: Union[
-        ChatAppGenerateEntity,
-        CompletionAppGenerateEntity,
-        AgentChatAppGenerateEntity,
-        AdvancedChatAppGenerateEntity
+        ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
     ]
     _task_state: Union[EasyUITaskState, WorkflowTaskState]
 
@@ -49,15 +46,18 @@ class MessageCycleManage:
 
         is_first_message = self._application_generate_entity.conversation_id is None
         extras = self._application_generate_entity.extras
-        auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
+        auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
 
         if auto_generate_conversation_name and is_first_message:
             # start generate thread
-            thread = Thread(target=self._generate_conversation_name_worker, kwargs={
-                'flask_app': current_app._get_current_object(), # type: ignore
-                'conversation_id': conversation.id,
-                'query': query
-            })
+            thread = Thread(
+                target=self._generate_conversation_name_worker,
+                kwargs={
+                    "flask_app": current_app._get_current_object(),  # type: ignore
+                    "conversation_id": conversation.id,
+                    "query": query,
+                },
+            )
 
             thread.start()
 
@@ -65,17 +65,10 @@ class MessageCycleManage:
 
         return None
 
-    def _generate_conversation_name_worker(self,
-                                           flask_app: Flask,
-                                           conversation_id: str,
-                                           query: str):
+    def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
         with flask_app.app_context():
             # get conversation and message
-            conversation = (
-                db.session.query(Conversation)
-                .filter(Conversation.id == conversation_id)
-                .first()
-            )
+            conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
 
             if not conversation:
                 return
@@ -105,12 +98,9 @@ class MessageCycleManage:
         annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
         if annotation:
             account = annotation.account
-            self._task_state.metadata['annotation_reply'] = {
-                'id': annotation.id,
-                'account': {
-                    'id': annotation.account_id,
-                    'name': account.name if account else 'Dify user'
-                }
+            self._task_state.metadata["annotation_reply"] = {
+                "id": annotation.id,
+                "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
             }
 
             return annotation
@@ -124,7 +114,7 @@ class MessageCycleManage:
         :return:
         """
         if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
-            self._task_state.metadata['retriever_resources'] = event.retriever_resources
+            self._task_state.metadata["retriever_resources"] = event.retriever_resources
 
     def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
         """
@@ -132,27 +122,23 @@ class MessageCycleManage:
         :param event: event
         :return:
         """
-        message_file = (
-            db.session.query(MessageFile)
-            .filter(MessageFile.id == event.message_file_id)
-            .first()
-        )
+        message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
 
         if message_file:
             # get tool file id
-            tool_file_id = message_file.url.split('/')[-1]
+            tool_file_id = message_file.url.split("/")[-1]
             # trim extension
-            tool_file_id = tool_file_id.split('.')[0]
+            tool_file_id = tool_file_id.split(".")[0]
 
             # get extension
-            if '.' in message_file.url:
+            if "." in message_file.url:
                 extension = f'.{message_file.url.split(".")[-1]}'
                 if len(extension) > 10:
-                    extension = '.bin'
+                    extension = ".bin"
             else:
-                extension = '.bin'
+                extension = ".bin"
             # add sign url to local file
-            if message_file.url.startswith('http'):
+            if message_file.url.startswith("http"):
                 url = message_file.url
             else:
                 url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
@@ -161,8 +147,8 @@ class MessageCycleManage:
                 task_id=self._application_generate_entity.task_id,
                 id=message_file.id,
                 type=message_file.type,
-                belongs_to=message_file.belongs_to or 'user',
-                url=url
+                belongs_to=message_file.belongs_to or "user",
+                url=url,
             )
 
         return None
@@ -174,11 +160,7 @@ class MessageCycleManage:
         :param message_id: message id
         :return:
         """
-        return MessageStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            id=message_id,
-            answer=answer
-        )
+        return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
 
     def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
         """
@@ -186,7 +168,4 @@ class MessageCycleManage:
         :param answer: answer
         :return:
         """
-        return MessageReplaceStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            answer=answer
-        )
+        return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)

+ 61 - 56
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -70,14 +70,14 @@ class WorkflowCycleManage:
 
         inputs = {**self._application_generate_entity.inputs}
         for key, value in (self._workflow_system_variables or {}).items():
-            if key.value == 'conversation':
+            if key.value == "conversation":
                 continue
 
-            inputs[f'sys.{key.value}'] = value
+            inputs[f"sys.{key.value}"] = value
 
         inputs = WorkflowEntry.handle_special_values(inputs)
 
-        triggered_from= (
+        triggered_from = (
             WorkflowRunTriggeredFrom.DEBUGGING
             if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
             else WorkflowRunTriggeredFrom.APP_RUN
@@ -185,20 +185,26 @@ class WorkflowCycleManage:
 
         db.session.commit()
 
-        running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
-            WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
-            WorkflowNodeExecution.app_id == workflow_run.app_id,
-            WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
-            WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
-            WorkflowNodeExecution.workflow_run_id == workflow_run.id,
-            WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
-        ).all()
+        running_workflow_node_executions = (
+            db.session.query(WorkflowNodeExecution)
+            .filter(
+                WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
+                WorkflowNodeExecution.app_id == workflow_run.app_id,
+                WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
+                WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+                WorkflowNodeExecution.workflow_run_id == workflow_run.id,
+                WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
+            )
+            .all()
+        )
 
         for workflow_node_execution in running_workflow_node_executions:
             workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
             workflow_node_execution.error = error
             workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
-            workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
+            workflow_node_execution.elapsed_time = (
+                workflow_node_execution.finished_at - workflow_node_execution.created_at
+            ).total_seconds()
             db.session.commit()
 
         db.session.refresh(workflow_run)
@@ -216,7 +222,9 @@ class WorkflowCycleManage:
 
         return workflow_run
 
-    def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
+    def _handle_node_execution_start(
+        self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
+    ) -> WorkflowNodeExecution:
         # init workflow node execution
         workflow_node_execution = WorkflowNodeExecution()
         workflow_node_execution.tenant_id = workflow_run.tenant_id
@@ -333,16 +341,16 @@ class WorkflowCycleManage:
             created_by_account = workflow_run.created_by_account
             if created_by_account:
                 created_by = {
-                    'id': created_by_account.id,
-                    'name': created_by_account.name,
-                    'email': created_by_account.email,
+                    "id": created_by_account.id,
+                    "name": created_by_account.name,
+                    "email": created_by_account.email,
                 }
         else:
             created_by_end_user = workflow_run.created_by_end_user
             if created_by_end_user:
                 created_by = {
-                    'id': created_by_end_user.id,
-                    'user': created_by_end_user.session_id,
+                    "id": created_by_end_user.id,
+                    "user": created_by_end_user.session_id,
                 }
 
         return WorkflowFinishStreamResponse(
@@ -401,7 +409,7 @@ class WorkflowCycleManage:
         # extras logic
         if event.node_type == NodeType.TOOL:
             node_data = cast(ToolNodeData, event.node_data)
-            response.data.extras['icon'] = ToolManager.get_tool_icon(
+            response.data.extras["icon"] = ToolManager.get_tool_icon(
                 tenant_id=self._application_generate_entity.app_config.tenant_id,
                 provider_type=node_data.provider_type,
                 provider_id=node_data.provider_id,
@@ -410,10 +418,10 @@ class WorkflowCycleManage:
         return response
 
     def _workflow_node_finish_to_stream_response(
-        self, 
-        event: QueueNodeSucceededEvent | QueueNodeFailedEvent, 
-        task_id: str, 
-        workflow_node_execution: WorkflowNodeExecution
+        self,
+        event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
+        task_id: str,
+        workflow_node_execution: WorkflowNodeExecution,
     ) -> Optional[NodeFinishStreamResponse]:
         """
         Workflow node finish to stream response.
@@ -424,7 +432,7 @@ class WorkflowCycleManage:
         """
         if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
             return None
-        
+
         return NodeFinishStreamResponse(
             task_id=task_id,
             workflow_run_id=workflow_node_execution.workflow_run_id,
@@ -452,13 +460,10 @@ class WorkflowCycleManage:
                 iteration_id=event.in_iteration_id,
             ),
         )
-    
+
     def _workflow_parallel_branch_start_to_stream_response(
-            self,
-            task_id: str,
-            workflow_run: WorkflowRun,
-            event: QueueParallelBranchRunStartedEvent
-        ) -> ParallelBranchStartStreamResponse:
+        self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
+    ) -> ParallelBranchStartStreamResponse:
         """
         Workflow parallel branch start to stream response
         :param task_id: task id
@@ -476,15 +481,15 @@ class WorkflowCycleManage:
                 parent_parallel_start_node_id=event.parent_parallel_start_node_id,
                 iteration_id=event.in_iteration_id,
                 created_at=int(time.time()),
-            )
+            ),
         )
-    
+
     def _workflow_parallel_branch_finished_to_stream_response(
-            self,
-            task_id: str,
-            workflow_run: WorkflowRun,
-            event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
-        ) -> ParallelBranchFinishedStreamResponse:
+        self,
+        task_id: str,
+        workflow_run: WorkflowRun,
+        event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
+    ) -> ParallelBranchFinishedStreamResponse:
         """
         Workflow parallel branch finished to stream response
         :param task_id: task id
@@ -501,18 +506,15 @@ class WorkflowCycleManage:
                 parent_parallel_id=event.parent_parallel_id,
                 parent_parallel_start_node_id=event.parent_parallel_start_node_id,
                 iteration_id=event.in_iteration_id,
-                status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
+                status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
                 error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
                 created_at=int(time.time()),
-            )
+            ),
         )
 
     def _workflow_iteration_start_to_stream_response(
-            self,
-            task_id: str,
-            workflow_run: WorkflowRun,
-            event: QueueIterationStartEvent
-        ) -> IterationNodeStartStreamResponse:
+        self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
+    ) -> IterationNodeStartStreamResponse:
         """
         Workflow iteration start to stream response
         :param task_id: task id
@@ -534,10 +536,12 @@ class WorkflowCycleManage:
                 metadata=event.metadata or {},
                 parallel_id=event.parallel_id,
                 parallel_start_node_id=event.parallel_start_node_id,
-            )
+            ),
         )
 
-    def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
+    def _workflow_iteration_next_to_stream_response(
+        self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
+    ) -> IterationNodeNextStreamResponse:
         """
         Workflow iteration next to stream response
         :param task_id: task id
@@ -559,10 +563,12 @@ class WorkflowCycleManage:
                 extras={},
                 parallel_id=event.parallel_id,
                 parallel_start_node_id=event.parallel_start_node_id,
-            )
+            ),
         )
 
-    def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
+    def _workflow_iteration_completed_to_stream_response(
+        self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
+    ) -> IterationNodeCompletedStreamResponse:
         """
         Workflow iteration completed to stream response
         :param task_id: task id
@@ -585,13 +591,13 @@ class WorkflowCycleManage:
                 status=WorkflowNodeExecutionStatus.SUCCEEDED,
                 error=None,
                 elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
-                total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
+                total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
                 execution_metadata=event.metadata,
                 finished_at=int(time.time()),
                 steps=event.steps,
                 parallel_id=event.parallel_id,
                 parallel_start_node_id=event.parallel_start_node_id,
-            )
+            ),
         )
 
     def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
@@ -643,7 +649,7 @@ class WorkflowCycleManage:
             return None
 
         if isinstance(value, dict):
-            if '__variant' in value and value['__variant'] == FileVar.__name__:
+            if "__variant" in value and value["__variant"] == FileVar.__name__:
                 return value
         elif isinstance(value, FileVar):
             return value.to_dict()
@@ -656,11 +662,10 @@ class WorkflowCycleManage:
         :param workflow_run_id: workflow run id
         :return:
         """
-        workflow_run = db.session.query(WorkflowRun).filter(
-            WorkflowRun.id == workflow_run_id).first()
+        workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
 
         if not workflow_run:
-            raise Exception(f'Workflow run not found: {workflow_run_id}')
+            raise Exception(f"Workflow run not found: {workflow_run_id}")
 
         return workflow_run
 
@@ -683,6 +688,6 @@ class WorkflowCycleManage:
         )
 
         if not workflow_node_execution:
-            raise Exception(f'Workflow node execution not found: {node_execution_id}')
+            raise Exception(f"Workflow node execution not found: {node_execution_id}")
 
-        return workflow_node_execution
+        return workflow_node_execution

+ 17 - 21
api/core/callback_handler/agent_tool_callback_handler.py

@@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
     "red": "31;1",
 }
 
+
 def get_colored_text(text: str, color: str) -> str:
     """Get colored text."""
     color_str = _TEXT_COLOR_MAPPING[color]
     return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
 
 
-def print_text(
-    text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
-) -> None:
+def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
     """Print text with highlighting and no end characters."""
     text_to_print = get_colored_text(text, color) if color else text
     print(text_to_print, end=end, file=file)
     if file:
         file.flush()  # ensure all printed content are written to file
 
+
 class DifyAgentCallbackHandler(BaseModel):
     """Callback Handler that prints to std out."""
-    color: Optional[str] = ''
+
+    color: Optional[str] = ""
     current_loop: int = 1
 
     def __init__(self, color: Optional[str] = None) -> None:
         super().__init__()
         """Initialize callback handler."""
         # use a specific color is not specified
-        self.color = color or 'green'
+        self.color = color or "green"
         self.current_loop = 1
 
     def on_tool_start(
@@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
         tool_outputs: Sequence[ToolInvokeMessage],
         message_id: Optional[str] = None,
         timer: Optional[Any] = None,
-        trace_manager: Optional[TraceQueueManager] = None
+        trace_manager: Optional[TraceQueueManager] = None,
     ) -> None:
         """If not the final action, print out observation."""
         print_text("\n[on_tool_end]\n", color=self.color)
@@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
                 )
             )
 
-    def on_tool_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
+    def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
         """Do nothing."""
-        print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
+        print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
 
-    def on_agent_start(
-        self, thought: str
-    ) -> None:
+    def on_agent_start(self, thought: str) -> None:
         """Run on agent start."""
         if thought:
-            print_text("\n[on_agent_start] \nCurrent Loop: " + \
-                        str(self.current_loop) + \
-                        "\nThought: " + thought + "\n", color=self.color)
+            print_text(
+                "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
+                color=self.color,
+            )
         else:
             print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
 
-    def on_agent_finish(
-        self, color: Optional[str] = None, **kwargs: Any
-    ) -> None:
+    def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
         """Run on agent end."""
         print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
 
@@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
     @property
     def ignore_agent(self) -> bool:
         """Whether to ignore agent callbacks."""
-        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
 
     @property
     def ignore_chat_model(self) -> bool:
         """Whether to ignore chat model callbacks."""
-        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"

+ 28 - 34
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,4 +1,3 @@
-
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
 class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
 
-    def __init__(self, queue_manager: AppQueueManager,
-                 app_id: str,
-                 message_id: str,
-                 user_id: str,
-                 invoke_from: InvokeFrom) -> None:
+    def __init__(
+        self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
+    ) -> None:
         self._queue_manager = queue_manager
         self._app_id = app_id
         self._message_id = message_id
@@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
         dataset_query = DatasetQuery(
             dataset_id=dataset_id,
             content=query,
-            source='app',
+            source="app",
             source_app_id=self._app_id,
-            created_by_role=('account'
-                             if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
-            created_by=self._user_id
+            created_by_role=(
+                "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
+            ),
+            created_by=self._user_id,
         )
 
         db.session.add(dataset_query)
@@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
         """Handle tool end."""
         for document in documents:
             query = db.session.query(DocumentSegment).filter(
-                DocumentSegment.index_node_id == document.metadata['doc_id']
+                DocumentSegment.index_node_id == document.metadata["doc_id"]
             )
 
             # if 'dataset_id' in document.metadata:
-            if 'dataset_id' in document.metadata:
-                query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
+            if "dataset_id" in document.metadata:
+                query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
 
             # add hit count to document segment
-            query.update(
-                {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
-                synchronize_session=False
-            )
+            query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
 
             db.session.commit()
 
@@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
             for item in resource:
                 dataset_retriever_resource = DatasetRetrieverResource(
                     message_id=self._message_id,
-                    position=item.get('position'),
-                    dataset_id=item.get('dataset_id'),
-                    dataset_name=item.get('dataset_name'),
-                    document_id=item.get('document_id'),
-                    document_name=item.get('document_name'),
-                    data_source_type=item.get('data_source_type'),
-                    segment_id=item.get('segment_id'),
-                    score=item.get('score') if 'score' in item else None,
-                    hit_count=item.get('hit_count') if 'hit_count' else None,
-                    word_count=item.get('word_count') if 'word_count' in item else None,
-                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
-                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
-                    content=item.get('content'),
-                    retriever_from=item.get('retriever_from'),
-                    created_by=self._user_id
+                    position=item.get("position"),
+                    dataset_id=item.get("dataset_id"),
+                    dataset_name=item.get("dataset_name"),
+                    document_id=item.get("document_id"),
+                    document_name=item.get("document_name"),
+                    data_source_type=item.get("data_source_type"),
+                    segment_id=item.get("segment_id"),
+                    score=item.get("score") if "score" in item else None,
+                    hit_count=item.get("hit_count") if "hit_count" else None,
+                    word_count=item.get("word_count") if "word_count" in item else None,
+                    segment_position=item.get("segment_position") if "segment_position" in item else None,
+                    index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
+                    content=item.get("content"),
+                    retriever_from=item.get("retriever_from"),
+                    created_by=self._user_id,
                 )
                 db.session.add(dataset_retriever_resource)
                 db.session.commit()
 
         self._queue_manager.publish(
-            QueueRetrieverResourcesEvent(retriever_resources=resource),
-            PublishFrom.APPLICATION_MANAGER
+            QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
         )

+ 1 - 1
api/core/callback_handler/workflow_tool_callback_handler.py

@@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
 
 
 class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
-    """Callback Handler that prints to std out."""
+    """Callback Handler that prints to std out."""

+ 27 - 23
api/core/embedding/cached_embedding.py

@@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings):
         embedding_queue_indices = []
         for i, text in enumerate(texts):
             hash = helper.generate_text_hash(text)
-            embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
-                                                              hash=hash,
-                                                              provider_name=self._model_instance.provider).first()
+            embedding = (
+                db.session.query(Embedding)
+                .filter_by(
+                    model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
+                )
+                .first()
+            )
             if embedding:
                 text_embeddings[i] = embedding.get_embedding()
             else:
@@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings):
             embedding_queue_embeddings = []
             try:
                 model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
-                model_schema = model_type_instance.get_model_schema(self._model_instance.model,
-                                                                    self._model_instance.credentials)
-                max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
-                    if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
+                model_schema = model_type_instance.get_model_schema(
+                    self._model_instance.model, self._model_instance.credentials
+                )
+                max_chunks = (
+                    model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
+                    if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
+                    else 1
+                )
                 for i in range(0, len(embedding_queue_texts), max_chunks):
-                    batch_texts = embedding_queue_texts[i:i + max_chunks]
+                    batch_texts = embedding_queue_texts[i : i + max_chunks]
 
-                    embedding_result = self._model_instance.invoke_text_embedding(
-                        texts=batch_texts,
-                        user=self._user
-                    )
+                    embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
 
                     for vector in embedding_result.embeddings:
                         try:
@@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings):
                         except IntegrityError:
                             db.session.rollback()
                         except Exception as e:
-                            logging.exception('Failed transform embedding: ', e)
+                            logging.exception("Failed transform embedding: ", e)
                 cache_embeddings = []
                 try:
                     for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
                         text_embeddings[i] = embedding
                         hash = helper.generate_text_hash(texts[i])
                         if hash not in cache_embeddings:
-                            embedding_cache = Embedding(model_name=self._model_instance.model,
-                                                        hash=hash,
-                                                        provider_name=self._model_instance.provider)
+                            embedding_cache = Embedding(
+                                model_name=self._model_instance.model,
+                                hash=hash,
+                                provider_name=self._model_instance.provider,
+                            )
                             embedding_cache.set_embedding(embedding)
                             db.session.add(embedding_cache)
                             cache_embeddings.append(hash)
@@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings):
                     db.session.rollback()
             except Exception as ex:
                 db.session.rollback()
-                logger.error('Failed to embed documents: ', ex)
+                logger.error("Failed to embed documents: ", ex)
                 raise ex
 
         return text_embeddings
@@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings):
         """Embed query text."""
         # use doc embedding cache or store if not exists
         hash = helper.generate_text_hash(text)
-        embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
+        embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
         embedding = redis_client.get(embedding_cache_key)
         if embedding:
             redis_client.expire(embedding_cache_key, 600)
             return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
         try:
-            embedding_result = self._model_instance.invoke_text_embedding(
-                texts=[text],
-                user=self._user
-            )
+            embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
 
             embedding_results = embedding_result.embeddings[0]
             embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
@@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings):
         except IntegrityError:
             db.session.rollback()
         except:
-            logging.exception('Failed to add embedding to redis')
+            logging.exception("Failed to add embedding to redis")
 
         return embedding_results

+ 4 - 4
api/core/entities/agent_entities.py

@@ -2,7 +2,7 @@ from enum import Enum
 
 
 class PlanningStrategy(Enum):
-    ROUTER = 'router'
-    REACT_ROUTER = 'react_router'
-    REACT = 'react'
-    FUNCTION_CALL = 'function_call'
+    ROUTER = "router"
+    REACT_ROUTER = "react_router"
+    REACT = "react"
+    FUNCTION_CALL = "function_call"

+ 3 - 3
api/core/entities/message_entities.py

@@ -5,7 +5,7 @@ from pydantic import BaseModel
 
 
 class PromptMessageFileType(enum.Enum):
-    IMAGE = 'image'
+    IMAGE = "image"
 
     @staticmethod
     def value_of(value):
@@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel):
 
 class ImagePromptMessageFile(PromptMessageFile):
     class DETAIL(enum.Enum):
-        LOW = 'low'
-        HIGH = 'high'
+        LOW = "low"
+        HIGH = "high"
 
     type: PromptMessageFileType = PromptMessageFileType.IMAGE
     detail: DETAIL = DETAIL.LOW

+ 7 - 1
api/core/entities/model_entities.py

@@ -12,6 +12,7 @@ class ModelStatus(Enum):
     """
     Enum class for model status.
     """
+
     ACTIVE = "active"
     NO_CONFIGURE = "no-configure"
     QUOTA_EXCEEDED = "quota-exceeded"
@@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel):
     """
     Simple provider.
     """
+
     provider: str
     label: I18nObject
     icon_small: Optional[I18nObject] = None
@@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel):
             label=provider_entity.label,
             icon_small=provider_entity.icon_small,
             icon_large=provider_entity.icon_large,
-            supported_model_types=provider_entity.supported_model_types
+            supported_model_types=provider_entity.supported_model_types,
         )
 
 
@@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
     """
     Model class for model response.
     """
+
     status: ModelStatus
     load_balancing_enabled: bool = False
 
@@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity):
     """
     Model with provider entity.
     """
+
     provider: SimpleModelProviderEntity
 
 
@@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel):
     """
     Default model provider entity.
     """
+
     provider: str
     label: I18nObject
     icon_small: Optional[I18nObject] = None
@@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel):
     """
     Default model entity.
     """
+
     model: str
     model_type: ModelType
     provider: DefaultModelProviderEntity

+ 201 - 167
api/core/entities/provider_configuration.py

@@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
     """
     Model class for provider configuration.
     """
+
     tenant_id: str
     provider: ProviderEntity
     preferred_provider_type: ProviderType
@@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
                 original_provider_configurate_methods[self.provider.provider].append(configurate_method)
 
         if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
-            if (any(len(quota_configuration.restrict_models) > 0
-                     for quota_configuration in self.system_configuration.quota_configurations)
-                    and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
+            if (
+                any(
+                    len(quota_configuration.restrict_models) > 0
+                    for quota_configuration in self.system_configuration.quota_configurations
+                )
+                and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
+            ):
                 self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
 
     def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
@@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
         if self.model_settings:
             # check if model is disabled by admin
             for model_setting in self.model_settings:
-                if (model_setting.model_type == model_type
-                        and model_setting.model == model):
+                if model_setting.model_type == model_type and model_setting.model == model:
                     if not model_setting.enabled:
-                        raise ValueError(f'Model {model} is disabled.')
+                        raise ValueError(f"Model {model} is disabled.")
 
         if self.using_provider_type == ProviderType.SYSTEM:
             restrict_models = []
@@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
             copy_credentials = self.system_configuration.credentials.copy()
             if restrict_models:
                 for restrict_model in restrict_models:
-                    if (restrict_model.model_type == model_type
-                            and restrict_model.model == model
-                            and restrict_model.base_model_name):
-                        copy_credentials['base_model_name'] = restrict_model.base_model_name
+                    if (
+                        restrict_model.model_type == model_type
+                        and restrict_model.model == model
+                        and restrict_model.base_model_name
+                    ):
+                        copy_credentials["base_model_name"] = restrict_model.base_model_name
 
             return copy_credentials
         else:
@@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
 
         current_quota_type = self.system_configuration.current_quota_type
         current_quota_configuration = next(
-            (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
-            None
+            (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
         )
 
-        return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
-            SystemConfigurationStatus.QUOTA_EXCEEDED
+        return (
+            SystemConfigurationStatus.ACTIVE
+            if current_quota_configuration.is_valid
+            else SystemConfigurationStatus.QUOTA_EXCEEDED
+        )
 
     def is_custom_configuration_available(self) -> bool:
         """
         Check custom configuration available.
         :return:
         """
-        return (self.custom_configuration.provider is not None
-                or len(self.custom_configuration.models) > 0)
+        return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
 
     def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
         """
@@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
         return self.obfuscated_credentials(
             credentials=credentials,
             credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
-            if self.provider.provider_credential_schema else []
+            if self.provider.provider_credential_schema
+            else [],
         )
 
     def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
@@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # get provider
-        provider_record = db.session.query(Provider) \
+        provider_record = (
+            db.session.query(Provider)
             .filter(
-            Provider.tenant_id == self.tenant_id,
-            Provider.provider_name == self.provider.provider,
-            Provider.provider_type == ProviderType.CUSTOM.value
-        ).first()
+                Provider.tenant_id == self.tenant_id,
+                Provider.provider_name == self.provider.provider,
+                Provider.provider_type == ProviderType.CUSTOM.value,
+            )
+            .first()
+        )
 
         # Get provider credential secret variables
         provider_credential_secret_variables = self.extract_secret_variables(
             self.provider.provider_credential_schema.credential_form_schemas
-            if self.provider.provider_credential_schema else []
+            if self.provider.provider_credential_schema
+            else []
         )
 
         if provider_record:
@@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
                 # fix origin data
                 if provider_record.encrypted_config:
                     if not provider_record.encrypted_config.startswith("{"):
-                        original_credentials = {
-                            "openai_api_key": provider_record.encrypted_config
-                        }
+                        original_credentials = {"openai_api_key": provider_record.encrypted_config}
                     else:
                         original_credentials = json.loads(provider_record.encrypted_config)
                 else:
@@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
                         credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
 
         credentials = model_provider_factory.provider_credentials_validate(
-            provider=self.provider.provider,
-            credentials=credentials
+            provider=self.provider.provider, credentials=credentials
         )
 
         for key, value in credentials.items():
@@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
                 provider_name=self.provider.provider,
                 provider_type=ProviderType.CUSTOM.value,
                 encrypted_config=json.dumps(credentials),
-                is_valid=True
+                is_valid=True,
             )
             db.session.add(provider_record)
             db.session.commit()
 
         provider_model_credentials_cache = ProviderCredentialsCache(
-            tenant_id=self.tenant_id,
-            identity_id=provider_record.id,
-            cache_type=ProviderCredentialsCacheType.PROVIDER
+            tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
         )
 
         provider_model_credentials_cache.delete()
@@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # get provider
-        provider_record = db.session.query(Provider) \
+        provider_record = (
+            db.session.query(Provider)
             .filter(
-            Provider.tenant_id == self.tenant_id,
-            Provider.provider_name == self.provider.provider,
-            Provider.provider_type == ProviderType.CUSTOM.value
-        ).first()
+                Provider.tenant_id == self.tenant_id,
+                Provider.provider_name == self.provider.provider,
+                Provider.provider_type == ProviderType.CUSTOM.value,
+            )
+            .first()
+        )
 
         # delete provider
         if provider_record:
@@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
             provider_model_credentials_cache = ProviderCredentialsCache(
                 tenant_id=self.tenant_id,
                 identity_id=provider_record.id,
-                cache_type=ProviderCredentialsCacheType.PROVIDER
+                cache_type=ProviderCredentialsCacheType.PROVIDER,
             )
 
             provider_model_credentials_cache.delete()
 
-    def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-            -> Optional[dict]:
+    def get_custom_model_credentials(
+        self, model_type: ModelType, model: str, obfuscated: bool = False
+    ) -> Optional[dict]:
         """
         Get custom model credentials.
 
@@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
                 return self.obfuscated_credentials(
                     credentials=credentials,
                     credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
-                    if self.provider.model_credential_schema else []
+                    if self.provider.model_credential_schema
+                    else [],
                 )
 
         return None
 
-    def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-            -> tuple[ProviderModel, dict]:
+    def custom_model_credentials_validate(
+        self, model_type: ModelType, model: str, credentials: dict
+    ) -> tuple[ProviderModel, dict]:
         """
         Validate custom model credentials.
 
@@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # get provider model
-        provider_model_record = db.session.query(ProviderModel) \
+        provider_model_record = (
+            db.session.query(ProviderModel)
             .filter(
-            ProviderModel.tenant_id == self.tenant_id,
-            ProviderModel.provider_name == self.provider.provider,
-            ProviderModel.model_name == model,
-            ProviderModel.model_type == model_type.to_origin_model_type()
-        ).first()
+                ProviderModel.tenant_id == self.tenant_id,
+                ProviderModel.provider_name == self.provider.provider,
+                ProviderModel.model_name == model,
+                ProviderModel.model_type == model_type.to_origin_model_type(),
+            )
+            .first()
+        )
 
         # Get provider credential secret variables
         provider_credential_secret_variables = self.extract_secret_variables(
             self.provider.model_credential_schema.credential_form_schemas
-            if self.provider.model_credential_schema else []
+            if self.provider.model_credential_schema
+            else []
         )
 
         if provider_model_record:
             try:
-                original_credentials = json.loads(
-                    provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
+                original_credentials = (
+                    json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
+                )
             except JSONDecodeError:
                 original_credentials = {}
 
@@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
                         credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
 
         credentials = model_provider_factory.model_credentials_validate(
-            provider=self.provider.provider,
-            model_type=model_type,
-            model=model,
-            credentials=credentials
+            provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
         )
 
         for key, value in credentials.items():
@@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
                 model_name=model,
                 model_type=model_type.to_origin_model_type(),
                 encrypted_config=json.dumps(credentials),
-                is_valid=True
+                is_valid=True,
             )
             db.session.add(provider_model_record)
             db.session.commit()
@@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
         provider_model_credentials_cache = ProviderCredentialsCache(
             tenant_id=self.tenant_id,
             identity_id=provider_model_record.id,
-            cache_type=ProviderCredentialsCacheType.MODEL
+            cache_type=ProviderCredentialsCacheType.MODEL,
         )
 
         provider_model_credentials_cache.delete()
@@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # get provider model
-        provider_model_record = db.session.query(ProviderModel) \
+        provider_model_record = (
+            db.session.query(ProviderModel)
             .filter(
-            ProviderModel.tenant_id == self.tenant_id,
-            ProviderModel.provider_name == self.provider.provider,
-            ProviderModel.model_name == model,
-            ProviderModel.model_type == model_type.to_origin_model_type()
-        ).first()
+                ProviderModel.tenant_id == self.tenant_id,
+                ProviderModel.provider_name == self.provider.provider,
+                ProviderModel.model_name == model,
+                ProviderModel.model_type == model_type.to_origin_model_type(),
+            )
+            .first()
+        )
 
         # delete provider model
         if provider_model_record:
@@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
             provider_model_credentials_cache = ProviderCredentialsCache(
                 tenant_id=self.tenant_id,
                 identity_id=provider_model_record.id,
-                cache_type=ProviderCredentialsCacheType.MODEL
+                cache_type=ProviderCredentialsCacheType.MODEL,
             )
 
             provider_model_credentials_cache.delete()
@@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
-        model_setting = db.session.query(ProviderModelSetting) \
+        model_setting = (
+            db.session.query(ProviderModelSetting)
             .filter(
-            ProviderModelSetting.tenant_id == self.tenant_id,
-            ProviderModelSetting.provider_name == self.provider.provider,
-            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-            ProviderModelSetting.model_name == model
-        ).first()
+                ProviderModelSetting.tenant_id == self.tenant_id,
+                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+                ProviderModelSetting.model_name == model,
+            )
+            .first()
+        )
 
         if model_setting:
             model_setting.enabled = True
@@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
                 provider_name=self.provider.provider,
                 model_type=model_type.to_origin_model_type(),
                 model_name=model,
-                enabled=True
+                enabled=True,
             )
             db.session.add(model_setting)
             db.session.commit()
@@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
-        model_setting = db.session.query(ProviderModelSetting) \
+        model_setting = (
+            db.session.query(ProviderModelSetting)
             .filter(
-            ProviderModelSetting.tenant_id == self.tenant_id,
-            ProviderModelSetting.provider_name == self.provider.provider,
-            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-            ProviderModelSetting.model_name == model
-        ).first()
+                ProviderModelSetting.tenant_id == self.tenant_id,
+                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+                ProviderModelSetting.model_name == model,
+            )
+            .first()
+        )
 
         if model_setting:
             model_setting.enabled = False
@@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
                 provider_name=self.provider.provider,
                 model_type=model_type.to_origin_model_type(),
                 model_name=model,
-                enabled=False
+                enabled=False,
             )
             db.session.add(model_setting)
             db.session.commit()
@@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
-        return db.session.query(ProviderModelSetting) \
+        return (
+            db.session.query(ProviderModelSetting)
             .filter(
-            ProviderModelSetting.tenant_id == self.tenant_id,
-            ProviderModelSetting.provider_name == self.provider.provider,
-            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-            ProviderModelSetting.model_name == model
-        ).first()
+                ProviderModelSetting.tenant_id == self.tenant_id,
+                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+                ProviderModelSetting.model_name == model,
+            )
+            .first()
+        )
 
     def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
         """
@@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
-        load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
+        load_balancing_config_count = (
+            db.session.query(LoadBalancingModelConfig)
             .filter(
-            LoadBalancingModelConfig.tenant_id == self.tenant_id,
-            LoadBalancingModelConfig.provider_name == self.provider.provider,
-            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
-            LoadBalancingModelConfig.model_name == model
-        ).count()
+                LoadBalancingModelConfig.tenant_id == self.tenant_id,
+                LoadBalancingModelConfig.provider_name == self.provider.provider,
+                LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+                LoadBalancingModelConfig.model_name == model,
+            )
+            .count()
+        )
 
         if load_balancing_config_count <= 1:
-            raise ValueError('Model load balancing configuration must be more than 1.')
+            raise ValueError("Model load balancing configuration must be more than 1.")
 
-        model_setting = db.session.query(ProviderModelSetting) \
+        model_setting = (
+            db.session.query(ProviderModelSetting)
             .filter(
-            ProviderModelSetting.tenant_id == self.tenant_id,
-            ProviderModelSetting.provider_name == self.provider.provider,
-            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-            ProviderModelSetting.model_name == model
-        ).first()
+                ProviderModelSetting.tenant_id == self.tenant_id,
+                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+                ProviderModelSetting.model_name == model,
+            )
+            .first()
+        )
 
         if model_setting:
             model_setting.load_balancing_enabled = True
@@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
                 provider_name=self.provider.provider,
                 model_type=model_type.to_origin_model_type(),
                 model_name=model,
-                load_balancing_enabled=True
+                load_balancing_enabled=True,
             )
             db.session.add(model_setting)
             db.session.commit()
@@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
         :param model: model name
         :return:
         """
-        model_setting = db.session.query(ProviderModelSetting) \
+        model_setting = (
+            db.session.query(ProviderModelSetting)
             .filter(
-            ProviderModelSetting.tenant_id == self.tenant_id,
-            ProviderModelSetting.provider_name == self.provider.provider,
-            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
-            ProviderModelSetting.model_name == model
-        ).first()
+                ProviderModelSetting.tenant_id == self.tenant_id,
+                ProviderModelSetting.provider_name == self.provider.provider,
+                ProviderModelSetting.model_type == model_type.to_origin_model_type(),
+                ProviderModelSetting.model_name == model,
+            )
+            .first()
+        )
 
         if model_setting:
             model_setting.load_balancing_enabled = False
@@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
                 provider_name=self.provider.provider,
                 model_type=model_type.to_origin_model_type(),
                 model_name=model,
-                load_balancing_enabled=False
+                load_balancing_enabled=False,
             )
             db.session.add(model_setting)
             db.session.commit()
@@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
             return
 
         # get preferred provider
-        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
+        preferred_model_provider = (
+            db.session.query(TenantPreferredModelProvider)
             .filter(
-            TenantPreferredModelProvider.tenant_id == self.tenant_id,
-            TenantPreferredModelProvider.provider_name == self.provider.provider
-        ).first()
+                TenantPreferredModelProvider.tenant_id == self.tenant_id,
+                TenantPreferredModelProvider.provider_name == self.provider.provider,
+            )
+            .first()
+        )
 
         if preferred_model_provider:
             preferred_model_provider.preferred_provider_type = provider_type.value
@@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
             preferred_model_provider = TenantPreferredModelProvider(
                 tenant_id=self.tenant_id,
                 provider_name=self.provider.provider,
-                preferred_provider_type=provider_type.value
+                preferred_provider_type=provider_type.value,
             )
             db.session.add(preferred_model_provider)
 
@@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
         :return:
         """
         # Get provider credential secret variables
-        credential_secret_variables = self.extract_secret_variables(
-            credential_form_schemas
-        )
+        credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
 
         # Obfuscate provider credentials
         copy_credentials = credentials.copy()
@@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
 
         return copy_credentials
 
-    def get_provider_model(self, model_type: ModelType,
-                           model: str,
-                           only_active: bool = False) -> Optional[ModelWithProviderEntity]:
+    def get_provider_model(
+        self, model_type: ModelType, model: str, only_active: bool = False
+    ) -> Optional[ModelWithProviderEntity]:
         """
         Get provider model.
         :param model_type: model type
@@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
 
         return None
 
-    def get_provider_models(self, model_type: Optional[ModelType] = None,
-                            only_active: bool = False) -> list[ModelWithProviderEntity]:
+    def get_provider_models(
+        self, model_type: Optional[ModelType] = None, only_active: bool = False
+    ) -> list[ModelWithProviderEntity]:
         """
         Get provider models.
         :param model_type: model type
@@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
 
         if self.using_provider_type == ProviderType.SYSTEM:
             provider_models = self._get_system_provider_models(
-                model_types=model_types,
-                provider_instance=provider_instance,
-                model_setting_map=model_setting_map
+                model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
             )
         else:
             provider_models = self._get_custom_provider_models(
-                model_types=model_types,
-                provider_instance=provider_instance,
-                model_setting_map=model_setting_map
+                model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
             )
 
         if only_active:
@@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
         # resort provider_models
         return sorted(provider_models, key=lambda x: x.model_type.value)
 
-    def _get_system_provider_models(self,
-                                    model_types: list[ModelType],
-                                    provider_instance: ModelProvider,
-                                    model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-            -> list[ModelWithProviderEntity]:
+    def _get_system_provider_models(
+        self,
+        model_types: list[ModelType],
+        provider_instance: ModelProvider,
+        model_setting_map: dict[ModelType, dict[str, ModelSettings]],
+    ) -> list[ModelWithProviderEntity]:
         """
         Get system provider models.
 
@@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
                         model_properties=m.model_properties,
                         deprecated=m.deprecated,
                         provider=SimpleModelProviderEntity(self.provider),
-                        status=status
+                        status=status,
                     )
                 )
 
@@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
 
             if should_use_custom_model:
                 if original_provider_configurate_methods[self.provider.provider] == [
-                    ConfigurateMethod.CUSTOMIZABLE_MODEL]:
+                    ConfigurateMethod.CUSTOMIZABLE_MODEL
+                ]:
                     # only customizable model
                     for restrict_model in restrict_models:
                         copy_credentials = self.system_configuration.credentials.copy()
                         if restrict_model.base_model_name:
-                            copy_credentials['base_model_name'] = restrict_model.base_model_name
+                            copy_credentials["base_model_name"] = restrict_model.base_model_name
 
                         try:
-                            custom_model_schema = (
-                                provider_instance.get_model_instance(restrict_model.model_type)
-                                .get_customizable_model_schema_from_credentials(
-                                    restrict_model.model,
-                                    copy_credentials
-                                )
-                            )
+                            custom_model_schema = provider_instance.get_model_instance(
+                                restrict_model.model_type
+                            ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
                         except Exception as ex:
-                            logger.warning(f'get custom model schema failed, {ex}')
+                            logger.warning(f"get custom model schema failed, {ex}")
                             continue
 
                         if not custom_model_schema:
@@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
                             continue
 
                         status = ModelStatus.ACTIVE
-                        if (custom_model_schema.model_type in model_setting_map
-                                and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
+                        if (
+                            custom_model_schema.model_type in model_setting_map
+                            and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
+                        ):
                             model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
                             if model_setting.enabled is False:
                                 status = ModelStatus.DISABLED
@@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
                                 model_properties=custom_model_schema.model_properties,
                                 deprecated=custom_model_schema.deprecated,
                                 provider=SimpleModelProviderEntity(self.provider),
-                                status=status
+                                status=status,
                             )
                         )
 
@@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
 
         return provider_models
 
-    def _get_custom_provider_models(self,
-                                    model_types: list[ModelType],
-                                    provider_instance: ModelProvider,
-                                    model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-            -> list[ModelWithProviderEntity]:
+    def _get_custom_provider_models(
+        self,
+        model_types: list[ModelType],
+        provider_instance: ModelProvider,
+        model_setting_map: dict[ModelType, dict[str, ModelSettings]],
+    ) -> list[ModelWithProviderEntity]:
         """
         Get custom provider models.
 
@@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
                         deprecated=m.deprecated,
                         provider=SimpleModelProviderEntity(self.provider),
                         status=status,
-                        load_balancing_enabled=load_balancing_enabled
+                        load_balancing_enabled=load_balancing_enabled,
                     )
                 )
 
@@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
                 continue
 
             try:
-                custom_model_schema = (
-                    provider_instance.get_model_instance(model_configuration.model_type)
-                    .get_customizable_model_schema_from_credentials(
-                        model_configuration.model,
-                        model_configuration.credentials
-                    )
+                custom_model_schema = provider_instance.get_model_instance(
+                    model_configuration.model_type
+                ).get_customizable_model_schema_from_credentials(
+                    model_configuration.model, model_configuration.credentials
                 )
             except Exception as ex:
-                logger.warning(f'get custom model schema failed, {ex}')
+                logger.warning(f"get custom model schema failed, {ex}")
                 continue
 
             if not custom_model_schema:
@@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
 
             status = ModelStatus.ACTIVE
             load_balancing_enabled = False
-            if (custom_model_schema.model_type in model_setting_map
-                    and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
+            if (
+                custom_model_schema.model_type in model_setting_map
+                and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
+            ):
                 model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
                 if model_setting.enabled is False:
                     status = ModelStatus.DISABLED
@@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
                     deprecated=custom_model_schema.deprecated,
                     provider=SimpleModelProviderEntity(self.provider),
                     status=status,
-                    load_balancing_enabled=load_balancing_enabled
+                    load_balancing_enabled=load_balancing_enabled,
                 )
             )
 
@@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
     """
     Model class for provider configuration dict.
     """
+
     tenant_id: str
     configurations: dict[str, ProviderConfiguration] = {}
 
     def __init__(self, tenant_id: str):
         super().__init__(tenant_id=tenant_id)
 
-    def get_models(self,
-                   provider: Optional[str] = None,
-                   model_type: Optional[ModelType] = None,
-                   only_active: bool = False) \
-            -> list[ModelWithProviderEntity]:
+    def get_models(
+        self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
+    ) -> list[ModelWithProviderEntity]:
         """
         Get available models.
 
@@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
     """
     Provider model bundle.
     """
+
     configuration: ProviderConfiguration
     provider_instance: ModelProvider
     model_type_instance: AIModel
 
     # pydantic configs
-    model_config = ConfigDict(arbitrary_types_allowed=True,
-                              protected_namespaces=())
+    model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())

+ 14 - 6
api/core/entities/provider_entities.py

@@ -8,18 +8,19 @@ from models.provider import ProviderQuotaType
 
 
 class QuotaUnit(Enum):
-    TIMES = 'times'
-    TOKENS = 'tokens'
-    CREDITS = 'credits'
+    TIMES = "times"
+    TOKENS = "tokens"
+    CREDITS = "credits"
 
 
 class SystemConfigurationStatus(Enum):
     """
     Enum class for system configuration status.
     """
-    ACTIVE = 'active'
-    QUOTA_EXCEEDED = 'quota-exceeded'
-    UNSUPPORTED = 'unsupported'
+
+    ACTIVE = "active"
+    QUOTA_EXCEEDED = "quota-exceeded"
+    UNSUPPORTED = "unsupported"
 
 
 class RestrictModel(BaseModel):
@@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel):
     """
     Model class for provider quota configuration.
     """
+
     quota_type: ProviderQuotaType
     quota_unit: QuotaUnit
     quota_limit: int
@@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel):
     """
     Model class for provider system configuration.
     """
+
     enabled: bool
     current_quota_type: Optional[ProviderQuotaType] = None
     quota_configurations: list[QuotaConfiguration] = []
@@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel):
     """
     Model class for provider custom configuration.
     """
+
     credentials: dict
 
 
@@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel):
     """
     Model class for provider custom model configuration.
     """
+
     model: str
     model_type: ModelType
     credentials: dict
@@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel):
     """
     Model class for provider custom configuration.
     """
+
     provider: Optional[CustomProviderConfiguration] = None
     models: list[CustomModelConfiguration] = []
 
@@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
     """
     Class for model load balancing configuration.
     """
+
     id: str
     name: str
     credentials: dict
@@ -93,6 +100,7 @@ class ModelSettings(BaseModel):
     """
     Model class for model settings.
     """
+
     model: str
     model_type: ModelType
     enabled: bool = True

+ 7 - 0
api/core/errors/error.py

@@ -3,6 +3,7 @@ from typing import Optional
 
 class LLMError(Exception):
     """Base class for all LLM exceptions."""
+
     description: Optional[str] = None
 
     def __init__(self, description: Optional[str] = None) -> None:
@@ -11,6 +12,7 @@ class LLMError(Exception):
 
 class LLMBadRequestError(LLMError):
     """Raised when the LLM returns bad request."""
+
     description = "Bad Request"
 
 
@@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception):
     """
     Custom exception raised when the provider token is not initialized.
     """
+
     description = "Provider Token Not Init"
 
     def __init__(self, *args, **kwargs):
@@ -28,6 +31,7 @@ class QuotaExceededError(Exception):
     """
     Custom exception raised when the quota for a provider has been exceeded.
     """
+
     description = "Quota Exceeded"
 
 
@@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception):
     """
     Custom exception raised when the quota for an app has been exceeded.
     """
+
     description = "App Invoke Quota Exceeded"
 
 
@@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception):
     """
     Custom exception raised when the model not support
     """
+
     description = "Model Currently Not Support"
 
 
 class InvokeRateLimitError(Exception):
     """Raised when the Invoke returns rate limit error."""
+
     description = "Rate Limit Error"

+ 9 - 16
api/core/extension/api_based_extension_requestor.py

@@ -20,10 +20,7 @@ class APIBasedExtensionRequestor:
         :param params: the request params
         :return: the response json
         """
-        headers = {
-            "Content-Type": "application/json",
-            "Authorization": "Bearer {}".format(self.api_key)
-        }
+        headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)}
 
         url = self.api_endpoint
 
@@ -32,20 +29,17 @@ class APIBasedExtensionRequestor:
             proxies = None
             if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
                 proxies = {
-                    'http': dify_config.SSRF_PROXY_HTTP_URL,
-                    'https': dify_config.SSRF_PROXY_HTTPS_URL,
+                    "http": dify_config.SSRF_PROXY_HTTP_URL,
+                    "https": dify_config.SSRF_PROXY_HTTPS_URL,
                 }
 
             response = requests.request(
-                method='POST',
+                method="POST",
                 url=url,
-                json={
-                    'point': point.value,
-                    'params': params
-                },
+                json={"point": point.value, "params": params},
                 headers=headers,
                 timeout=self.timeout,
-                proxies=proxies
+                proxies=proxies,
             )
         except requests.exceptions.Timeout:
             raise ValueError("request timeout")
@@ -53,9 +47,8 @@ class APIBasedExtensionRequestor:
             raise ValueError("request connection error")
 
         if response.status_code != 200:
-            raise ValueError("request error, status_code: {}, content: {}".format(
-                response.status_code,
-                response.text[:100]
-            ))
+            raise ValueError(
+                "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
+            )
 
         return response.json()

+ 26 - 22
api/core/extension/extensible.py

@@ -11,8 +11,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map
 
 
 class ExtensionModule(enum.Enum):
-    MODERATION = 'moderation'
-    EXTERNAL_DATA_TOOL = 'external_data_tool'
+    MODERATION = "moderation"
+    EXTERNAL_DATA_TOOL = "external_data_tool"
 
 
 class ModuleExtension(BaseModel):
@@ -41,12 +41,12 @@ class Extensible:
         position_map = {}
 
         # get the path of the current class
-        current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
+        current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
         current_dir_path = os.path.dirname(current_path)
 
         # traverse subdirectories
         for subdir_name in os.listdir(current_dir_path):
-            if subdir_name.startswith('__'):
+            if subdir_name.startswith("__"):
                 continue
 
             subdir_path = os.path.join(current_dir_path, subdir_name)
@@ -58,21 +58,21 @@ class Extensible:
                 # in the front-end page and business logic, there are special treatments.
                 builtin = False
                 position = None
-                if '__builtin__' in file_names:
+                if "__builtin__" in file_names:
                     builtin = True
 
-                    builtin_file_path = os.path.join(subdir_path, '__builtin__')
+                    builtin_file_path = os.path.join(subdir_path, "__builtin__")
                     if os.path.exists(builtin_file_path):
-                        with open(builtin_file_path, encoding='utf-8') as f:
+                        with open(builtin_file_path, encoding="utf-8") as f:
                             position = int(f.read().strip())
                     position_map[extension_name] = position
 
-                if (extension_name + '.py') not in file_names:
+                if (extension_name + ".py") not in file_names:
                     logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
                     continue
 
                 # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
-                py_path = os.path.join(subdir_path, extension_name + '.py')
+                py_path = os.path.join(subdir_path, extension_name + ".py")
                 spec = importlib.util.spec_from_file_location(extension_name, py_path)
                 if not spec or not spec.loader:
                     raise Exception(f"Failed to load module {extension_name} from {py_path}")
@@ -91,25 +91,29 @@ class Extensible:
 
                 json_data = {}
                 if not builtin:
-                    if 'schema.json' not in file_names:
+                    if "schema.json" not in file_names:
                         logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
                         continue
 
-                    json_path = os.path.join(subdir_path, 'schema.json')
+                    json_path = os.path.join(subdir_path, "schema.json")
                     json_data = {}
                     if os.path.exists(json_path):
-                        with open(json_path, encoding='utf-8') as f:
+                        with open(json_path, encoding="utf-8") as f:
                             json_data = json.load(f)
 
-                extensions.append(ModuleExtension(
-                    extension_class=extension_class,
-                    name=extension_name,
-                    label=json_data.get('label'),
-                    form_schema=json_data.get('form_schema'),
-                    builtin=builtin,
-                    position=position
-                ))
-
-        sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)
+                extensions.append(
+                    ModuleExtension(
+                        extension_class=extension_class,
+                        name=extension_name,
+                        label=json_data.get("label"),
+                        form_schema=json_data.get("form_schema"),
+                        builtin=builtin,
+                        position=position,
+                    )
+                )
+
+        sorted_extensions = sort_to_dict_by_position_map(
+            position_map=position_map, data=extensions, name_func=lambda x: x.name
+        )
 
         return sorted_extensions

+ 1 - 4
api/core/extension/extension.py

@@ -6,10 +6,7 @@ from core.moderation.base import Moderation
 class Extension:
     __module_extensions: dict[str, dict[str, ModuleExtension]] = {}
 
-    module_classes = {
-        ExtensionModule.MODERATION: Moderation,
-        ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
-    }
+    module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool}
 
     def init(self):
         for module, module_class in self.module_classes.items():

+ 36 - 40
api/core/external_data_tool/api/api.py

@@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
             raise ValueError("api_based_extension_id is required")
 
         # get api_based_extension
-        api_based_extension = db.session.query(APIBasedExtension).filter(
-            APIBasedExtension.tenant_id == tenant_id,
-            APIBasedExtension.id == api_based_extension_id
-        ).first()
+        api_based_extension = (
+            db.session.query(APIBasedExtension)
+            .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
+            .first()
+        )
 
         if not api_based_extension:
             raise ValueError("api_based_extension_id is invalid")
@@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool):
         api_based_extension_id = self.config.get("api_based_extension_id")
 
         # get api_based_extension
-        api_based_extension = db.session.query(APIBasedExtension).filter(
-            APIBasedExtension.tenant_id == self.tenant_id,
-            APIBasedExtension.id == api_based_extension_id
-        ).first()
+        api_based_extension = (
+            db.session.query(APIBasedExtension)
+            .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
+            .first()
+        )
 
         if not api_based_extension:
-            raise ValueError("[External data tool] API query failed, variable: {}, "
-                             "error: api_based_extension_id is invalid"
-                             .format(self.variable))
+            raise ValueError(
+                "[External data tool] API query failed, variable: {}, "
+                "error: api_based_extension_id is invalid".format(self.variable)
+            )
 
         # decrypt api_key
-        api_key = encrypter.decrypt_token(
-            tenant_id=self.tenant_id,
-            token=api_based_extension.api_key
-        )
+        api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key)
 
         try:
             # request api
-            requestor = APIBasedExtensionRequestor(
-                api_endpoint=api_based_extension.api_endpoint,
-                api_key=api_key
-            )
+            requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
         except Exception as e:
-            raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
-                self.variable,
-                e
-            ))
-
-        response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
-            'app_id': self.app_id,
-            'tool_variable': self.variable,
-            'inputs': inputs,
-            'query': query
-        })
-
-        if 'result' not in response_json:
-            raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
-                             .format(self.variable))
-
-        if not isinstance(response_json['result'], str):
-            raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string"
-                             .format(self.variable))
-
-        return response_json['result']
+            raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e))
+
+        response_json = requestor.request(
+            point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
+            params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query},
+        )
+
+        if "result" not in response_json:
+            raise ValueError(
+                "[External data tool] API query failed, variable: {}, error: result not found in response".format(
+                    self.variable
+                )
+            )
+
+        if not isinstance(response_json["result"], str):
+            raise ValueError(
+                "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable)
+            )
+
+        return response_json["result"]

+ 20 - 21
api/core/external_data_tool/external_data_fetch.py

@@ -12,11 +12,14 @@ logger = logging.getLogger(__name__)
 
 
 class ExternalDataFetch:
-    def fetch(self, tenant_id: str,
-              app_id: str,
-              external_data_tools: list[ExternalDataVariableEntity],
-              inputs: dict,
-              query: str) -> dict:
+    def fetch(
+        self,
+        tenant_id: str,
+        app_id: str,
+        external_data_tools: list[ExternalDataVariableEntity],
+        inputs: dict,
+        query: str,
+    ) -> dict:
         """
         Fill in variable inputs from external data tools if exists.
 
@@ -38,7 +41,7 @@ class ExternalDataFetch:
                     app_id,
                     tool,
                     inputs,
-                    query
+                    query,
                 )
 
                 futures[future] = tool
@@ -50,12 +53,15 @@ class ExternalDataFetch:
         inputs.update(results)
         return inputs
 
-    def _query_external_data_tool(self, flask_app: Flask,
-                                  tenant_id: str,
-                                  app_id: str,
-                                  external_data_tool: ExternalDataVariableEntity,
-                                  inputs: dict,
-                                  query: str) -> tuple[Optional[str], Optional[str]]:
+    def _query_external_data_tool(
+        self,
+        flask_app: Flask,
+        tenant_id: str,
+        app_id: str,
+        external_data_tool: ExternalDataVariableEntity,
+        inputs: dict,
+        query: str,
+    ) -> tuple[Optional[str], Optional[str]]:
         """
         Query external data tool.
         :param flask_app: flask app
@@ -72,17 +78,10 @@ class ExternalDataFetch:
             tool_config = external_data_tool.config
 
             external_data_tool_factory = ExternalDataToolFactory(
-                name=tool_type,
-                tenant_id=tenant_id,
-                app_id=app_id,
-                variable=tool_variable,
-                config=tool_config
+                name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config
             )
 
             # query external data tool
-            result = external_data_tool_factory.query(
-                inputs=inputs,
-                query=query
-            )
+            result = external_data_tool_factory.query(inputs=inputs, query=query)
 
             return tool_variable, result

+ 1 - 5
api/core/external_data_tool/factory.py

@@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension
 
 
 class ExternalDataToolFactory:
-
     def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
         extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
         self.__extension_instance = extension_class(
-            tenant_id=tenant_id,
-            app_id=app_id,
-            variable=variable,
-            config=config
+            tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
         )
 
     @classmethod

+ 31 - 28
api/core/file/file_obj.py

@@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel):
     """
     File Upload Entity.
     """
+
     image_config: Optional[dict[str, Any]] = None
 
 
 class FileType(enum.Enum):
-    IMAGE = 'image'
+    IMAGE = "image"
 
     @staticmethod
     def value_of(value):
@@ -28,9 +29,9 @@ class FileType(enum.Enum):
 
 
 class FileTransferMethod(enum.Enum):
-    REMOTE_URL = 'remote_url'
-    LOCAL_FILE = 'local_file'
-    TOOL_FILE = 'tool_file'
+    REMOTE_URL = "remote_url"
+    LOCAL_FILE = "local_file"
+    TOOL_FILE = "tool_file"
 
     @staticmethod
     def value_of(value):
@@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum):
                 return member
         raise ValueError(f"No matching enum found for value '{value}'")
 
+
 class FileBelongsTo(enum.Enum):
-    USER = 'user'
-    ASSISTANT = 'assistant'
+    USER = "user"
+    ASSISTANT = "assistant"
 
     @staticmethod
     def value_of(value):
@@ -65,16 +67,16 @@ class FileVar(BaseModel):
 
     def to_dict(self) -> dict:
         return {
-            '__variant': self.__class__.__name__,
-            'tenant_id': self.tenant_id,
-            'type': self.type.value,
-            'transfer_method': self.transfer_method.value,
-            'url': self.preview_url,
-            'remote_url': self.url,
-            'related_id': self.related_id,
-            'filename': self.filename,
-            'extension': self.extension,
-            'mime_type': self.mime_type,
+            "__variant": self.__class__.__name__,
+            "tenant_id": self.tenant_id,
+            "type": self.type.value,
+            "transfer_method": self.transfer_method.value,
+            "url": self.preview_url,
+            "remote_url": self.url,
+            "related_id": self.related_id,
+            "filename": self.filename,
+            "extension": self.extension,
+            "mime_type": self.mime_type,
         }
 
     def to_markdown(self) -> str:
@@ -86,7 +88,7 @@ class FileVar(BaseModel):
         if self.type == FileType.IMAGE:
             text = f'![{self.filename or ""}]({preview_url})'
         else:
-            text = f'[{self.filename or preview_url}]({preview_url})'
+            text = f"[{self.filename or preview_url}]({preview_url})"
 
         return text
 
@@ -115,28 +117,29 @@ class FileVar(BaseModel):
             return ImagePromptMessageContent(
                 data=self.data,
                 detail=ImagePromptMessageContent.DETAIL.HIGH
-                if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
+                if image_config.get("detail") == "high"
+                else ImagePromptMessageContent.DETAIL.LOW,
             )
 
     def _get_data(self, force_url: bool = False) -> Optional[str]:
         from models.model import UploadFile
+
         if self.type == FileType.IMAGE:
             if self.transfer_method == FileTransferMethod.REMOTE_URL:
                 return self.url
             elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
-                upload_file = (db.session.query(UploadFile)
-                               .filter(
-                    UploadFile.id == self.related_id,
-                    UploadFile.tenant_id == self.tenant_id
-                ).first())
-
-                return UploadFileParser.get_image_data(
-                    upload_file=upload_file,
-                    force_url=force_url
+                upload_file = (
+                    db.session.query(UploadFile)
+                    .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
+                    .first()
                 )
+
+                return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
             elif self.transfer_method == FileTransferMethod.TOOL_FILE:
                 extension = self.extension
                 # add sign url
-                return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension)
+                return ToolFileParser.get_tool_file_manager().sign_file(
+                    tool_file_id=self.related_id, extension=extension
+                )
 
         return None

+ 52 - 48
api/core/file/message_file_parser.py

@@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS
 
 
 class MessageFileParser:
-
     def __init__(self, tenant_id: str, app_id: str) -> None:
         self.tenant_id = tenant_id
         self.app_id = app_id
 
-    def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
-                                         user: Union[Account, EndUser]) -> list[FileVar]:
+    def validate_and_transform_files_arg(
+        self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
+    ) -> list[FileVar]:
         """
         validate and transform files arg
 
@@ -30,22 +30,22 @@ class MessageFileParser:
         """
         for file in files:
             if not isinstance(file, dict):
-                raise ValueError('Invalid file format, must be dict')
-            if not file.get('type'):
-                raise ValueError('Missing file type')
-            FileType.value_of(file.get('type'))
-            if not file.get('transfer_method'):
-                raise ValueError('Missing file transfer method')
-            FileTransferMethod.value_of(file.get('transfer_method'))
-            if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
-                if not file.get('url'):
-                    raise ValueError('Missing file url')
-                if not file.get('url').startswith('http'):
-                    raise ValueError('Invalid file url')
-            if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
-                raise ValueError('Missing file upload_file_id')
-            if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
-                raise ValueError('Missing file tool_file_id')
+                raise ValueError("Invalid file format, must be dict")
+            if not file.get("type"):
+                raise ValueError("Missing file type")
+            FileType.value_of(file.get("type"))
+            if not file.get("transfer_method"):
+                raise ValueError("Missing file transfer method")
+            FileTransferMethod.value_of(file.get("transfer_method"))
+            if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
+                if not file.get("url"):
+                    raise ValueError("Missing file url")
+                if not file.get("url").startswith("http"):
+                    raise ValueError("Invalid file url")
+            if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
+                raise ValueError("Missing file upload_file_id")
+            if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
+                raise ValueError("Missing file tool_file_id")
 
         # transform files to file objs
         type_file_objs = self._to_file_objs(files, file_extra_config)
@@ -62,17 +62,17 @@ class MessageFileParser:
                     continue
 
                 # Validate number of files
-                if len(files) > image_config['number_limits']:
+                if len(files) > image_config["number_limits"]:
                     raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
 
                 for file_obj in file_objs:
                     # Validate transfer method
-                    if file_obj.transfer_method.value not in image_config['transfer_methods']:
-                        raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
+                    if file_obj.transfer_method.value not in image_config["transfer_methods"]:
+                        raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
 
                     # Validate file type
                     if file_obj.type != FileType.IMAGE:
-                        raise ValueError(f'Invalid file type: {file_obj.type}')
+                        raise ValueError(f"Invalid file type: {file_obj.type}")
 
                     if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
                         # check remote url valid and is image
@@ -81,18 +81,21 @@ class MessageFileParser:
                             raise ValueError(error)
                     elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
                         # get upload file from upload_file_id
-                        upload_file = (db.session.query(UploadFile)
-                                       .filter(
-                            UploadFile.id == file_obj.related_id,
-                            UploadFile.tenant_id == self.tenant_id,
-                            UploadFile.created_by == user.id,
-                            UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
-                            UploadFile.extension.in_(IMAGE_EXTENSIONS)
-                        ).first())
+                        upload_file = (
+                            db.session.query(UploadFile)
+                            .filter(
+                                UploadFile.id == file_obj.related_id,
+                                UploadFile.tenant_id == self.tenant_id,
+                                UploadFile.created_by == user.id,
+                                UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+                                UploadFile.extension.in_(IMAGE_EXTENSIONS),
+                            )
+                            .first()
+                        )
 
                         # check upload file is belong to tenant and user
                         if not upload_file:
-                            raise ValueError('Invalid upload file')
+                            raise ValueError("Invalid upload file")
 
                     new_files.append(file_obj)
 
@@ -113,8 +116,9 @@ class MessageFileParser:
         # return all file objs
         return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
 
-    def _to_file_objs(self, files: list[Union[dict, MessageFile]],
-                      file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]:
+    def _to_file_objs(
+        self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
+    ) -> dict[FileType, list[FileVar]]:
         """
         transform files to file objs
 
@@ -152,23 +156,23 @@ class MessageFileParser:
         :return:
         """
         if isinstance(file, dict):
-            transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
+            transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
             if transfer_method != FileTransferMethod.TOOL_FILE:
                 return FileVar(
                     tenant_id=self.tenant_id,
-                    type=FileType.value_of(file.get('type')),
+                    type=FileType.value_of(file.get("type")),
                     transfer_method=transfer_method,
-                    url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
-                    related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
-                    extra_config=file_extra_config
+                    url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
+                    related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+                    extra_config=file_extra_config,
                 )
             return FileVar(
                 tenant_id=self.tenant_id,
-                type=FileType.value_of(file.get('type')),
+                type=FileType.value_of(file.get("type")),
                 transfer_method=transfer_method,
                 url=None,
-                related_id=file.get('tool_file_id'),
-                extra_config=file_extra_config
+                related_id=file.get("tool_file_id"),
+                extra_config=file_extra_config,
             )
         else:
             return FileVar(
@@ -178,7 +182,7 @@ class MessageFileParser:
                 transfer_method=FileTransferMethod.value_of(file.transfer_method),
                 url=file.url,
                 related_id=file.upload_file_id or None,
-                extra_config=file_extra_config
+                extra_config=file_extra_config,
             )
 
     def _check_image_remote_url(self, url):
@@ -190,17 +194,17 @@ class MessageFileParser:
             def is_s3_presigned_url(url):
                 try:
                     parsed_url = urlparse(url)
-                    if 'amazonaws.com' not in parsed_url.netloc:
+                    if "amazonaws.com" not in parsed_url.netloc:
                         return False
                     query_params = parse_qs(parsed_url.query)
-                    required_params = ['Signature', 'Expires']
+                    required_params = ["Signature", "Expires"]
                     for param in required_params:
                         if param not in query_params:
                             return False
-                    if not query_params['Expires'][0].isdigit():
+                    if not query_params["Expires"][0].isdigit():
                         return False
-                    signature = query_params['Signature'][0]
-                    if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature):
+                    signature = query_params["Signature"][0]
+                    if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
                         return False
                     return True
                 except Exception:

+ 4 - 5
api/core/file/tool_file_parser.py

@@ -1,8 +1,7 @@
-tool_file_manager = {
-    'manager': None
-}
+tool_file_manager = {"manager": None}
+
 
 class ToolFileParser:
     @staticmethod
-    def get_tool_file_manager() -> 'ToolFileManager':
-        return tool_file_manager['manager']
+    def get_tool_file_manager() -> "ToolFileManager":
+        return tool_file_manager["manager"]

+ 6 - 6
api/core/file/upload_file_parser.py

@@ -9,7 +9,7 @@ from typing import Optional
 from configs import dify_config
 from extensions.ext_storage import storage
 
-IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
+IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
 IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
 
 
@@ -22,18 +22,18 @@ class UploadFileParser:
         if upload_file.extension not in IMAGE_EXTENSIONS:
             return None
 
-        if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url:
+        if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
             return cls.get_signed_temp_image_url(upload_file.id)
         else:
             # get image file base64
             try:
                 data = storage.load(upload_file.key)
             except FileNotFoundError:
-                logging.error(f'File not found: {upload_file.key}')
+                logging.error(f"File not found: {upload_file.key}")
                 return None
 
-            encoded_string = base64.b64encode(data).decode('utf-8')
-            return f'data:{upload_file.mime_type};base64,{encoded_string}'
+            encoded_string = base64.b64encode(data).decode("utf-8")
+            return f"data:{upload_file.mime_type};base64,{encoded_string}"
 
     @classmethod
     def get_signed_temp_image_url(cls, upload_file_id) -> str:
@@ -44,7 +44,7 @@ class UploadFileParser:
         :return:
         """
         base_url = dify_config.FILES_URL
-        image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview'
+        image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
 
         timestamp = str(int(time.time()))
         nonce = os.urandom(16).hex()

+ 38 - 34
api/core/helper/code_executor/code_executor.py

@@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
 
 logger = logging.getLogger(__name__)
 
+
 class CodeExecutionException(Exception):
     pass
 
+
 class CodeExecutionResponse(BaseModel):
     class Data(BaseModel):
         stdout: Optional[str] = None
@@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel):
 
 
 class CodeLanguage(str, Enum):
-    PYTHON3 = 'python3'
-    JINJA2 = 'jinja2'
-    JAVASCRIPT = 'javascript'
+    PYTHON3 = "python3"
+    JINJA2 = "jinja2"
+    JAVASCRIPT = "javascript"
 
 
 class CodeExecutor:
@@ -45,63 +47,65 @@ class CodeExecutor:
     }
 
     code_language_to_running_language = {
-        CodeLanguage.JAVASCRIPT: 'nodejs',
+        CodeLanguage.JAVASCRIPT: "nodejs",
         CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
         CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
     }
 
-    supported_dependencies_languages: set[CodeLanguage] = {
-        CodeLanguage.PYTHON3
-    }
+    supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3}
 
     @classmethod
-    def execute_code(cls, 
-                     language: CodeLanguage, 
-                     preload: str, 
-                     code: str) -> str:
+    def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str:
         """
         Execute code
         :param language: code language
         :param code: code
         :return:
         """
-        url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run'
+        url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
 
-        headers = {
-            'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY
-        }
+        headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
 
         data = {
-            'language': cls.code_language_to_running_language.get(language),
-            'code': code,
-            'preload': preload,
-            'enable_network': True
+            "language": cls.code_language_to_running_language.get(language),
+            "code": code,
+            "preload": preload,
+            "enable_network": True,
         }
 
         try:
-            response = post(str(url), json=data, headers=headers,
-                            timeout=Timeout(
-                                connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
-                                read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
-                                write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
-                                pool=None))
+            response = post(
+                str(url),
+                json=data,
+                headers=headers,
+                timeout=Timeout(
+                    connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
+                    read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
+                    write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
+                    pool=None,
+                ),
+            )
             if response.status_code == 503:
-                raise CodeExecutionException('Code execution service is unavailable')
+                raise CodeExecutionException("Code execution service is unavailable")
             elif response.status_code != 200:
-                raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running')
+                raise Exception(
+                    f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
+                )
         except CodeExecutionException as e:
             raise e
         except Exception as e:
-            raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
-                                         ' please check if the sandbox service is running.'
-                                         f' ( Error: {str(e)} )')
+            raise CodeExecutionException(
+                "Failed to execute code, which is likely a network issue,"
+                " please check if the sandbox service is running."
+                f" ( Error: {str(e)} )"
+            )
 
         try:
             response = response.json()
         except:
-            raise CodeExecutionException('Failed to parse response')
+            raise CodeExecutionException("Failed to parse response")
 
-        if (code := response.get('code')) != 0:
+        if (code := response.get("code")) != 0:
             raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
 
         response = CodeExecutionResponse(**response)
@@ -109,7 +113,7 @@ class CodeExecutor:
         if response.data.error:
             raise CodeExecutionException(response.data.error)
 
-        return response.data.stdout or ''
+        return response.data.stdout or ""
 
     @classmethod
     def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
@@ -122,7 +126,7 @@ class CodeExecutor:
         """
         template_transformer = cls.code_template_transformers.get(language)
         if not template_transformer:
-            raise CodeExecutionException(f'Unsupported language {language}')
+            raise CodeExecutionException(f"Unsupported language {language}")
 
         runner, preload = template_transformer.transform_caller(code, inputs)
 

+ 3 - 17
api/core/helper/code_executor/code_node_provider.py

@@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel):
         return {
             "type": "code",
             "config": {
-                "variables": [
-                    {
-                        "variable": "arg1",
-                        "value_selector": []
-                    },
-                    {
-                        "variable": "arg2",
-                        "value_selector": []
-                    }
-                ],
+                "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
                 "code_language": cls.get_language(),
                 "code": cls.get_default_code(),
-                "outputs": {
-                    "result": {
-                        "type": "string",
-                        "children": None
-                    }
-                }
-            }
+                "outputs": {"result": {"type": "string", "children": None}},
+            },
         }

+ 2 - 1
api/core/helper/code_executor/javascript/javascript_code_provider.py

@@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider):
                     result: arg1 + arg2
                 }
             }
-            """)
+            """
+        )

+ 2 - 1
api/core/helper/code_executor/javascript/javascript_transformer.py

@@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
             var output_json = JSON.stringify(output_obj)
             var result = `<<RESULT>>${{output_json}}<<RESULT>>`
             console.log(result)
-            """)
+            """
+        )
         return runner_script

+ 2 - 4
api/core/helper/code_executor/jinja2/jinja2_formatter.py

@@ -10,8 +10,6 @@ class Jinja2Formatter:
         :param inputs: inputs
         :return:
         """
-        result = CodeExecutor.execute_workflow_code_template(
-            language=CodeLanguage.JINJA2, code=template, inputs=inputs
-        )
+        result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
 
-        return result['result']
+        return result["result"]

+ 1 - 3
api/core/helper/code_executor/jinja2/jinja2_transformer.py

@@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
         :param response: response
         :return:
         """
-        return {
-            'result': cls.extract_result_str_from_response(response)
-        }
+        return {"result": cls.extract_result_str_from_response(response)}
 
     @classmethod
     def get_runner_script(cls) -> str:

+ 2 - 1
api/core/helper/code_executor/python3/python3_code_provider.py

@@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider):
                 return {
                     "result": arg1 + arg2,
                 }
-            """)
+            """
+        )

Некоторые файлы не были показаны из-за большого количества измененных файлов