Browse Source

chore: refurish python code by applying Pylint linter rules (#8322)

Bowen Liang 7 months ago
parent
commit
a1104ab97e
100 changed files with 198 additions and 218 deletions
  1. 1 1
      api/app.py
  2. 2 2
      api/commands.py
  3. 1 1
      api/controllers/console/app/audio.py
  4. 1 1
      api/controllers/console/auth/oauth.py
  5. 3 3
      api/controllers/console/datasets/datasets_document.py
  6. 1 1
      api/controllers/console/explore/audio.py
  7. 2 2
      api/controllers/console/explore/completion.py
  8. 5 5
      api/controllers/console/explore/conversation.py
  9. 1 1
      api/controllers/console/explore/installed_app.py
  10. 2 2
      api/controllers/console/explore/message.py
  11. 1 1
      api/controllers/console/explore/parameter.py
  12. 1 1
      api/controllers/console/workspace/workspace.py
  13. 1 1
      api/controllers/service_api/app/app.py
  14. 1 1
      api/controllers/service_api/app/audio.py
  15. 2 2
      api/controllers/service_api/app/completion.py
  16. 3 3
      api/controllers/service_api/app/conversation.py
  17. 2 2
      api/controllers/service_api/app/message.py
  18. 1 1
      api/controllers/web/app.py
  19. 1 1
      api/controllers/web/audio.py
  20. 2 2
      api/controllers/web/completion.py
  21. 5 5
      api/controllers/web/conversation.py
  22. 2 2
      api/controllers/web/message.py
  23. 2 2
      api/core/agent/output_parser/cot_output_parser.py
  24. 1 1
      api/core/app/app_config/base_app_config_manager.py
  25. 3 3
      api/core/app/app_config/easy_ui_based_app/agent/manager.py
  26. 1 1
      api/core/app/app_config/easy_ui_based_app/dataset/manager.py
  27. 3 3
      api/core/app/app_config/easy_ui_based_app/variables/manager.py
  28. 2 2
      api/core/app/app_config/features/file_upload/manager.py
  29. 2 2
      api/core/app/apps/advanced_chat/app_runner.py
  30. 1 1
      api/core/app/apps/base_app_generate_response_converter.py
  31. 3 3
      api/core/app/apps/base_app_generator.py
  32. 2 2
      api/core/app/apps/base_app_queue_manager.py
  33. 3 3
      api/core/app/apps/message_based_app_generator.py
  34. 2 2
      api/core/app/apps/workflow/app_runner.py
  35. 1 1
      api/core/app/features/annotation_reply/annotation_reply.py
  36. 1 1
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  37. 2 2
      api/core/app/task_pipeline/workflow_cycle_manage.py
  38. 1 1
      api/core/callback_handler/index_tool_callback_handler.py
  39. 1 1
      api/core/indexing_runner.py
  40. 1 1
      api/core/memory/token_buffer_memory.py
  41. 6 6
      api/core/model_runtime/entities/model_entities.py
  42. 1 1
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  43. 2 2
      api/core/model_runtime/model_providers/azure_openai/tts/tts.py
  44. 5 5
      api/core/model_runtime/model_providers/bedrock/llm/llm.py
  45. 4 4
      api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
  46. 2 2
      api/core/model_runtime/model_providers/google/llm/llm.py
  47. 2 2
      api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
  48. 1 1
      api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py
  49. 2 2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
  50. 2 2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
  51. 1 1
      api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
  52. 1 1
      api/core/model_runtime/model_providers/openai/llm/llm.py
  53. 2 2
      api/core/model_runtime/model_providers/openai/tts/tts.py
  54. 0 1
      api/core/model_runtime/model_providers/openrouter/llm/llm.py
  55. 1 1
      api/core/model_runtime/model_providers/replicate/llm/llm.py
  56. 2 2
      api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
  57. 4 4
      api/core/model_runtime/model_providers/tongyi/llm/llm.py
  58. 1 1
      api/core/model_runtime/model_providers/upstage/llm/llm.py
  59. 2 2
      api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
  60. 1 2
      api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py
  61. 1 1
      api/core/model_runtime/model_providers/wenxin/llm/llm.py
  62. 1 1
      api/core/model_runtime/model_providers/xinference/xinference_helper.py
  63. 9 12
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  64. 2 3
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
  65. 2 2
      api/core/model_runtime/schema_validators/common_validator.py
  66. 2 2
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  67. 2 2
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  68. 1 9
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  69. 6 6
      api/core/rag/extractor/extract_processor.py
  70. 1 1
      api/core/rag/extractor/firecrawl/firecrawl_app.py
  71. 2 2
      api/core/rag/extractor/notion_extractor.py
  72. 1 1
      api/core/rag/retrieval/dataset_retrieval.py
  73. 1 1
      api/core/rag/splitter/text_splitter.py
  74. 1 1
      api/core/tools/provider/app_tool_provider.py
  75. 2 2
      api/core/tools/provider/builtin/aippt/tools/aippt.py
  76. 2 2
      api/core/tools/provider/builtin/azuredalle/tools/dalle3.py
  77. 1 1
      api/core/tools/provider/builtin/code/tools/simple_code.py
  78. 2 2
      api/core/tools/provider/builtin/cogview/tools/cogview3.py
  79. 2 2
      api/core/tools/provider/builtin/dalle/tools/dalle3.py
  80. 2 2
      api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py
  81. 3 3
      api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py
  82. 1 1
      api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py
  83. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google.py
  84. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google_jobs.py
  85. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google_news.py
  86. 1 1
      api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py
  87. 1 1
      api/core/tools/provider/builtin/spider/spiderApp.py
  88. 1 1
      api/core/tools/provider/builtin/stability/tools/text2image.py
  89. 1 1
      api/core/tools/provider/builtin/vanna/tools/vanna.py
  90. 1 1
      api/core/tools/provider/builtin_tool_provider.py
  91. 9 9
      api/core/tools/provider/tool_provider.py
  92. 4 4
      api/core/tools/tool/api_tool.py
  93. 3 9
      api/core/tools/tool_engine.py
  94. 1 1
      api/core/tools/utils/message_transformer.py
  95. 1 1
      api/core/tools/utils/parser.py
  96. 1 1
      api/core/tools/utils/web_reader_tool.py
  97. 1 1
      api/core/workflow/graph_engine/entities/runtime_route_state.py
  98. 2 2
      api/core/workflow/nodes/answer/answer_stream_generate_router.py
  99. 2 2
      api/core/workflow/nodes/end/end_stream_generate_router.py
  100. 4 4
      api/core/workflow/nodes/http_request/http_executor.py

+ 1 - 1
api/app.py

@@ -164,7 +164,7 @@ def initialize_extensions(app):
 @login_manager.request_loader
 @login_manager.request_loader
 def load_user_from_request(request_from_flask_login):
 def load_user_from_request(request_from_flask_login):
     """Load user based on the request."""
     """Load user based on the request."""
-    if request.blueprint not in ["console", "inner_api"]:
+    if request.blueprint not in {"console", "inner_api"}:
         return None
         return None
     # Check if the user_id contains a dot, indicating the old format
     # Check if the user_id contains a dot, indicating the old format
     auth_header = request.headers.get("Authorization", "")
     auth_header = request.headers.get("Authorization", "")

+ 2 - 2
api/commands.py

@@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
 @click.command("vdb-migrate", help="migrate vector db.")
 @click.command("vdb-migrate", help="migrate vector db.")
 @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
 @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
 def vdb_migrate(scope: str):
 def vdb_migrate(scope: str):
-    if scope in ["knowledge", "all"]:
+    if scope in {"knowledge", "all"}:
         migrate_knowledge_vector_database()
         migrate_knowledge_vector_database()
-    if scope in ["annotation", "all"]:
+    if scope in {"annotation", "all"}:
         migrate_annotation_vector_database()
         migrate_annotation_vector_database()
 
 
 
 

+ 1 - 1
api/controllers/console/app/audio.py

@@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource):
             message_id = args.get("message_id", None)
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             text = args.get("text", None)
             if (
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow
                 and app_model.workflow.features_dict
                 and app_model.workflow.features_dict
             ):
             ):

+ 1 - 1
api/controllers/console/auth/oauth.py

@@ -71,7 +71,7 @@ class OAuthCallback(Resource):
 
 
         account = _generate_account(provider, user_info)
         account = _generate_account(provider, user_info)
         # Check account status
         # Check account status
-        if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
+        if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
             return {"error": "Account is banned or closed."}, 403
             return {"error": "Account is banned or closed."}, 403
 
 
         if account.status == AccountStatus.PENDING.value:
         if account.status == AccountStatus.PENDING.value:

+ 3 - 3
api/controllers/console/datasets/datasets_document.py

@@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
         document_id = str(document_id)
         document_id = str(document_id)
         document = self.get_document(dataset_id, document_id)
         document = self.get_document(dataset_id, document_id)
 
 
-        if document.indexing_status in ["completed", "error"]:
+        if document.indexing_status in {"completed", "error"}:
             raise DocumentAlreadyFinishedError()
             raise DocumentAlreadyFinishedError()
 
 
         data_process_rule = document.dataset_process_rule
         data_process_rule = document.dataset_process_rule
@@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
         info_list = []
         info_list = []
         extract_settings = []
         extract_settings = []
         for document in documents:
         for document in documents:
-            if document.indexing_status in ["completed", "error"]:
+            if document.indexing_status in {"completed", "error"}:
                 raise DocumentAlreadyFinishedError()
                 raise DocumentAlreadyFinishedError()
             data_source_info = document.data_source_info_dict
             data_source_info = document.data_source_info_dict
             # format document files info
             # format document files info
@@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
             db.session.commit()
             db.session.commit()
 
 
         elif action == "resume":
         elif action == "resume":
-            if document.indexing_status not in ["paused", "error"]:
+            if document.indexing_status not in {"paused", "error"}:
                 raise InvalidActionError("Document not in paused or error state.")
                 raise InvalidActionError("Document not in paused or error state.")
 
 
             document.paused_by = None
             document.paused_by = None

+ 1 - 1
api/controllers/console/explore/audio.py

@@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource):
             message_id = args.get("message_id", None)
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             text = args.get("text", None)
             if (
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow
                 and app_model.workflow.features_dict
                 and app_model.workflow.features_dict
             ):
             ):

+ 2 - 2
api/controllers/console/explore/completion.py

@@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
     def post(self, installed_app):
     def post(self, installed_app):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
     def post(self, installed_app, task_id):
     def post(self, installed_app, task_id):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

+ 5 - 5
api/controllers/console/explore/conversation.py

@@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
     def get(self, installed_app):
     def get(self, installed_app):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
     def delete(self, installed_app, c_id):
     def delete(self, installed_app, c_id):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
     def post(self, installed_app, c_id):
     def post(self, installed_app, c_id):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
     def patch(self, installed_app, c_id):
     def patch(self, installed_app, c_id):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
     def patch(self, installed_app, c_id):
     def patch(self, installed_app, c_id):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)

+ 1 - 1
api/controllers/console/explore/installed_app.py

@@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
                 "app_owner_tenant_id": installed_app.app_owner_tenant_id,
                 "app_owner_tenant_id": installed_app.app_owner_tenant_id,
                 "is_pinned": installed_app.is_pinned,
                 "is_pinned": installed_app.is_pinned,
                 "last_used_at": installed_app.last_used_at,
                 "last_used_at": installed_app.last_used_at,
-                "editable": current_user.role in ["owner", "admin"],
+                "editable": current_user.role in {"owner", "admin"},
                 "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
                 "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
             }
             }
             for installed_app in installed_apps
             for installed_app in installed_apps

+ 2 - 2
api/controllers/console/explore/message.py

@@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
         app_model = installed_app.app
         app_model = installed_app.app
 
 
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
     def get(self, installed_app, message_id):
     def get(self, installed_app, message_id):
         app_model = installed_app.app
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         message_id = str(message_id)
         message_id = str(message_id)

+ 1 - 1
api/controllers/console/explore/parameter.py

@@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
         app_model = installed_app.app
         app_model = installed_app.app
 
 
-        if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+        if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
             workflow = app_model.workflow
             workflow = app_model.workflow
             if workflow is None:
             if workflow is None:
                 raise AppUnavailableError()
                 raise AppUnavailableError()

+ 1 - 1
api/controllers/console/workspace/workspace.py

@@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
             raise TooManyFilesError()
             raise TooManyFilesError()
 
 
         extension = file.filename.split(".")[-1]
         extension = file.filename.split(".")[-1]
-        if extension.lower() not in ["svg", "png"]:
+        if extension.lower() not in {"svg", "png"}:
             raise UnsupportedFileTypeError()
             raise UnsupportedFileTypeError()
 
 
         try:
         try:

+ 1 - 1
api/controllers/service_api/app/app.py

@@ -42,7 +42,7 @@ class AppParameterApi(Resource):
     @marshal_with(parameters_fields)
     @marshal_with(parameters_fields)
     def get(self, app_model: App):
     def get(self, app_model: App):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
-        if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+        if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
             workflow = app_model.workflow
             workflow = app_model.workflow
             if workflow is None:
             if workflow is None:
                 raise AppUnavailableError()
                 raise AppUnavailableError()

+ 1 - 1
api/controllers/service_api/app/audio.py

@@ -79,7 +79,7 @@ class TextApi(Resource):
             message_id = args.get("message_id", None)
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             text = args.get("text", None)
             if (
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow
                 and app_model.workflow.features_dict
                 and app_model.workflow.features_dict
             ):
             ):

+ 2 - 2
api/controllers/service_api/app/completion.py

@@ -96,7 +96,7 @@ class ChatApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser):
     def post(self, app_model: App, end_user: EndUser):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -144,7 +144,7 @@ class ChatStopApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser, task_id):
     def post(self, app_model: App, end_user: EndUser, task_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

+ 3 - 3
api/controllers/service_api/app/conversation.py

@@ -18,7 +18,7 @@ class ConversationApi(Resource):
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, app_model: App, end_user: EndUser):
     def get(self, app_model: App, end_user: EndUser):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
     @marshal_with(simple_conversation_fields)
     @marshal_with(simple_conversation_fields)
     def delete(self, app_model: App, end_user: EndUser, c_id):
     def delete(self, app_model: App, end_user: EndUser, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
     @marshal_with(simple_conversation_fields)
     @marshal_with(simple_conversation_fields)
     def post(self, app_model: App, end_user: EndUser, c_id):
     def post(self, app_model: App, end_user: EndUser, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)

+ 2 - 2
api/controllers/service_api/app/message.py

@@ -76,7 +76,7 @@ class MessageListApi(Resource):
     @marshal_with(message_infinite_scroll_pagination_fields)
     @marshal_with(message_infinite_scroll_pagination_fields)
     def get(self, app_model: App, end_user: EndUser):
     def get(self, app_model: App, end_user: EndUser):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
     def get(self, app_model: App, end_user: EndUser, message_id):
     def get(self, app_model: App, end_user: EndUser, message_id):
         message_id = str(message_id)
         message_id = str(message_id)
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         try:
         try:

+ 1 - 1
api/controllers/web/app.py

@@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
     @marshal_with(parameters_fields)
     @marshal_with(parameters_fields)
     def get(self, app_model: App, end_user):
     def get(self, app_model: App, end_user):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
-        if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+        if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
             workflow = app_model.workflow
             workflow = app_model.workflow
             if workflow is None:
             if workflow is None:
                 raise AppUnavailableError()
                 raise AppUnavailableError()

+ 1 - 1
api/controllers/web/audio.py

@@ -78,7 +78,7 @@ class TextApi(WebApiResource):
             message_id = args.get("message_id", None)
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             text = args.get("text", None)
             if (
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow
                 and app_model.workflow.features_dict
                 and app_model.workflow.features_dict
             ):
             ):

+ 2 - 2
api/controllers/web/completion.py

@@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
 class ChatApi(WebApiResource):
 class ChatApi(WebApiResource):
     def post(self, app_model, end_user):
     def post(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
 class ChatStopApi(WebApiResource):
 class ChatStopApi(WebApiResource):
     def post(self, app_model, end_user, task_id):
     def post(self, app_model, end_user, task_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

+ 5 - 5
api/controllers/web/conversation.py

@@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, app_model, end_user):
     def get(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
 class ConversationApi(WebApiResource):
 class ConversationApi(WebApiResource):
     def delete(self, app_model, end_user, c_id):
     def delete(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
     @marshal_with(simple_conversation_fields)
     @marshal_with(simple_conversation_fields)
     def post(self, app_model, end_user, c_id):
     def post(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
 class ConversationPinApi(WebApiResource):
 class ConversationPinApi(WebApiResource):
     def patch(self, app_model, end_user, c_id):
     def patch(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)
@@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
 class ConversationUnPinApi(WebApiResource):
 class ConversationUnPinApi(WebApiResource):
     def patch(self, app_model, end_user, c_id):
     def patch(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         conversation_id = str(c_id)
         conversation_id = str(c_id)

+ 2 - 2
api/controllers/web/message.py

@@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
     @marshal_with(message_infinite_scroll_pagination_fields)
     @marshal_with(message_infinite_scroll_pagination_fields)
     def get(self, app_model, end_user):
     def get(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
             raise NotChatAppError()
 
 
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
@@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
 class MessageSuggestedQuestionApi(WebApiResource):
 class MessageSuggestedQuestionApi(WebApiResource):
     def get(self, app_model, end_user, message_id):
     def get(self, app_model, end_user, message_id):
         app_mode = AppMode.value_of(app_model.mode)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotCompletionAppError()
             raise NotCompletionAppError()
 
 
         message_id = str(message_id)
         message_id = str(message_id)

+ 2 - 2
api/core/agent/output_parser/cot_output_parser.py

@@ -90,7 +90,7 @@ class CotAgentOutputParser:
 
 
                 if not in_code_block and not in_json:
                 if not in_code_block and not in_json:
                     if delta.lower() == action_str[action_idx] and action_idx == 0:
                     if delta.lower() == action_str[action_idx] and action_idx == 0:
-                        if last_character not in ["\n", " ", ""]:
+                        if last_character not in {"\n", " ", ""}:
                             index += steps
                             index += steps
                             yield delta
                             yield delta
                             continue
                             continue
@@ -117,7 +117,7 @@ class CotAgentOutputParser:
                             action_idx = 0
                             action_idx = 0
 
 
                     if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
                     if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
-                        if last_character not in ["\n", " ", ""]:
+                        if last_character not in {"\n", " ", ""}:
                             index += steps
                             index += steps
                             yield delta
                             yield delta
                             continue
                             continue

+ 1 - 1
api/core/app/app_config/base_app_config_manager.py

@@ -29,7 +29,7 @@ class BaseAppConfigManager:
         additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
         additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
 
 
         additional_features.file_upload = FileUploadConfigManager.convert(
         additional_features.file_upload = FileUploadConfigManager.convert(
-            config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
+            config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
         )
         )
 
 
         additional_features.opening_statement, additional_features.suggested_questions = (
         additional_features.opening_statement, additional_features.suggested_questions = (

+ 3 - 3
api/core/app/app_config/easy_ui_based_app/agent/manager.py

@@ -18,7 +18,7 @@ class AgentConfigManager:
 
 
             if agent_strategy == "function_call":
             if agent_strategy == "function_call":
                 strategy = AgentEntity.Strategy.FUNCTION_CALLING
                 strategy = AgentEntity.Strategy.FUNCTION_CALLING
-            elif agent_strategy == "cot" or agent_strategy == "react":
+            elif agent_strategy in {"cot", "react"}:
                 strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
                 strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
             else:
             else:
                 # old configs, try to detect default strategy
                 # old configs, try to detect default strategy
@@ -43,10 +43,10 @@ class AgentConfigManager:
 
 
                     agent_tools.append(AgentToolEntity(**agent_tool_properties))
                     agent_tools.append(AgentToolEntity(**agent_tool_properties))
 
 
-            if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
+            if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
                 "react_router",
                 "react_router",
                 "router",
                 "router",
-            ]:
+            }:
                 agent_prompt = agent_dict.get("prompt", None) or {}
                 agent_prompt = agent_dict.get("prompt", None) or {}
                 # check model mode
                 # check model mode
                 model_mode = config.get("model", {}).get("mode", "completion")
                 model_mode = config.get("model", {}).get("mode", "completion")

+ 1 - 1
api/core/app/app_config/easy_ui_based_app/dataset/manager.py

@@ -167,7 +167,7 @@ class DatasetConfigManager:
             config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
             config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
 
 
         has_datasets = False
         has_datasets = False
-        if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
+        if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
             for tool in config["agent_mode"]["tools"]:
             for tool in config["agent_mode"]["tools"]:
                 key = list(tool.keys())[0]
                 key = list(tool.keys())[0]
                 if key == "dataset":
                 if key == "dataset":

+ 3 - 3
api/core/app/app_config/easy_ui_based_app/variables/manager.py

@@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
                         variable=variable["variable"], type=variable["type"], config=variable["config"]
                         variable=variable["variable"], type=variable["type"], config=variable["config"]
                     )
                     )
                 )
                 )
-            elif variable_type in [
+            elif variable_type in {
                 VariableEntityType.TEXT_INPUT,
                 VariableEntityType.TEXT_INPUT,
                 VariableEntityType.PARAGRAPH,
                 VariableEntityType.PARAGRAPH,
                 VariableEntityType.NUMBER,
                 VariableEntityType.NUMBER,
                 VariableEntityType.SELECT,
                 VariableEntityType.SELECT,
-            ]:
+            }:
                 variable = variables[variable_type]
                 variable = variables[variable_type]
                 variable_entities.append(
                 variable_entities.append(
                     VariableEntity(
                     VariableEntity(
@@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
         variables = []
         variables = []
         for item in config["user_input_form"]:
         for item in config["user_input_form"]:
             key = list(item.keys())[0]
             key = list(item.keys())[0]
-            if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
+            if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
                 raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph'  or 'select'")
                 raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph'  or 'select'")
 
 
             form_item = item[key]
             form_item = item[key]

+ 2 - 2
api/core/app/app_config/features/file_upload/manager.py

@@ -54,14 +54,14 @@ class FileUploadConfigManager:
 
 
             if is_vision:
             if is_vision:
                 detail = config["file_upload"]["image"]["detail"]
                 detail = config["file_upload"]["image"]["detail"]
-                if detail not in ["high", "low"]:
+                if detail not in {"high", "low"}:
                     raise ValueError("detail must be in ['high', 'low']")
                     raise ValueError("detail must be in ['high', 'low']")
 
 
             transfer_methods = config["file_upload"]["image"]["transfer_methods"]
             transfer_methods = config["file_upload"]["image"]["transfer_methods"]
             if not isinstance(transfer_methods, list):
             if not isinstance(transfer_methods, list):
                 raise ValueError("transfer_methods must be of list type")
                 raise ValueError("transfer_methods must be of list type")
             for method in transfer_methods:
             for method in transfer_methods:
-                if method not in ["remote_url", "local_file"]:
+                if method not in {"remote_url", "local_file"}:
                     raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
                     raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
 
 
         return config, ["file_upload"]
         return config, ["file_upload"]

+ 2 - 2
api/core/app/apps/advanced_chat/app_runner.py

@@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             raise ValueError("Workflow not initialized")
             raise ValueError("Workflow not initialized")
 
 
         user_id = None
         user_id = None
-        if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+        if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
             end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             if end_user:
             if end_user:
                 user_id = end_user.session_id
                 user_id = end_user.session_id
@@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             user_id=self.application_generate_entity.user_id,
             user_id=self.application_generate_entity.user_id,
             user_from=(
             user_from=(
                 UserFrom.ACCOUNT
                 UserFrom.ACCOUNT
-                if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
+                if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
                 else UserFrom.END_USER
                 else UserFrom.END_USER
             ),
             ),
             invoke_from=self.application_generate_entity.invoke_from,
             invoke_from=self.application_generate_entity.invoke_from,

+ 1 - 1
api/core/app/apps/base_app_generate_response_converter.py

@@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
     def convert(
     def convert(
         cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
         cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
     ) -> dict[str, Any] | Generator[str, Any, None]:
     ) -> dict[str, Any] | Generator[str, Any, None]:
-        if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
+        if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
             if isinstance(response, AppBlockingResponse):
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_full_response(response)
                 return cls.convert_blocking_full_response(response)
             else:
             else:

+ 3 - 3
api/core/app/apps/base_app_generator.py

@@ -22,11 +22,11 @@ class BaseAppGenerator:
             return var.default or ""
             return var.default or ""
         if (
         if (
             var.type
             var.type
-            in (
+            in {
                 VariableEntityType.TEXT_INPUT,
                 VariableEntityType.TEXT_INPUT,
                 VariableEntityType.SELECT,
                 VariableEntityType.SELECT,
                 VariableEntityType.PARAGRAPH,
                 VariableEntityType.PARAGRAPH,
-            )
+            }
             and user_input_value
             and user_input_value
             and not isinstance(user_input_value, str)
             and not isinstance(user_input_value, str)
         ):
         ):
@@ -44,7 +44,7 @@ class BaseAppGenerator:
             options = var.options or []
             options = var.options or []
             if user_input_value not in options:
             if user_input_value not in options:
                 raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
                 raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
-        elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
+        elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
             if var.max_length and user_input_value and len(user_input_value) > var.max_length:
             if var.max_length and user_input_value and len(user_input_value) > var.max_length:
                 raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
                 raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
 
 

+ 2 - 2
api/core/app/apps/base_app_queue_manager.py

@@ -32,7 +32,7 @@ class AppQueueManager:
         self._user_id = user_id
         self._user_id = user_id
         self._invoke_from = invoke_from
         self._invoke_from = invoke_from
 
 
-        user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
+        user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
         redis_client.setex(
         redis_client.setex(
             AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
             AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
         )
         )
@@ -118,7 +118,7 @@ class AppQueueManager:
         if result is None:
         if result is None:
             return
             return
 
 
-        user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
+        user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
         if result.decode("utf-8") != f"{user_prefix}-{user_id}":
         if result.decode("utf-8") != f"{user_prefix}-{user_id}":
             return
             return
 
 

+ 3 - 3
api/core/app/apps/message_based_app_generator.py

@@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         # get from source
         # get from source
         end_user_id = None
         end_user_id = None
         account_id = None
         account_id = None
-        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+        if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
             from_source = "api"
             from_source = "api"
             end_user_id = application_generate_entity.user_id
             end_user_id = application_generate_entity.user_id
         else:
         else:
@@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             model_provider = application_generate_entity.model_conf.provider
             model_provider = application_generate_entity.model_conf.provider
             model_id = application_generate_entity.model_conf.model
             model_id = application_generate_entity.model_conf.model
             override_model_configs = None
             override_model_configs = None
-            if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
+            if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
                 AppMode.AGENT_CHAT,
                 AppMode.AGENT_CHAT,
                 AppMode.CHAT,
                 AppMode.CHAT,
                 AppMode.COMPLETION,
                 AppMode.COMPLETION,
-            ]:
+            }:
                 override_model_configs = app_config.app_model_config_dict
                 override_model_configs = app_config.app_model_config_dict
 
 
         # get conversation introduction
         # get conversation introduction

+ 2 - 2
api/core/app/apps/workflow/app_runner.py

@@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
         app_config = cast(WorkflowAppConfig, app_config)
         app_config = cast(WorkflowAppConfig, app_config)
 
 
         user_id = None
         user_id = None
-        if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+        if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
             end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             if end_user:
             if end_user:
                 user_id = end_user.session_id
                 user_id = end_user.session_id
@@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
             user_id=self.application_generate_entity.user_id,
             user_id=self.application_generate_entity.user_id,
             user_from=(
             user_from=(
                 UserFrom.ACCOUNT
                 UserFrom.ACCOUNT
-                if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
+                if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
                 else UserFrom.END_USER
                 else UserFrom.END_USER
             ),
             ),
             invoke_from=self.application_generate_entity.invoke_from,
             invoke_from=self.application_generate_entity.invoke_from,

+ 1 - 1
api/core/app/features/annotation_reply/annotation_reply.py

@@ -63,7 +63,7 @@ class AnnotationReplyFeature:
                 score = documents[0].metadata["score"]
                 score = documents[0].metadata["score"]
                 annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
                 annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
                 if annotation:
                 if annotation:
-                    if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
+                    if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
                         from_source = "api"
                         from_source = "api"
                     else:
                     else:
                         from_source = "console"
                         from_source = "console"

+ 1 - 1
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
             self._message,
             self._message,
             application_generate_entity=self._application_generate_entity,
             application_generate_entity=self._application_generate_entity,
             conversation=self._conversation,
             conversation=self._conversation,
-            is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
+            is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
             and self._application_generate_entity.conversation_id is None,
             and self._application_generate_entity.conversation_id is None,
             extras=self._application_generate_entity.extras,
             extras=self._application_generate_entity.extras,
         )
         )

+ 2 - 2
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -383,7 +383,7 @@ class WorkflowCycleManage:
         :param workflow_node_execution: workflow node execution
         :param workflow_node_execution: workflow node execution
         :return:
         :return:
         """
         """
-        if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
+        if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
             return None
 
 
         response = NodeStartStreamResponse(
         response = NodeStartStreamResponse(
@@ -430,7 +430,7 @@ class WorkflowCycleManage:
         :param workflow_node_execution: workflow node execution
         :param workflow_node_execution: workflow node execution
         :return:
         :return:
         """
         """
-        if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
+        if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
             return None
 
 
         return NodeFinishStreamResponse(
         return NodeFinishStreamResponse(

+ 1 - 1
api/core/callback_handler/index_tool_callback_handler.py

@@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
             source="app",
             source="app",
             source_app_id=self._app_id,
             source_app_id=self._app_id,
             created_by_role=(
             created_by_role=(
-                "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
+                "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
             ),
             ),
             created_by=self._user_id,
             created_by=self._user_id,
         )
         )

+ 1 - 1
api/core/indexing_runner.py

@@ -292,7 +292,7 @@ class IndexingRunner:
         self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
         self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
     ) -> list[Document]:
     ) -> list[Document]:
         # load file
         # load file
-        if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
+        if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
             return []
             return []
 
 
         data_source_info = dataset_document.data_source_info_dict
         data_source_info = dataset_document.data_source_info_dict

+ 1 - 1
api/core/memory/token_buffer_memory.py

@@ -52,7 +52,7 @@ class TokenBufferMemory:
             files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
             files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
             if files:
             if files:
                 file_extra_config = None
                 file_extra_config = None
-                if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+                if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
                     file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
                     file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
                 else:
                 else:
                     if message.workflow_run_id:
                     if message.workflow_run_id:

+ 6 - 6
api/core/model_runtime/entities/model_entities.py

@@ -27,17 +27,17 @@ class ModelType(Enum):
 
 
         :return: model type
         :return: model type
         """
         """
-        if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
+        if origin_model_type in {"text-generation", cls.LLM.value}:
             return cls.LLM
             return cls.LLM
-        elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
+        elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
             return cls.TEXT_EMBEDDING
             return cls.TEXT_EMBEDDING
-        elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
+        elif origin_model_type in {"reranking", cls.RERANK.value}:
             return cls.RERANK
             return cls.RERANK
-        elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
+        elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
             return cls.SPEECH2TEXT
             return cls.SPEECH2TEXT
-        elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
+        elif origin_model_type in {"tts", cls.TTS.value}:
             return cls.TTS
             return cls.TTS
-        elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
+        elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
             return cls.TEXT2IMG
             return cls.TEXT2IMG
         elif origin_model_type == cls.MODERATION.value:
         elif origin_model_type == cls.MODERATION.value:
             return cls.MODERATION
             return cls.MODERATION

+ 1 - 1
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                                     mime_type = data_split[0].replace("data:", "")
                                     mime_type = data_split[0].replace("data:", "")
                                     base64_data = data_split[1]
                                     base64_data = data_split[1]
 
 
-                                if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                                if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                                     raise ValueError(
                                     raise ValueError(
                                         f"Unsupported image type {mime_type}, "
                                         f"Unsupported image type {mime_type}, "
                                         f"only support image/jpeg, image/png, image/gif, and image/webp"
                                         f"only support image/jpeg, image/png, image/gif, and image/webp"

+ 2 - 2
api/core/model_runtime/model_providers/azure_openai/tts/tts.py

@@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
                     for i in range(len(sentences))
                     for i in range(len(sentences))
                 ]
                 ]
                 for future in futures:
                 for future in futures:
-                    yield from future.result().__enter__().iter_bytes(1024)
+                    yield from future.result().__enter__().iter_bytes(1024)  # noqa:PLC2801
 
 
             else:
             else:
                 response = client.audio.speech.with_streaming_response.create(
                 response = client.audio.speech.with_streaming_response.create(
                     model=model, voice=voice, response_format="mp3", input=content_text.strip()
                     model=model, voice=voice, response_format="mp3", input=content_text.strip()
                 )
                 )
 
 
-                yield from response.__enter__().iter_bytes(1024)
+                yield from response.__enter__().iter_bytes(1024)  # noqa:PLC2801
         except Exception as ex:
         except Exception as ex:
             raise InvokeBadRequestError(str(ex))
             raise InvokeBadRequestError(str(ex))
 
 

+ 5 - 5
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                             base64_data = data_split[1]
                             base64_data = data_split[1]
                             image_content = base64.b64decode(base64_data)
                             image_content = base64.b64decode(base64_data)
 
 
-                        if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                        if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                             raise ValueError(
                             raise ValueError(
                                 f"Unsupported image type {mime_type}, "
                                 f"Unsupported image type {mime_type}, "
                                 f"only support image/jpeg, image/png, image/gif, and image/webp"
                                 f"only support image/jpeg, image/png, image/gif, and image/webp"
@@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 
 
         if error_code == "AccessDeniedException":
         if error_code == "AccessDeniedException":
             return InvokeAuthorizationError(error_msg)
             return InvokeAuthorizationError(error_msg)
-        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+        elif error_code in {"ResourceNotFoundException", "ValidationException"}:
             return InvokeBadRequestError(error_msg)
             return InvokeBadRequestError(error_msg)
-        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+        elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
             return InvokeRateLimitError(error_msg)
             return InvokeRateLimitError(error_msg)
-        elif error_code in [
+        elif error_code in {
             "ModelTimeoutException",
             "ModelTimeoutException",
             "ModelErrorException",
             "ModelErrorException",
             "InternalServerException",
             "InternalServerException",
             "ModelNotReadyException",
             "ModelNotReadyException",
-        ]:
+        }:
             return InvokeServerUnavailableError(error_msg)
             return InvokeServerUnavailableError(error_msg)
         elif error_code == "ModelStreamErrorException":
         elif error_code == "ModelStreamErrorException":
             return InvokeConnectionError(error_msg)
             return InvokeConnectionError(error_msg)

+ 4 - 4
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
 
 
         if error_code == "AccessDeniedException":
         if error_code == "AccessDeniedException":
             return InvokeAuthorizationError(error_msg)
             return InvokeAuthorizationError(error_msg)
-        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+        elif error_code in {"ResourceNotFoundException", "ValidationException"}:
             return InvokeBadRequestError(error_msg)
             return InvokeBadRequestError(error_msg)
-        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+        elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
             return InvokeRateLimitError(error_msg)
             return InvokeRateLimitError(error_msg)
-        elif error_code in [
+        elif error_code in {
             "ModelTimeoutException",
             "ModelTimeoutException",
             "ModelErrorException",
             "ModelErrorException",
             "InternalServerException",
             "InternalServerException",
             "ModelNotReadyException",
             "ModelNotReadyException",
-        ]:
+        }:
             return InvokeServerUnavailableError(error_msg)
             return InvokeServerUnavailableError(error_msg)
         elif error_code == "ModelStreamErrorException":
         elif error_code == "ModelStreamErrorException":
             return InvokeConnectionError(error_msg)
             return InvokeConnectionError(error_msg)

+ 2 - 2
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -6,10 +6,10 @@ from collections.abc import Generator
 from typing import Optional, Union, cast
 from typing import Optional, Union, cast
 
 
 import google.ai.generativelanguage as glm
 import google.ai.generativelanguage as glm
-import google.api_core.exceptions as exceptions
 import google.generativeai as genai
 import google.generativeai as genai
-import google.generativeai.client as client
 import requests
 import requests
+from google.api_core import exceptions
+from google.generativeai import client
 from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
 from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
 from google.generativeai.types.content_types import to_part
 from google.generativeai.types.content_types import to_part
 from PIL import Image
 from PIL import Image

+ 2 - 2
api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py

@@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
             if "huggingfacehub_api_type" not in credentials:
             if "huggingfacehub_api_type" not in credentials:
                 raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
                 raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
 
 
-            if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
+            if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
                 raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
                 raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
 
 
             if "huggingfacehub_api_token" not in credentials:
             if "huggingfacehub_api_token" not in credentials:
@@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
                     credentials["huggingfacehub_api_token"], model
                     credentials["huggingfacehub_api_token"], model
                 )
                 )
 
 
-            if credentials["task_type"] not in ("text2text-generation", "text-generation"):
+            if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
                 raise CredentialsValidateFailedError(
                 raise CredentialsValidateFailedError(
                     "Huggingface Hub Task Type must be one of text2text-generation, text-generation."
                     "Huggingface Hub Task Type must be one of text2text-generation, text-generation."
                 )
                 )

+ 1 - 1
api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py

@@ -75,7 +75,7 @@ class TeiHelper:
         if len(model_type.keys()) < 1:
         if len(model_type.keys()) < 1:
             raise RuntimeError("model_type is empty")
             raise RuntimeError("model_type is empty")
         model_type = list(model_type.keys())[0]
         model_type = list(model_type.keys())[0]
-        if model_type not in ["embedding", "reranker"]:
+        if model_type not in {"embedding", "reranker"}:
             raise RuntimeError(f"invalid model_type: {model_type}")
             raise RuntimeError(f"invalid model_type: {model_type}")
 
 
         max_input_length = response_json.get("max_input_length", 512)
         max_input_length = response_json.get("max_input_length", 512)

+ 2 - 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion.py

@@ -100,9 +100,9 @@ class MinimaxChatCompletion:
         return self._handle_chat_generate_response(response)
         return self._handle_chat_generate_response(response)
 
 
     def _handle_error(self, code: int, msg: str):
     def _handle_error(self, code: int, msg: str):
-        if code == 1000 or code == 1001 or code == 1013 or code == 1027:
+        if code in {1000, 1001, 1013, 1027}:
             raise InternalServerError(msg)
             raise InternalServerError(msg)
-        elif code == 1002 or code == 1039:
+        elif code in {1002, 1039}:
             raise RateLimitReachedError(msg)
             raise RateLimitReachedError(msg)
         elif code == 1004:
         elif code == 1004:
             raise InvalidAuthenticationError(msg)
             raise InvalidAuthenticationError(msg)

+ 2 - 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py

@@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
         return self._handle_chat_generate_response(response)
         return self._handle_chat_generate_response(response)
 
 
     def _handle_error(self, code: int, msg: str):
     def _handle_error(self, code: int, msg: str):
-        if code == 1000 or code == 1001 or code == 1013 or code == 1027:
+        if code in {1000, 1001, 1013, 1027}:
             raise InternalServerError(msg)
             raise InternalServerError(msg)
-        elif code == 1002 or code == 1039:
+        elif code in {1002, 1039}:
             raise RateLimitReachedError(msg)
             raise RateLimitReachedError(msg)
         elif code == 1004:
         elif code == 1004:
             raise InvalidAuthenticationError(msg)
             raise InvalidAuthenticationError(msg)

+ 1 - 1
api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py

@@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
             raise CredentialsValidateFailedError("Invalid api key")
             raise CredentialsValidateFailedError("Invalid api key")
 
 
     def _handle_error(self, code: int, msg: str):
     def _handle_error(self, code: int, msg: str):
-        if code == 1000 or code == 1001:
+        if code in {1000, 1001}:
             raise InternalServerError(msg)
             raise InternalServerError(msg)
         elif code == 1002:
         elif code == 1002:
             raise RateLimitReachedError(msg)
             raise RateLimitReachedError(msg)

+ 1 - 1
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         model_mode = self.get_model_mode(base_model, credentials)
         model_mode = self.get_model_mode(base_model, credentials)
 
 
         # transform response format
         # transform response format
-        if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
             stop = stop or []
             stop = stop or []
             if model_mode == LLMMode.CHAT:
             if model_mode == LLMMode.CHAT:
                 # chat model
                 # chat model

+ 2 - 2
api/core/model_runtime/model_providers/openai/tts/tts.py

@@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
                     for i in range(len(sentences))
                     for i in range(len(sentences))
                 ]
                 ]
                 for future in futures:
                 for future in futures:
-                    yield from future.result().__enter__().iter_bytes(1024)
+                    yield from future.result().__enter__().iter_bytes(1024)  # noqa:PLC2801
 
 
             else:
             else:
                 response = client.audio.speech.with_streaming_response.create(
                 response = client.audio.speech.with_streaming_response.create(
                     model=model, voice=voice, response_format="mp3", input=content_text.strip()
                     model=model, voice=voice, response_format="mp3", input=content_text.strip()
                 )
                 )
 
 
-                yield from response.__enter__().iter_bytes(1024)
+                yield from response.__enter__().iter_bytes(1024)  # noqa:PLC2801
         except Exception as ex:
         except Exception as ex:
             raise InvokeBadRequestError(str(ex))
             raise InvokeBadRequestError(str(ex))
 
 

+ 0 - 1
api/core/model_runtime/model_providers/openrouter/llm/llm.py

@@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
         credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
         credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
         credentials["mode"] = self.get_model_mode(model).value
         credentials["mode"] = self.get_model_mode(model).value
         credentials["function_calling_type"] = "tool_call"
         credentials["function_calling_type"] = "tool_call"
-        return
 
 
     def _invoke(
     def _invoke(
         self,
         self,

+ 1 - 1
api/core/model_runtime/model_providers/replicate/llm/llm.py

@@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
         )
         )
 
 
         for key, value in input_properties:
         for key, value in input_properties:
-            if key not in ["system_prompt", "prompt"] and "stop" not in key:
+            if key not in {"system_prompt", "prompt"} and "stop" not in key:
                 value_type = value.get("type")
                 value_type = value.get("type")
 
 
                 if not value_type:
                 if not value_type:

+ 2 - 2
api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py

@@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
         )
         )
 
 
         for input_property in input_properties:
         for input_property in input_properties:
-            if input_property[0] in ("text", "texts", "inputs"):
+            if input_property[0] in {"text", "texts", "inputs"}:
                 text_input_key = input_property[0]
                 text_input_key = input_property[0]
                 return text_input_key
                 return text_input_key
 
 
@@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
     def _generate_embeddings_by_text_input_key(
     def _generate_embeddings_by_text_input_key(
         client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
         client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
     ) -> list[list[float]]:
     ) -> list[list[float]]:
-        if text_input_key in ("text", "inputs"):
+        if text_input_key in {"text", "inputs"}:
             embeddings = []
             embeddings = []
             for text in texts:
             for text in texts:
                 result = client.run(replicate_model_version, input={text_input_key: text})
                 result = client.run(replicate_model_version, input={text_input_key: text})

+ 4 - 4
api/core/model_runtime/model_providers/tongyi/llm/llm.py

@@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
         :param tools: tools for tool calling
         :param tools: tools for tool calling
         :return:
         :return:
         """
         """
-        if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
+        if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
             model = model.replace("-chat", "")
             model = model.replace("-chat", "")
         if model == "farui-plus":
         if model == "farui-plus":
             model = "qwen-farui-plus"
             model = "qwen-farui-plus"
@@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
 
 
         mode = self.get_model_mode(model, credentials)
         mode = self.get_model_mode(model, credentials)
 
 
-        if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
+        if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
             model = model.replace("-chat", "")
             model = model.replace("-chat", "")
 
 
         extra_model_kwargs = {}
         extra_model_kwargs = {}
@@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
         :param prompt_messages: prompt messages
         :param prompt_messages: prompt messages
         :return: llm response
         :return: llm response
         """
         """
-        if response.status_code != 200 and response.status_code != HTTPStatus.OK:
+        if response.status_code not in {200, HTTPStatus.OK}:
             raise ServiceUnavailableError(response.message)
             raise ServiceUnavailableError(response.message)
         # transform assistant message to prompt message
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
         assistant_prompt_message = AssistantPromptMessage(
@@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
         full_text = ""
         full_text = ""
         tool_calls = []
         tool_calls = []
         for index, response in enumerate(responses):
         for index, response in enumerate(responses):
-            if response.status_code != 200 and response.status_code != HTTPStatus.OK:
+            if response.status_code not in {200, HTTPStatus.OK}:
                 raise ServiceUnavailableError(
                 raise ServiceUnavailableError(
                     f"Failed to invoke model {model}, status code: {response.status_code}, "
                     f"Failed to invoke model {model}, status code: {response.status_code}, "
                     f"message: {response.message}"
                     f"message: {response.message}"

+ 1 - 1
api/core/model_runtime/model_providers/upstage/llm/llm.py

@@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
         """
         """
         Code block mode wrapper for invoking large language model
         Code block mode wrapper for invoking large language model
         """
         """
-        if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
             stop = stop or []
             stop = stop or []
             self._transform_chat_json_prompts(
             self._transform_chat_json_prompts(
                 model=model,
                 model=model,

+ 2 - 2
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -5,7 +5,6 @@ import logging
 from collections.abc import Generator
 from collections.abc import Generator
 from typing import Optional, Union, cast
 from typing import Optional, Union, cast
 
 
-import google.api_core.exceptions as exceptions
 import google.auth.transport.requests
 import google.auth.transport.requests
 import vertexai.generative_models as glm
 import vertexai.generative_models as glm
 from anthropic import AnthropicVertex, Stream
 from anthropic import AnthropicVertex, Stream
@@ -17,6 +16,7 @@ from anthropic.types import (
     MessageStopEvent,
     MessageStopEvent,
     MessageStreamEvent,
     MessageStreamEvent,
 )
 )
+from google.api_core import exceptions
 from google.cloud import aiplatform
 from google.cloud import aiplatform
 from google.oauth2 import service_account
 from google.oauth2 import service_account
 from PIL import Image
 from PIL import Image
@@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
                             mime_type = data_split[0].replace("data:", "")
                             mime_type = data_split[0].replace("data:", "")
                             base64_data = data_split[1]
                             base64_data = data_split[1]
 
 
-                        if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                        if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                             raise ValueError(
                             raise ValueError(
                                 f"Unsupported image type {mime_type}, "
                                 f"Unsupported image type {mime_type}, "
                                 f"only support image/jpeg, image/png, image/gif, and image/webp"
                                 f"only support image/jpeg, image/png, image/gif, and image/webp"

+ 1 - 2
api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py

@@ -96,7 +96,6 @@ class Signer:
         signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
         signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
         sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
         sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
         request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
         request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
-        return
 
 
     @staticmethod
     @staticmethod
     def hashed_canonical_request_v4(request, meta):
     def hashed_canonical_request_v4(request, meta):
@@ -105,7 +104,7 @@ class Signer:
 
 
         signed_headers = {}
         signed_headers = {}
         for key in request.headers:
         for key in request.headers:
-            if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"):
+            if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
                 signed_headers[key.lower()] = request.headers[key]
                 signed_headers[key.lower()] = request.headers[key]
 
 
         if "host" in signed_headers:
         if "host" in signed_headers:

+ 1 - 1
api/core/model_runtime/model_providers/wenxin/llm/llm.py

@@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
         """
         """
         Code block mode wrapper for invoking large language model
         Code block mode wrapper for invoking large language model
         """
         """
-        if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
             response_format = model_parameters["response_format"]
             response_format = model_parameters["response_format"]
             stop = stop or []
             stop = stop or []
             self._transform_json_prompts(
             self._transform_json_prompts(

+ 1 - 1
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -103,7 +103,7 @@ class XinferenceHelper:
             model_handle_type = "embedding"
             model_handle_type = "embedding"
         elif response_json.get("model_type") == "audio":
         elif response_json.get("model_type") == "audio":
             model_handle_type = "audio"
             model_handle_type = "audio"
-            if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]:
+            if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
                 model_ability.append("text-to-audio")
                 model_ability.append("text-to-audio")
             else:
             else:
                 model_ability.append("audio-to-text")
                 model_ability.append("audio-to-text")

+ 9 - 12
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         new_prompt_messages: list[PromptMessage] = []
         new_prompt_messages: list[PromptMessage] = []
         for prompt_message in prompt_messages:
         for prompt_message in prompt_messages:
             copy_prompt_message = prompt_message.copy()
             copy_prompt_message = prompt_message.copy()
-            if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
+            if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
                 if isinstance(copy_prompt_message.content, list):
                 if isinstance(copy_prompt_message.content, list):
                     # check if model is 'glm-4v'
                     # check if model is 'glm-4v'
-                    if model not in ("glm-4v", "glm-4v-plus"):
+                    if model not in {"glm-4v", "glm-4v-plus"}:
                         # not support list message
                         # not support list message
                         continue
                         continue
                     # get image and
                     # get image and
@@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 ):
                 ):
                     new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
                     new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
                 else:
                 else:
-                    if (
-                        copy_prompt_message.role == PromptMessageRole.USER
-                        or copy_prompt_message.role == PromptMessageRole.TOOL
-                    ):
+                    if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
                         new_prompt_messages.append(copy_prompt_message)
                         new_prompt_messages.append(copy_prompt_message)
                     elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
                     elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
                         new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
                         new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 else:
                 else:
                     new_prompt_messages.append(copy_prompt_message)
                     new_prompt_messages.append(copy_prompt_message)
 
 
-        if model == "glm-4v" or model == "glm-4v-plus":
+        if model in {"glm-4v", "glm-4v-plus"}:
             params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
             params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
         else:
         else:
             params = {"model": model, "messages": [], **model_parameters}
             params = {"model": model, "messages": [], **model_parameters}
@@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 # chatglm model
                 # chatglm model
                 for prompt_message in new_prompt_messages:
                 for prompt_message in new_prompt_messages:
                     # merge system message to user message
                     # merge system message to user message
-                    if (
-                        prompt_message.role == PromptMessageRole.SYSTEM
-                        or prompt_message.role == PromptMessageRole.TOOL
-                        or prompt_message.role == PromptMessageRole.USER
-                    ):
+                    if prompt_message.role in {
+                        PromptMessageRole.SYSTEM,
+                        PromptMessageRole.TOOL,
+                        PromptMessageRole.USER,
+                    }:
                         if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
                         if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
                             params["messages"][-1]["content"] += "\n\n" + prompt_message.content
                             params["messages"][-1]["content"] += "\n\n" + prompt_message.content
                         else:
                         else:

+ 2 - 3
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py

@@ -1,5 +1,4 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-from .fine_tuning_job import FineTuningJob as FineTuningJob
-from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
-from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
+from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
+from .fine_tuning_job_event import FineTuningJobEvent

+ 2 - 2
api/core/model_runtime/schema_validators/common_validator.py

@@ -75,7 +75,7 @@ class CommonValidator:
         if not isinstance(value, str):
         if not isinstance(value, str):
             raise ValueError(f"Variable {credential_form_schema.variable} should be string")
             raise ValueError(f"Variable {credential_form_schema.variable} should be string")
 
 
-        if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]:
+        if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
             # If the value is in options, no validation is performed
             # If the value is in options, no validation is performed
             if credential_form_schema.options:
             if credential_form_schema.options:
                 if value not in [option.value for option in credential_form_schema.options]:
                 if value not in [option.value for option in credential_form_schema.options]:
@@ -83,7 +83,7 @@ class CommonValidator:
 
 
         if credential_form_schema.type == FormType.SWITCH:
         if credential_form_schema.type == FormType.SWITCH:
             # If the value is not in ['true', 'false'], an exception is thrown
             # If the value is not in ['true', 'false'], an exception is thrown
-            if value.lower() not in ["true", "false"]:
+            if value.lower() not in {"true", "false"}:
                 raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
                 raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
 
 
             value = True if value.lower() == "true" else False
             value = True if value.lower() == "true" else False

+ 2 - 2
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector):
     def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
     def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
         try:
         try:
             parsed_url = urlparse(config.host)
             parsed_url = urlparse(config.host)
-            if parsed_url.scheme in ["http", "https"]:
+            if parsed_url.scheme in {"http", "https"}:
                 hosts = f"{config.host}:{config.port}"
                 hosts = f"{config.host}:{config.port}"
             else:
             else:
                 hosts = f"http://{config.host}:{config.port}"
                 hosts = f"http://{config.host}:{config.port}"
@@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector):
         return uuids
         return uuids
 
 
     def text_exists(self, id: str) -> bool:
     def text_exists(self, id: str) -> bool:
-        return self._client.exists(index=self._collection_name, id=id).__bool__()
+        return bool(self._client.exists(index=self._collection_name, id=id))
 
 
     def delete_by_ids(self, ids: list[str]) -> None:
     def delete_by_ids(self, ids: list[str]) -> None:
         for id in ids:
         for id in ids:

+ 2 - 2
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -35,7 +35,7 @@ class MyScaleVector(BaseVector):
         super().__init__(collection_name)
         super().__init__(collection_name)
         self._config = config
         self._config = config
         self._metric = metric
         self._metric = metric
-        self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
+        self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
         self._client = get_client(
         self._client = get_client(
             host=config.host,
             host=config.host,
             port=config.port,
             port=config.port,
@@ -92,7 +92,7 @@ class MyScaleVector(BaseVector):
 
 
     @staticmethod
     @staticmethod
     def escape_str(value: Any) -> str:
     def escape_str(value: Any) -> str:
-        return "".join(" " if c in ("\\", "'") else c for c in str(value))
+        return "".join(" " if c in {"\\", "'"} else c for c in str(value))
 
 
     def text_exists(self, id: str) -> bool:
     def text_exists(self, id: str) -> bool:
         results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
         results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")

+ 1 - 9
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -223,15 +223,7 @@ class OracleVector(BaseVector):
                 words = pseg.cut(query)
                 words = pseg.cut(query)
                 current_entity = ""
                 current_entity = ""
                 for word, pos in words:
                 for word, pos in words:
-                    if (
-                        pos == "nr"
-                        or pos == "Ng"
-                        or pos == "eng"
-                        or pos == "nz"
-                        or pos == "n"
-                        or pos == "ORG"
-                        or pos == "v"
-                    ):  # nr: 人名, ns: 地名, nt: 机构名
+                    if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:  # nr: 人名, ns: 地名, nt: 机构名
                         current_entity += word
                         current_entity += word
                     else:
                     else:
                         if current_entity:
                         if current_entity:

+ 6 - 6
api/core/rag/extractor/extract_processor.py

@@ -98,17 +98,17 @@ class ExtractProcessor:
                 unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
                 unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
                 unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
                 unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
                 if etl_type == "Unstructured":
                 if etl_type == "Unstructured":
-                    if file_extension == ".xlsx" or file_extension == ".xls":
+                    if file_extension in {".xlsx", ".xls"}:
                         extractor = ExcelExtractor(file_path)
                         extractor = ExcelExtractor(file_path)
                     elif file_extension == ".pdf":
                     elif file_extension == ".pdf":
                         extractor = PdfExtractor(file_path)
                         extractor = PdfExtractor(file_path)
-                    elif file_extension in [".md", ".markdown"]:
+                    elif file_extension in {".md", ".markdown"}:
                         extractor = (
                         extractor = (
                             UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
                             UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
                             if is_automatic
                             if is_automatic
                             else MarkdownExtractor(file_path, autodetect_encoding=True)
                             else MarkdownExtractor(file_path, autodetect_encoding=True)
                         )
                         )
-                    elif file_extension in [".htm", ".html"]:
+                    elif file_extension in {".htm", ".html"}:
                         extractor = HtmlExtractor(file_path)
                         extractor = HtmlExtractor(file_path)
                     elif file_extension == ".docx":
                     elif file_extension == ".docx":
                         extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
                         extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
@@ -134,13 +134,13 @@ class ExtractProcessor:
                             else TextExtractor(file_path, autodetect_encoding=True)
                             else TextExtractor(file_path, autodetect_encoding=True)
                         )
                         )
                 else:
                 else:
-                    if file_extension == ".xlsx" or file_extension == ".xls":
+                    if file_extension in {".xlsx", ".xls"}:
                         extractor = ExcelExtractor(file_path)
                         extractor = ExcelExtractor(file_path)
                     elif file_extension == ".pdf":
                     elif file_extension == ".pdf":
                         extractor = PdfExtractor(file_path)
                         extractor = PdfExtractor(file_path)
-                    elif file_extension in [".md", ".markdown"]:
+                    elif file_extension in {".md", ".markdown"}:
                         extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
                         extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
-                    elif file_extension in [".htm", ".html"]:
+                    elif file_extension in {".htm", ".html"}:
                         extractor = HtmlExtractor(file_path)
                         extractor = HtmlExtractor(file_path)
                     elif file_extension == ".docx":
                     elif file_extension == ".docx":
                         extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
                         extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)

+ 1 - 1
api/core/rag/extractor/firecrawl/firecrawl_app.py

@@ -32,7 +32,7 @@ class FirecrawlApp:
             else:
             else:
                 raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
                 raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
 
 
-        elif response.status_code in [402, 409, 500]:
+        elif response.status_code in {402, 409, 500}:
             error_message = response.json().get("error", "Unknown error occurred")
             error_message = response.json().get("error", "Unknown error occurred")
             raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
             raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
         else:
         else:

+ 2 - 2
api/core/rag/extractor/notion_extractor.py

@@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor):
                     multi_select_list = property_value[type]
                     multi_select_list = property_value[type]
                     for multi_select in multi_select_list:
                     for multi_select in multi_select_list:
                         value.append(multi_select["name"])
                         value.append(multi_select["name"])
-                elif type == "rich_text" or type == "title":
+                elif type in {"rich_text", "title"}:
                     if len(property_value[type]) > 0:
                     if len(property_value[type]) > 0:
                         value = property_value[type][0]["plain_text"]
                         value = property_value[type][0]["plain_text"]
                     else:
                     else:
                         value = ""
                         value = ""
-                elif type == "select" or type == "status":
+                elif type in {"select", "status"}:
                     if property_value[type]:
                     if property_value[type]:
                         value = property_value[type]["name"]
                         value = property_value[type]["name"]
                     else:
                     else:

+ 1 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -115,7 +115,7 @@ class DatasetRetrieval:
 
 
             available_datasets.append(dataset)
             available_datasets.append(dataset)
         all_documents = []
         all_documents = []
-        user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
+        user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
         if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
         if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
             all_documents = self.single_retrieve(
             all_documents = self.single_retrieve(
                 app_id,
                 app_id,

+ 1 - 1
api/core/rag/splitter/text_splitter.py

@@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l
             splits = re.split(separator, text)
             splits = re.split(separator, text)
     else:
     else:
         splits = list(text)
         splits = list(text)
-    return [s for s in splits if (s != "" and s != "\n")]
+    return [s for s in splits if (s not in {"", "\n"})]
 
 
 
 
 class TextSplitter(BaseDocumentTransformer, ABC):
 class TextSplitter(BaseDocumentTransformer, ABC):

+ 1 - 1
api/core/tools/provider/app_tool_provider.py

@@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController):
                 label = input_form[form_type]["label"]
                 label = input_form[form_type]["label"]
                 variable_name = input_form[form_type]["variable_name"]
                 variable_name = input_form[form_type]["variable_name"]
                 options = input_form[form_type].get("options", [])
                 options = input_form[form_type].get("options", [])
-                if form_type == "paragraph" or form_type == "text-input":
+                if form_type in {"paragraph", "text-input"}:
                     tool["parameters"].append(
                     tool["parameters"].append(
                         ToolParameter(
                         ToolParameter(
                             name=variable_name,
                             name=variable_name,

+ 2 - 2
api/core/tools/provider/builtin/aippt/tools/aippt.py

@@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool):
                             pass
                             pass
                     elif event == "close":
                     elif event == "close":
                         break
                         break
-                    elif event == "error" or event == "filter":
+                    elif event in {"error", "filter"}:
                         raise Exception(f"Failed to generate outline: {data}")
                         raise Exception(f"Failed to generate outline: {data}")
 
 
         return outline
         return outline
@@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool):
                                 pass
                                 pass
                         elif event == "close":
                         elif event == "close":
                             break
                             break
-                        elif event == "error" or event == "filter":
+                        elif event in {"error", "filter"}:
                             raise Exception(f"Failed to generate content: {data}")
                             raise Exception(f"Failed to generate content: {data}")
 
 
             return content
             return content

+ 2 - 2
api/core/tools/provider/builtin/azuredalle/tools/dalle3.py

@@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool):
         n = tool_parameters.get("n", 1)
         n = tool_parameters.get("n", 1)
         # get quality
         # get quality
         quality = tool_parameters.get("quality", "standard")
         quality = tool_parameters.get("quality", "standard")
-        if quality not in ["standard", "hd"]:
+        if quality not in {"standard", "hd"}:
             return self.create_text_message("Invalid quality")
             return self.create_text_message("Invalid quality")
         # get style
         # get style
         style = tool_parameters.get("style", "vivid")
         style = tool_parameters.get("style", "vivid")
-        if style not in ["natural", "vivid"]:
+        if style not in {"natural", "vivid"}:
             return self.create_text_message("Invalid style")
             return self.create_text_message("Invalid style")
         # set extra body
         # set extra body
         seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
         seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

+ 1 - 1
api/core/tools/provider/builtin/code/tools/simple_code.py

@@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool):
         language = tool_parameters.get("language", CodeLanguage.PYTHON3)
         language = tool_parameters.get("language", CodeLanguage.PYTHON3)
         code = tool_parameters.get("code", "")
         code = tool_parameters.get("code", "")
 
 
-        if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
+        if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
             raise ValueError(f"Only python3 and javascript are supported, not {language}")
             raise ValueError(f"Only python3 and javascript are supported, not {language}")
 
 
         result = CodeExecutor.execute_code(language, "", code)
         result = CodeExecutor.execute_code(language, "", code)

+ 2 - 2
api/core/tools/provider/builtin/cogview/tools/cogview3.py

@@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool):
         n = tool_parameters.get("n", 1)
         n = tool_parameters.get("n", 1)
         # get quality
         # get quality
         quality = tool_parameters.get("quality", "standard")
         quality = tool_parameters.get("quality", "standard")
-        if quality not in ["standard", "hd"]:
+        if quality not in {"standard", "hd"}:
             return self.create_text_message("Invalid quality")
             return self.create_text_message("Invalid quality")
         # get style
         # get style
         style = tool_parameters.get("style", "vivid")
         style = tool_parameters.get("style", "vivid")
-        if style not in ["natural", "vivid"]:
+        if style not in {"natural", "vivid"}:
             return self.create_text_message("Invalid style")
             return self.create_text_message("Invalid style")
         # set extra body
         # set extra body
         seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
         seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

+ 2 - 2
api/core/tools/provider/builtin/dalle/tools/dalle3.py

@@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool):
         n = tool_parameters.get("n", 1)
         n = tool_parameters.get("n", 1)
         # get quality
         # get quality
         quality = tool_parameters.get("quality", "standard")
         quality = tool_parameters.get("quality", "standard")
-        if quality not in ["standard", "hd"]:
+        if quality not in {"standard", "hd"}:
             return self.create_text_message("Invalid quality")
             return self.create_text_message("Invalid quality")
         # get style
         # get style
         style = tool_parameters.get("style", "vivid")
         style = tool_parameters.get("style", "vivid")
-        if style not in ["natural", "vivid"]:
+        if style not in {"natural", "vivid"}:
             return self.create_text_message("Invalid style")
             return self.create_text_message("Invalid style")
 
 
         # call openapi dalle3
         # call openapi dalle3

+ 2 - 2
api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py

@@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool):
 
 
     def _extract_options(self, control: dict) -> list:
     def _extract_options(self, control: dict) -> list:
         options = []
         options = []
-        if control["type"] in [9, 10, 11]:
+        if control["type"] in {9, 10, 11}:
             options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
             options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
-        elif control["type"] in [28, 36]:
+        elif control["type"] in {28, 36}:
             itemnames = control["advancedSetting"].get("itemnames")
             itemnames = control["advancedSetting"].get("itemnames")
             if itemnames and itemnames.startswith("[{"):
             if itemnames and itemnames.startswith("[{"):
                 try:
                 try:

+ 3 - 3
api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py

@@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool):
         type_id = field.get("typeId")
         type_id = field.get("typeId")
         if type_id == 10:
         if type_id == 10:
             value = value if isinstance(value, str) else "、".join(value)
             value = value if isinstance(value, str) else "、".join(value)
-        elif type_id in [28, 36]:
+        elif type_id in {28, 36}:
             value = field.get("options", {}).get(value, value)
             value = field.get("options", {}).get(value, value)
-        elif type_id in [26, 27, 48, 14]:
+        elif type_id in {26, 27, 48, 14}:
             value = self.process_value(value)
             value = self.process_value(value)
-        elif type_id in [35, 29]:
+        elif type_id in {35, 29}:
             value = self.parse_cascade_or_associated(field, value)
             value = self.parse_cascade_or_associated(field, value)
         elif type_id == 40:
         elif type_id == 40:
             value = self.parse_location(value)
             value = self.parse_location(value)

+ 1 - 1
api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py

@@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
             models_data=[],
             models_data=[],
             headers=headers,
             headers=headers,
             params=params,
             params=params,
-            recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
+            recursive=result_type not in {"first sd_name", "first name sd_name pair"},
         )
         )
 
 
         result_str = ""
         result_str = ""

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google.py

@@ -38,7 +38,7 @@ class SearchAPI:
         return {
         return {
             "engine": "google",
             "engine": "google",
             "q": query,
             "q": query,
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
         }
 
 
     @staticmethod
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google_jobs.py

@@ -38,7 +38,7 @@ class SearchAPI:
         return {
         return {
             "engine": "google_jobs",
             "engine": "google_jobs",
             "q": query,
             "q": query,
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
         }
 
 
     @staticmethod
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google_news.py

@@ -38,7 +38,7 @@ class SearchAPI:
         return {
         return {
             "engine": "google_news",
             "engine": "google_news",
             "q": query,
             "q": query,
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
         }
 
 
     @staticmethod
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py

@@ -38,7 +38,7 @@ class SearchAPI:
             "engine": "youtube_transcripts",
             "engine": "youtube_transcripts",
             "video_id": video_id,
             "video_id": video_id,
             "lang": language or "en",
             "lang": language or "en",
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
         }
 
 
     @staticmethod
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/spider/spiderApp.py

@@ -214,7 +214,7 @@ class Spider:
         return requests.delete(url, headers=headers, stream=stream)
         return requests.delete(url, headers=headers, stream=stream)
 
 
     def _handle_error(self, response, action):
     def _handle_error(self, response, action):
-        if response.status_code in [402, 409, 500]:
+        if response.status_code in {402, 409, 500}:
             error_message = response.json().get("error", "Unknown error occurred")
             error_message = response.json().get("error", "Unknown error occurred")
             raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
             raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
         else:
         else:

+ 1 - 1
api/core/tools/provider/builtin/stability/tools/text2image.py

@@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
 
 
         model = tool_parameters.get("model", "core")
         model = tool_parameters.get("model", "core")
 
 
-        if model in ["sd3", "sd3-turbo"]:
+        if model in {"sd3", "sd3-turbo"}:
             payload["model"] = tool_parameters.get("model")
             payload["model"] = tool_parameters.get("model")
 
 
         if model != "sd3-turbo":
         if model != "sd3-turbo":

+ 1 - 1
api/core/tools/provider/builtin/vanna/tools/vanna.py

@@ -38,7 +38,7 @@ class VannaTool(BuiltinTool):
         vn = VannaDefault(model=model, api_key=api_key)
         vn = VannaDefault(model=model, api_key=api_key)
 
 
         db_type = tool_parameters.get("db_type", "")
         db_type = tool_parameters.get("db_type", "")
-        if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
+        if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
             if not db_name:
             if not db_name:
                 return self.create_text_message("Please input database name")
                 return self.create_text_message("Please input database name")
             if not username:
             if not username:

+ 1 - 1
api/core/tools/provider/builtin_tool_provider.py

@@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
 
 
 class BuiltinToolProviderController(ToolProviderController):
 class BuiltinToolProviderController(ToolProviderController):
     def __init__(self, **data: Any) -> None:
     def __init__(self, **data: Any) -> None:
-        if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
+        if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
             super().__init__(**data)
             super().__init__(**data)
             return
             return
 
 

+ 9 - 9
api/core/tools/provider/tool_provider.py

@@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC):
 
 
             # check type
             # check type
             credential_schema = credentials_need_to_validate[credential_name]
             credential_schema = credentials_need_to_validate[credential_name]
-            if (
-                credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT
-                or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT
-            ):
+            if credential_schema in {
+                ToolProviderCredentials.CredentialsType.SECRET_INPUT,
+                ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+            }:
                 if not isinstance(credentials[credential_name], str):
                 if not isinstance(credentials[credential_name], str):
                     raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
                     raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
 
 
@@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC):
             if credential_schema.default is not None:
             if credential_schema.default is not None:
                 default_value = credential_schema.default
                 default_value = credential_schema.default
                 # parse default value into the correct type
                 # parse default value into the correct type
-                if (
-                    credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT
-                    or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT
-                    or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT
-                ):
+                if credential_schema.type in {
+                    ToolProviderCredentials.CredentialsType.SECRET_INPUT,
+                    ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+                    ToolProviderCredentials.CredentialsType.SELECT,
+                }:
                     default_value = str(default_value)
                     default_value = str(default_value)
 
 
                 credentials[credential_name] = default_value
                 credentials[credential_name] = default_value

+ 4 - 4
api/core/tools/tool/api_tool.py

@@ -5,7 +5,7 @@ from urllib.parse import urlencode
 
 
 import httpx
 import httpx
 
 
-import core.helper.ssrf_proxy as ssrf_proxy
+from core.helper import ssrf_proxy
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
 from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
 from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
@@ -191,7 +191,7 @@ class ApiTool(Tool):
             else:
             else:
                 body = body
                 body = body
 
 
-        if method in ("get", "head", "post", "put", "delete", "patch"):
+        if method in {"get", "head", "post", "put", "delete", "patch"}:
             response = getattr(ssrf_proxy, method)(
             response = getattr(ssrf_proxy, method)(
                 url,
                 url,
                 params=params,
                 params=params,
@@ -224,9 +224,9 @@ class ApiTool(Tool):
                     elif option["type"] == "string":
                     elif option["type"] == "string":
                         return str(value)
                         return str(value)
                     elif option["type"] == "boolean":
                     elif option["type"] == "boolean":
-                        if str(value).lower() in ["true", "1"]:
+                        if str(value).lower() in {"true", "1"}:
                             return True
                             return True
-                        elif str(value).lower() in ["false", "0"]:
+                        elif str(value).lower() in {"false", "0"}:
                             return False
                             return False
                         else:
                         else:
                             continue  # Not a boolean, try next option
                             continue  # Not a boolean, try next option

+ 3 - 9
api/core/tools/tool_engine.py

@@ -189,10 +189,7 @@ class ToolEngine:
                 result += response.message
                 result += response.message
             elif response.type == ToolInvokeMessage.MessageType.LINK:
             elif response.type == ToolInvokeMessage.MessageType.LINK:
                 result += f"result link: {response.message}. please tell user to check it."
                 result += f"result link: {response.message}. please tell user to check it."
-            elif (
-                response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
-                or response.type == ToolInvokeMessage.MessageType.IMAGE
-            ):
+            elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                 result += (
                 result += (
                     "image has been created and sent to user already, you do not need to create it,"
                     "image has been created and sent to user already, you do not need to create it,"
                     " just tell the user to check it now."
                     " just tell the user to check it now."
@@ -212,10 +209,7 @@ class ToolEngine:
         result = []
         result = []
 
 
         for response in tool_response:
         for response in tool_response:
-            if (
-                response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
-                or response.type == ToolInvokeMessage.MessageType.IMAGE
-            ):
+            if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                 mimetype = None
                 mimetype = None
                 if response.meta.get("mime_type"):
                 if response.meta.get("mime_type"):
                     mimetype = response.meta.get("mime_type")
                     mimetype = response.meta.get("mime_type")
@@ -297,7 +291,7 @@ class ToolEngine:
                 belongs_to="assistant",
                 belongs_to="assistant",
                 url=message.url,
                 url=message.url,
                 upload_file_id=None,
                 upload_file_id=None,
-                created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"),
+                created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
                 created_by=user_id,
                 created_by=user_id,
             )
             )
 
 

+ 1 - 1
api/core/tools/utils/message_transformer.py

@@ -19,7 +19,7 @@ class ToolFileMessageTransformer:
         result = []
         result = []
 
 
         for message in messages:
         for message in messages:
-            if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
+            if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
                 result.append(message)
                 result.append(message)
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
                 # try to download image
                 # try to download image

+ 1 - 1
api/core/tools/utils/parser.py

@@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser:
         elif "schema" in parameter and "type" in parameter["schema"]:
         elif "schema" in parameter and "type" in parameter["schema"]:
             typ = parameter["schema"]["type"]
             typ = parameter["schema"]["type"]
 
 
-        if typ == "integer" or typ == "number":
+        if typ in {"integer", "number"}:
             return ToolParameter.ToolParameterType.NUMBER
             return ToolParameter.ToolParameterType.NUMBER
         elif typ == "boolean":
         elif typ == "boolean":
             return ToolParameter.ToolParameterType.BOOLEAN
             return ToolParameter.ToolParameterType.BOOLEAN

+ 1 - 1
api/core/tools/utils/web_reader_tool.py

@@ -313,7 +313,7 @@ def normalize_whitespace(text):
 
 
 
 
 def is_leaf(element):
 def is_leaf(element):
-    return element.name in ["p", "li"]
+    return element.name in {"p", "li"}
 
 
 
 
 def is_text(element):
 def is_text(element):

+ 1 - 1
api/core/workflow/graph_engine/entities/runtime_route_state.py

@@ -51,7 +51,7 @@ class RouteNodeState(BaseModel):
 
 
         :param run_result: run result
         :param run_result: run result
         """
         """
-        if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
+        if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
             raise Exception(f"Route state {self.id} already finished")
             raise Exception(f"Route state {self.id} already finished")
 
 
         if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
         if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:

+ 2 - 2
api/core/workflow/nodes/answer/answer_stream_generate_router.py

@@ -148,11 +148,11 @@ class AnswerStreamGeneratorRouter:
         for edge in reverse_edges:
         for edge in reverse_edges:
             source_node_id = edge.source_node_id
             source_node_id = edge.source_node_id
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
-            if source_node_type in (
+            if source_node_type in {
                 NodeType.ANSWER.value,
                 NodeType.ANSWER.value,
                 NodeType.IF_ELSE.value,
                 NodeType.IF_ELSE.value,
                 NodeType.QUESTION_CLASSIFIER.value,
                 NodeType.QUESTION_CLASSIFIER.value,
-            ):
+            }:
                 answer_dependencies[answer_node_id].append(source_node_id)
                 answer_dependencies[answer_node_id].append(source_node_id)
             else:
             else:
                 cls._recursive_fetch_answer_dependencies(
                 cls._recursive_fetch_answer_dependencies(

+ 2 - 2
api/core/workflow/nodes/end/end_stream_generate_router.py

@@ -136,10 +136,10 @@ class EndStreamGeneratorRouter:
         for edge in reverse_edges:
         for edge in reverse_edges:
             source_node_id = edge.source_node_id
             source_node_id = edge.source_node_id
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
-            if source_node_type in (
+            if source_node_type in {
                 NodeType.IF_ELSE.value,
                 NodeType.IF_ELSE.value,
                 NodeType.QUESTION_CLASSIFIER,
                 NodeType.QUESTION_CLASSIFIER,
-            ):
+            }:
                 end_dependencies[end_node_id].append(source_node_id)
                 end_dependencies[end_node_id].append(source_node_id)
             else:
             else:
                 cls._recursive_fetch_end_dependencies(
                 cls._recursive_fetch_end_dependencies(

+ 4 - 4
api/core/workflow/nodes/http_request/http_executor.py

@@ -6,8 +6,8 @@ from urllib.parse import urlencode
 
 
 import httpx
 import httpx
 
 
-import core.helper.ssrf_proxy as ssrf_proxy
 from configs import dify_config
 from configs import dify_config
+from core.helper import ssrf_proxy
 from core.workflow.entities.variable_entities import VariableSelector
 from core.workflow.entities.variable_entities import VariableSelector
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.http_request.entities import (
 from core.workflow.nodes.http_request.entities import (
@@ -176,7 +176,7 @@ class HttpExecutor:
             elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
             elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
                 self.headers["Content-Type"] = "application/x-www-form-urlencoded"
                 self.headers["Content-Type"] = "application/x-www-form-urlencoded"
 
 
-            if node_data.body.type in ["form-data", "x-www-form-urlencoded"]:
+            if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
                 body = self._to_dict(body_data)
                 body = self._to_dict(body_data)
 
 
                 if node_data.body.type == "form-data":
                 if node_data.body.type == "form-data":
@@ -187,7 +187,7 @@ class HttpExecutor:
                     self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
                     self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
                 else:
                 else:
                     self.body = urlencode(body)
                     self.body = urlencode(body)
-            elif node_data.body.type in ["json", "raw-text"]:
+            elif node_data.body.type in {"json", "raw-text"}:
                 self.body = body_data
                 self.body = body_data
             elif node_data.body.type == "none":
             elif node_data.body.type == "none":
                 self.body = ""
                 self.body = ""
@@ -258,7 +258,7 @@ class HttpExecutor:
             "follow_redirects": True,
             "follow_redirects": True,
         }
         }
 
 
-        if self.method in ("get", "head", "post", "put", "delete", "patch"):
+        if self.method in {"get", "head", "post", "put", "delete", "patch"}:
             response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
             response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
         else:
         else:
             raise ValueError(f"Invalid http method {self.method}")
             raise ValueError(f"Invalid http method {self.method}")

Some files were not shown because too many files changed in this diff