Procházet zdrojové kódy

chore: refurish python code by applying Pylint linter rules (#8322)

Bowen Liang před 7 měsíci
rodič
revize
a1104ab97e
100 změnil soubory, kde provedl 198 přidání a 218 odebrání
  1. 1 1
      api/app.py
  2. 2 2
      api/commands.py
  3. 1 1
      api/controllers/console/app/audio.py
  4. 1 1
      api/controllers/console/auth/oauth.py
  5. 3 3
      api/controllers/console/datasets/datasets_document.py
  6. 1 1
      api/controllers/console/explore/audio.py
  7. 2 2
      api/controllers/console/explore/completion.py
  8. 5 5
      api/controllers/console/explore/conversation.py
  9. 1 1
      api/controllers/console/explore/installed_app.py
  10. 2 2
      api/controllers/console/explore/message.py
  11. 1 1
      api/controllers/console/explore/parameter.py
  12. 1 1
      api/controllers/console/workspace/workspace.py
  13. 1 1
      api/controllers/service_api/app/app.py
  14. 1 1
      api/controllers/service_api/app/audio.py
  15. 2 2
      api/controllers/service_api/app/completion.py
  16. 3 3
      api/controllers/service_api/app/conversation.py
  17. 2 2
      api/controllers/service_api/app/message.py
  18. 1 1
      api/controllers/web/app.py
  19. 1 1
      api/controllers/web/audio.py
  20. 2 2
      api/controllers/web/completion.py
  21. 5 5
      api/controllers/web/conversation.py
  22. 2 2
      api/controllers/web/message.py
  23. 2 2
      api/core/agent/output_parser/cot_output_parser.py
  24. 1 1
      api/core/app/app_config/base_app_config_manager.py
  25. 3 3
      api/core/app/app_config/easy_ui_based_app/agent/manager.py
  26. 1 1
      api/core/app/app_config/easy_ui_based_app/dataset/manager.py
  27. 3 3
      api/core/app/app_config/easy_ui_based_app/variables/manager.py
  28. 2 2
      api/core/app/app_config/features/file_upload/manager.py
  29. 2 2
      api/core/app/apps/advanced_chat/app_runner.py
  30. 1 1
      api/core/app/apps/base_app_generate_response_converter.py
  31. 3 3
      api/core/app/apps/base_app_generator.py
  32. 2 2
      api/core/app/apps/base_app_queue_manager.py
  33. 3 3
      api/core/app/apps/message_based_app_generator.py
  34. 2 2
      api/core/app/apps/workflow/app_runner.py
  35. 1 1
      api/core/app/features/annotation_reply/annotation_reply.py
  36. 1 1
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  37. 2 2
      api/core/app/task_pipeline/workflow_cycle_manage.py
  38. 1 1
      api/core/callback_handler/index_tool_callback_handler.py
  39. 1 1
      api/core/indexing_runner.py
  40. 1 1
      api/core/memory/token_buffer_memory.py
  41. 6 6
      api/core/model_runtime/entities/model_entities.py
  42. 1 1
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  43. 2 2
      api/core/model_runtime/model_providers/azure_openai/tts/tts.py
  44. 5 5
      api/core/model_runtime/model_providers/bedrock/llm/llm.py
  45. 4 4
      api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
  46. 2 2
      api/core/model_runtime/model_providers/google/llm/llm.py
  47. 2 2
      api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
  48. 1 1
      api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py
  49. 2 2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
  50. 2 2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
  51. 1 1
      api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
  52. 1 1
      api/core/model_runtime/model_providers/openai/llm/llm.py
  53. 2 2
      api/core/model_runtime/model_providers/openai/tts/tts.py
  54. 0 1
      api/core/model_runtime/model_providers/openrouter/llm/llm.py
  55. 1 1
      api/core/model_runtime/model_providers/replicate/llm/llm.py
  56. 2 2
      api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
  57. 4 4
      api/core/model_runtime/model_providers/tongyi/llm/llm.py
  58. 1 1
      api/core/model_runtime/model_providers/upstage/llm/llm.py
  59. 2 2
      api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
  60. 1 2
      api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py
  61. 1 1
      api/core/model_runtime/model_providers/wenxin/llm/llm.py
  62. 1 1
      api/core/model_runtime/model_providers/xinference/xinference_helper.py
  63. 9 12
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  64. 2 3
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
  65. 2 2
      api/core/model_runtime/schema_validators/common_validator.py
  66. 2 2
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  67. 2 2
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  68. 1 9
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  69. 6 6
      api/core/rag/extractor/extract_processor.py
  70. 1 1
      api/core/rag/extractor/firecrawl/firecrawl_app.py
  71. 2 2
      api/core/rag/extractor/notion_extractor.py
  72. 1 1
      api/core/rag/retrieval/dataset_retrieval.py
  73. 1 1
      api/core/rag/splitter/text_splitter.py
  74. 1 1
      api/core/tools/provider/app_tool_provider.py
  75. 2 2
      api/core/tools/provider/builtin/aippt/tools/aippt.py
  76. 2 2
      api/core/tools/provider/builtin/azuredalle/tools/dalle3.py
  77. 1 1
      api/core/tools/provider/builtin/code/tools/simple_code.py
  78. 2 2
      api/core/tools/provider/builtin/cogview/tools/cogview3.py
  79. 2 2
      api/core/tools/provider/builtin/dalle/tools/dalle3.py
  80. 2 2
      api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py
  81. 3 3
      api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py
  82. 1 1
      api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py
  83. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google.py
  84. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google_jobs.py
  85. 1 1
      api/core/tools/provider/builtin/searchapi/tools/google_news.py
  86. 1 1
      api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py
  87. 1 1
      api/core/tools/provider/builtin/spider/spiderApp.py
  88. 1 1
      api/core/tools/provider/builtin/stability/tools/text2image.py
  89. 1 1
      api/core/tools/provider/builtin/vanna/tools/vanna.py
  90. 1 1
      api/core/tools/provider/builtin_tool_provider.py
  91. 9 9
      api/core/tools/provider/tool_provider.py
  92. 4 4
      api/core/tools/tool/api_tool.py
  93. 3 9
      api/core/tools/tool_engine.py
  94. 1 1
      api/core/tools/utils/message_transformer.py
  95. 1 1
      api/core/tools/utils/parser.py
  96. 1 1
      api/core/tools/utils/web_reader_tool.py
  97. 1 1
      api/core/workflow/graph_engine/entities/runtime_route_state.py
  98. 2 2
      api/core/workflow/nodes/answer/answer_stream_generate_router.py
  99. 2 2
      api/core/workflow/nodes/end/end_stream_generate_router.py
  100. 4 4
      api/core/workflow/nodes/http_request/http_executor.py

+ 1 - 1
api/app.py

@@ -164,7 +164,7 @@ def initialize_extensions(app):
 @login_manager.request_loader
 def load_user_from_request(request_from_flask_login):
     """Load user based on the request."""
-    if request.blueprint not in ["console", "inner_api"]:
+    if request.blueprint not in {"console", "inner_api"}:
         return None
     # Check if the user_id contains a dot, indicating the old format
     auth_header = request.headers.get("Authorization", "")

+ 2 - 2
api/commands.py

@@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
 @click.command("vdb-migrate", help="migrate vector db.")
 @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
 def vdb_migrate(scope: str):
-    if scope in ["knowledge", "all"]:
+    if scope in {"knowledge", "all"}:
         migrate_knowledge_vector_database()
-    if scope in ["annotation", "all"]:
+    if scope in {"annotation", "all"}:
         migrate_annotation_vector_database()
 
 

+ 1 - 1
api/controllers/console/app/audio.py

@@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource):
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow.features_dict
             ):

+ 1 - 1
api/controllers/console/auth/oauth.py

@@ -71,7 +71,7 @@ class OAuthCallback(Resource):
 
         account = _generate_account(provider, user_info)
         # Check account status
-        if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
+        if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
             return {"error": "Account is banned or closed."}, 403
 
         if account.status == AccountStatus.PENDING.value:

+ 3 - 3
api/controllers/console/datasets/datasets_document.py

@@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
         document_id = str(document_id)
         document = self.get_document(dataset_id, document_id)
 
-        if document.indexing_status in ["completed", "error"]:
+        if document.indexing_status in {"completed", "error"}:
             raise DocumentAlreadyFinishedError()
 
         data_process_rule = document.dataset_process_rule
@@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
         info_list = []
         extract_settings = []
         for document in documents:
-            if document.indexing_status in ["completed", "error"]:
+            if document.indexing_status in {"completed", "error"}:
                 raise DocumentAlreadyFinishedError()
             data_source_info = document.data_source_info_dict
             # format document files info
@@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
             db.session.commit()
 
         elif action == "resume":
-            if document.indexing_status not in ["paused", "error"]:
+            if document.indexing_status not in {"paused", "error"}:
                 raise InvalidActionError("Document not in paused or error state.")
 
             document.paused_by = None

+ 1 - 1
api/controllers/console/explore/audio.py

@@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource):
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow.features_dict
             ):

+ 2 - 2
api/controllers/console/explore/completion.py

@@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
     def post(self, installed_app):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
     def post(self, installed_app, task_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

+ 5 - 5
api/controllers/console/explore/conversation.py

@@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
     def get(self, installed_app):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
     def delete(self, installed_app, c_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
     def post(self, installed_app, c_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
     def patch(self, installed_app, c_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
     def patch(self, installed_app, c_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)

+ 1 - 1
api/controllers/console/explore/installed_app.py

@@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
                 "app_owner_tenant_id": installed_app.app_owner_tenant_id,
                 "is_pinned": installed_app.is_pinned,
                 "last_used_at": installed_app.last_used_at,
-                "editable": current_user.role in ["owner", "admin"],
+                "editable": current_user.role in {"owner", "admin"},
                 "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
             }
             for installed_app in installed_apps

+ 2 - 2
api/controllers/console/explore/message.py

@@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
         app_model = installed_app.app
 
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
     def get(self, installed_app, message_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         message_id = str(message_id)

+ 1 - 1
api/controllers/console/explore/parameter.py

@@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
         """Retrieve app parameters."""
         app_model = installed_app.app
 
-        if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+        if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
             workflow = app_model.workflow
             if workflow is None:
                 raise AppUnavailableError()

+ 1 - 1
api/controllers/console/workspace/workspace.py

@@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
             raise TooManyFilesError()
 
         extension = file.filename.split(".")[-1]
-        if extension.lower() not in ["svg", "png"]:
+        if extension.lower() not in {"svg", "png"}:
             raise UnsupportedFileTypeError()
 
         try:

+ 1 - 1
api/controllers/service_api/app/app.py

@@ -42,7 +42,7 @@ class AppParameterApi(Resource):
     @marshal_with(parameters_fields)
     def get(self, app_model: App):
         """Retrieve app parameters."""
-        if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+        if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
             workflow = app_model.workflow
             if workflow is None:
                 raise AppUnavailableError()

+ 1 - 1
api/controllers/service_api/app/audio.py

@@ -79,7 +79,7 @@ class TextApi(Resource):
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow.features_dict
             ):

+ 2 - 2
api/controllers/service_api/app/completion.py

@@ -96,7 +96,7 @@ class ChatApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -144,7 +144,7 @@ class ChatStopApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser, task_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

+ 3 - 3
api/controllers/service_api/app/conversation.py

@@ -18,7 +18,7 @@ class ConversationApi(Resource):
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, app_model: App, end_user: EndUser):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
     @marshal_with(simple_conversation_fields)
     def delete(self, app_model: App, end_user: EndUser, c_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
     @marshal_with(simple_conversation_fields)
     def post(self, app_model: App, end_user: EndUser, c_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)

+ 2 - 2
api/controllers/service_api/app/message.py

@@ -76,7 +76,7 @@ class MessageListApi(Resource):
     @marshal_with(message_infinite_scroll_pagination_fields)
     def get(self, app_model: App, end_user: EndUser):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
     def get(self, app_model: App, end_user: EndUser, message_id):
         message_id = str(message_id)
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         try:

+ 1 - 1
api/controllers/web/app.py

@@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
     @marshal_with(parameters_fields)
     def get(self, app_model: App, end_user):
         """Retrieve app parameters."""
-        if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+        if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
             workflow = app_model.workflow
             if workflow is None:
                 raise AppUnavailableError()

+ 1 - 1
api/controllers/web/audio.py

@@ -78,7 +78,7 @@ class TextApi(WebApiResource):
             message_id = args.get("message_id", None)
             text = args.get("text", None)
             if (
-                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
                 and app_model.workflow
                 and app_model.workflow.features_dict
             ):

+ 2 - 2
api/controllers/web/completion.py

@@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
 class ChatApi(WebApiResource):
     def post(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
 class ChatStopApi(WebApiResource):
     def post(self, app_model, end_user, task_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

+ 5 - 5
api/controllers/web/conversation.py

@@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
 class ConversationApi(WebApiResource):
     def delete(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
     @marshal_with(simple_conversation_fields)
     def post(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
 class ConversationPinApi(WebApiResource):
     def patch(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)
@@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
 class ConversationUnPinApi(WebApiResource):
     def patch(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         conversation_id = str(c_id)

+ 2 - 2
api/controllers/web/message.py

@@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
     @marshal_with(message_infinite_scroll_pagination_fields)
     def get(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
@@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
 class MessageSuggestedQuestionApi(WebApiResource):
     def get(self, app_model, end_user, message_id):
         app_mode = AppMode.value_of(app_model.mode)
-        if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
+        if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
             raise NotCompletionAppError()
 
         message_id = str(message_id)

+ 2 - 2
api/core/agent/output_parser/cot_output_parser.py

@@ -90,7 +90,7 @@ class CotAgentOutputParser:
 
                 if not in_code_block and not in_json:
                     if delta.lower() == action_str[action_idx] and action_idx == 0:
-                        if last_character not in ["\n", " ", ""]:
+                        if last_character not in {"\n", " ", ""}:
                             index += steps
                             yield delta
                             continue
@@ -117,7 +117,7 @@ class CotAgentOutputParser:
                             action_idx = 0
 
                     if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
-                        if last_character not in ["\n", " ", ""]:
+                        if last_character not in {"\n", " ", ""}:
                             index += steps
                             yield delta
                             continue

+ 1 - 1
api/core/app/app_config/base_app_config_manager.py

@@ -29,7 +29,7 @@ class BaseAppConfigManager:
         additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
 
         additional_features.file_upload = FileUploadConfigManager.convert(
-            config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
+            config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
         )
 
         additional_features.opening_statement, additional_features.suggested_questions = (

+ 3 - 3
api/core/app/app_config/easy_ui_based_app/agent/manager.py

@@ -18,7 +18,7 @@ class AgentConfigManager:
 
             if agent_strategy == "function_call":
                 strategy = AgentEntity.Strategy.FUNCTION_CALLING
-            elif agent_strategy == "cot" or agent_strategy == "react":
+            elif agent_strategy in {"cot", "react"}:
                 strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
             else:
                 # old configs, try to detect default strategy
@@ -43,10 +43,10 @@ class AgentConfigManager:
 
                     agent_tools.append(AgentToolEntity(**agent_tool_properties))
 
-            if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
+            if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
                 "react_router",
                 "router",
-            ]:
+            }:
                 agent_prompt = agent_dict.get("prompt", None) or {}
                 # check model mode
                 model_mode = config.get("model", {}).get("mode", "completion")

+ 1 - 1
api/core/app/app_config/easy_ui_based_app/dataset/manager.py

@@ -167,7 +167,7 @@ class DatasetConfigManager:
             config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
 
         has_datasets = False
-        if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
+        if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
             for tool in config["agent_mode"]["tools"]:
                 key = list(tool.keys())[0]
                 if key == "dataset":

+ 3 - 3
api/core/app/app_config/easy_ui_based_app/variables/manager.py

@@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
                         variable=variable["variable"], type=variable["type"], config=variable["config"]
                     )
                 )
-            elif variable_type in [
+            elif variable_type in {
                 VariableEntityType.TEXT_INPUT,
                 VariableEntityType.PARAGRAPH,
                 VariableEntityType.NUMBER,
                 VariableEntityType.SELECT,
-            ]:
+            }:
                 variable = variables[variable_type]
                 variable_entities.append(
                     VariableEntity(
@@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
         variables = []
         for item in config["user_input_form"]:
             key = list(item.keys())[0]
-            if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
+            if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
                 raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph'  or 'select'")
 
             form_item = item[key]

+ 2 - 2
api/core/app/app_config/features/file_upload/manager.py

@@ -54,14 +54,14 @@ class FileUploadConfigManager:
 
             if is_vision:
                 detail = config["file_upload"]["image"]["detail"]
-                if detail not in ["high", "low"]:
+                if detail not in {"high", "low"}:
                     raise ValueError("detail must be in ['high', 'low']")
 
             transfer_methods = config["file_upload"]["image"]["transfer_methods"]
             if not isinstance(transfer_methods, list):
                 raise ValueError("transfer_methods must be of list type")
             for method in transfer_methods:
-                if method not in ["remote_url", "local_file"]:
+                if method not in {"remote_url", "local_file"}:
                     raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
 
         return config, ["file_upload"]

+ 2 - 2
api/core/app/apps/advanced_chat/app_runner.py

@@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             raise ValueError("Workflow not initialized")
 
         user_id = None
-        if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+        if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
             end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             if end_user:
                 user_id = end_user.session_id
@@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
             user_id=self.application_generate_entity.user_id,
             user_from=(
                 UserFrom.ACCOUNT
-                if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
+                if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
                 else UserFrom.END_USER
             ),
             invoke_from=self.application_generate_entity.invoke_from,

+ 1 - 1
api/core/app/apps/base_app_generate_response_converter.py

@@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
     def convert(
         cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
     ) -> dict[str, Any] | Generator[str, Any, None]:
-        if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
+        if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_full_response(response)
             else:

+ 3 - 3
api/core/app/apps/base_app_generator.py

@@ -22,11 +22,11 @@ class BaseAppGenerator:
             return var.default or ""
         if (
             var.type
-            in (
+            in {
                 VariableEntityType.TEXT_INPUT,
                 VariableEntityType.SELECT,
                 VariableEntityType.PARAGRAPH,
-            )
+            }
             and user_input_value
             and not isinstance(user_input_value, str)
         ):
@@ -44,7 +44,7 @@ class BaseAppGenerator:
             options = var.options or []
             if user_input_value not in options:
                 raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
-        elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
+        elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
             if var.max_length and user_input_value and len(user_input_value) > var.max_length:
                 raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
 

+ 2 - 2
api/core/app/apps/base_app_queue_manager.py

@@ -32,7 +32,7 @@ class AppQueueManager:
         self._user_id = user_id
         self._invoke_from = invoke_from
 
-        user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
+        user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
         redis_client.setex(
             AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
         )
@@ -118,7 +118,7 @@ class AppQueueManager:
         if result is None:
             return
 
-        user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
+        user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
         if result.decode("utf-8") != f"{user_prefix}-{user_id}":
             return
 

+ 3 - 3
api/core/app/apps/message_based_app_generator.py

@@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
         # get from source
         end_user_id = None
         account_id = None
-        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+        if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
             from_source = "api"
             end_user_id = application_generate_entity.user_id
         else:
@@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             model_provider = application_generate_entity.model_conf.provider
             model_id = application_generate_entity.model_conf.model
             override_model_configs = None
-            if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
+            if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
                 AppMode.AGENT_CHAT,
                 AppMode.CHAT,
                 AppMode.COMPLETION,
-            ]:
+            }:
                 override_model_configs = app_config.app_model_config_dict
 
         # get conversation introduction

+ 2 - 2
api/core/app/apps/workflow/app_runner.py

@@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
         app_config = cast(WorkflowAppConfig, app_config)
 
         user_id = None
-        if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+        if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
             end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             if end_user:
                 user_id = end_user.session_id
@@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
             user_id=self.application_generate_entity.user_id,
             user_from=(
                 UserFrom.ACCOUNT
-                if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
+                if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
                 else UserFrom.END_USER
             ),
             invoke_from=self.application_generate_entity.invoke_from,

+ 1 - 1
api/core/app/features/annotation_reply/annotation_reply.py

@@ -63,7 +63,7 @@ class AnnotationReplyFeature:
                 score = documents[0].metadata["score"]
                 annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
                 if annotation:
-                    if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
+                    if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
                         from_source = "api"
                     else:
                         from_source = "console"

+ 1 - 1
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
             self._message,
             application_generate_entity=self._application_generate_entity,
             conversation=self._conversation,
-            is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
+            is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
             and self._application_generate_entity.conversation_id is None,
             extras=self._application_generate_entity.extras,
         )

+ 2 - 2
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -383,7 +383,7 @@ class WorkflowCycleManage:
         :param workflow_node_execution: workflow node execution
         :return:
         """
-        if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
+        if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
 
         response = NodeStartStreamResponse(
@@ -430,7 +430,7 @@ class WorkflowCycleManage:
         :param workflow_node_execution: workflow node execution
         :return:
         """
-        if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
+        if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
 
         return NodeFinishStreamResponse(

+ 1 - 1
api/core/callback_handler/index_tool_callback_handler.py

@@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
             source="app",
             source_app_id=self._app_id,
             created_by_role=(
-                "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
+                "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
             ),
             created_by=self._user_id,
         )

+ 1 - 1
api/core/indexing_runner.py

@@ -292,7 +292,7 @@ class IndexingRunner:
         self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
     ) -> list[Document]:
         # load file
-        if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
+        if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
             return []
 
         data_source_info = dataset_document.data_source_info_dict

+ 1 - 1
api/core/memory/token_buffer_memory.py

@@ -52,7 +52,7 @@ class TokenBufferMemory:
             files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
             if files:
                 file_extra_config = None
-                if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
+                if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
                     file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
                 else:
                     if message.workflow_run_id:

+ 6 - 6
api/core/model_runtime/entities/model_entities.py

@@ -27,17 +27,17 @@ class ModelType(Enum):
 
         :return: model type
         """
-        if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
+        if origin_model_type in {"text-generation", cls.LLM.value}:
             return cls.LLM
-        elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value:
+        elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
             return cls.TEXT_EMBEDDING
-        elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
+        elif origin_model_type in {"reranking", cls.RERANK.value}:
             return cls.RERANK
-        elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value:
+        elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
             return cls.SPEECH2TEXT
-        elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
+        elif origin_model_type in {"tts", cls.TTS.value}:
             return cls.TTS
-        elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
+        elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
             return cls.TEXT2IMG
         elif origin_model_type == cls.MODERATION.value:
             return cls.MODERATION

+ 1 - 1
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                                     mime_type = data_split[0].replace("data:", "")
                                     base64_data = data_split[1]
 
-                                if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                                if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                                     raise ValueError(
                                         f"Unsupported image type {mime_type}, "
                                         f"only support image/jpeg, image/png, image/gif, and image/webp"

+ 2 - 2
api/core/model_runtime/model_providers/azure_openai/tts/tts.py

@@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
                     for i in range(len(sentences))
                 ]
                 for future in futures:
-                    yield from future.result().__enter__().iter_bytes(1024)
+                    yield from future.result().__enter__().iter_bytes(1024)  # noqa:PLC2801
 
             else:
                 response = client.audio.speech.with_streaming_response.create(
                     model=model, voice=voice, response_format="mp3", input=content_text.strip()
                 )
 
-                yield from response.__enter__().iter_bytes(1024)
+                yield from response.__enter__().iter_bytes(1024)  # noqa:PLC2801
         except Exception as ex:
             raise InvokeBadRequestError(str(ex))
 

+ 5 - 5
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                             base64_data = data_split[1]
                             image_content = base64.b64decode(base64_data)
 
-                        if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                        if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                             raise ValueError(
                                 f"Unsupported image type {mime_type}, "
                                 f"only support image/jpeg, image/png, image/gif, and image/webp"
@@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 
         if error_code == "AccessDeniedException":
             return InvokeAuthorizationError(error_msg)
-        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+        elif error_code in {"ResourceNotFoundException", "ValidationException"}:
             return InvokeBadRequestError(error_msg)
-        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+        elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
             return InvokeRateLimitError(error_msg)
-        elif error_code in [
+        elif error_code in {
             "ModelTimeoutException",
             "ModelErrorException",
             "InternalServerException",
             "ModelNotReadyException",
-        ]:
+        }:
             return InvokeServerUnavailableError(error_msg)
         elif error_code == "ModelStreamErrorException":
             return InvokeConnectionError(error_msg)

+ 4 - 4
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
 
         if error_code == "AccessDeniedException":
             return InvokeAuthorizationError(error_msg)
-        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+        elif error_code in {"ResourceNotFoundException", "ValidationException"}:
             return InvokeBadRequestError(error_msg)
-        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+        elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
             return InvokeRateLimitError(error_msg)
-        elif error_code in [
+        elif error_code in {
             "ModelTimeoutException",
             "ModelErrorException",
             "InternalServerException",
             "ModelNotReadyException",
-        ]:
+        }:
             return InvokeServerUnavailableError(error_msg)
         elif error_code == "ModelStreamErrorException":
             return InvokeConnectionError(error_msg)

+ 2 - 2
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -6,10 +6,10 @@ from collections.abc import Generator
 from typing import Optional, Union, cast
 
 import google.ai.generativelanguage as glm
-import google.api_core.exceptions as exceptions
 import google.generativeai as genai
-import google.generativeai.client as client
 import requests
+from google.api_core import exceptions
+from google.generativeai import client
 from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
 from google.generativeai.types.content_types import to_part
 from PIL import Image

+ 2 - 2
api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py

@@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
             if "huggingfacehub_api_type" not in credentials:
                 raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
 
-            if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"):
+            if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
                 raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
 
             if "huggingfacehub_api_token" not in credentials:
@@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
                     credentials["huggingfacehub_api_token"], model
                 )
 
-            if credentials["task_type"] not in ("text2text-generation", "text-generation"):
+            if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
                 raise CredentialsValidateFailedError(
                     "Huggingface Hub Task Type must be one of text2text-generation, text-generation."
                 )

+ 1 - 1
api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py

@@ -75,7 +75,7 @@ class TeiHelper:
         if len(model_type.keys()) < 1:
             raise RuntimeError("model_type is empty")
         model_type = list(model_type.keys())[0]
-        if model_type not in ["embedding", "reranker"]:
+        if model_type not in {"embedding", "reranker"}:
             raise RuntimeError(f"invalid model_type: {model_type}")
 
         max_input_length = response_json.get("max_input_length", 512)

+ 2 - 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion.py

@@ -100,9 +100,9 @@ class MinimaxChatCompletion:
         return self._handle_chat_generate_response(response)
 
     def _handle_error(self, code: int, msg: str):
-        if code == 1000 or code == 1001 or code == 1013 or code == 1027:
+        if code in {1000, 1001, 1013, 1027}:
             raise InternalServerError(msg)
-        elif code == 1002 or code == 1039:
+        elif code in {1002, 1039}:
             raise RateLimitReachedError(msg)
         elif code == 1004:
             raise InvalidAuthenticationError(msg)

+ 2 - 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py

@@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
         return self._handle_chat_generate_response(response)
 
     def _handle_error(self, code: int, msg: str):
-        if code == 1000 or code == 1001 or code == 1013 or code == 1027:
+        if code in {1000, 1001, 1013, 1027}:
             raise InternalServerError(msg)
-        elif code == 1002 or code == 1039:
+        elif code in {1002, 1039}:
             raise RateLimitReachedError(msg)
         elif code == 1004:
             raise InvalidAuthenticationError(msg)

+ 1 - 1
api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py

@@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
             raise CredentialsValidateFailedError("Invalid api key")
 
     def _handle_error(self, code: int, msg: str):
-        if code == 1000 or code == 1001:
+        if code in {1000, 1001}:
             raise InternalServerError(msg)
         elif code == 1002:
             raise RateLimitReachedError(msg)

+ 1 - 1
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         model_mode = self.get_model_mode(base_model, credentials)
 
         # transform response format
-        if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
             stop = stop or []
             if model_mode == LLMMode.CHAT:
                 # chat model

+ 2 - 2
api/core/model_runtime/model_providers/openai/tts/tts.py

@@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
                     for i in range(len(sentences))
                 ]
                 for future in futures:
-                    yield from future.result().__enter__().iter_bytes(1024)
+                    yield from future.result().__enter__().iter_bytes(1024)  # noqa:PLC2801
 
             else:
                 response = client.audio.speech.with_streaming_response.create(
                     model=model, voice=voice, response_format="mp3", input=content_text.strip()
                 )
 
-                yield from response.__enter__().iter_bytes(1024)
+                yield from response.__enter__().iter_bytes(1024)  # noqa:PLC2801
         except Exception as ex:
             raise InvokeBadRequestError(str(ex))
 

+ 0 - 1
api/core/model_runtime/model_providers/openrouter/llm/llm.py

@@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
         credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
         credentials["mode"] = self.get_model_mode(model).value
         credentials["function_calling_type"] = "tool_call"
-        return
 
     def _invoke(
         self,

+ 1 - 1
api/core/model_runtime/model_providers/replicate/llm/llm.py

@@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
         )
 
         for key, value in input_properties:
-            if key not in ["system_prompt", "prompt"] and "stop" not in key:
+            if key not in {"system_prompt", "prompt"} and "stop" not in key:
                 value_type = value.get("type")
 
                 if not value_type:

+ 2 - 2
api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py

@@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
         )
 
         for input_property in input_properties:
-            if input_property[0] in ("text", "texts", "inputs"):
+            if input_property[0] in {"text", "texts", "inputs"}:
                 text_input_key = input_property[0]
                 return text_input_key
 
@@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
     def _generate_embeddings_by_text_input_key(
         client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
     ) -> list[list[float]]:
-        if text_input_key in ("text", "inputs"):
+        if text_input_key in {"text", "inputs"}:
             embeddings = []
             for text in texts:
                 result = client.run(replicate_model_version, input={text_input_key: text})

+ 4 - 4
api/core/model_runtime/model_providers/tongyi/llm/llm.py

@@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
         :param tools: tools for tool calling
         :return:
         """
-        if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
+        if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
             model = model.replace("-chat", "")
         if model == "farui-plus":
             model = "qwen-farui-plus"
@@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
 
         mode = self.get_model_mode(model, credentials)
 
-        if model in ["qwen-turbo-chat", "qwen-plus-chat"]:
+        if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
             model = model.replace("-chat", "")
 
         extra_model_kwargs = {}
@@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
         :param prompt_messages: prompt messages
         :return: llm response
         """
-        if response.status_code != 200 and response.status_code != HTTPStatus.OK:
+        if response.status_code not in {200, HTTPStatus.OK}:
             raise ServiceUnavailableError(response.message)
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
@@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
         full_text = ""
         tool_calls = []
         for index, response in enumerate(responses):
-            if response.status_code != 200 and response.status_code != HTTPStatus.OK:
+            if response.status_code not in {200, HTTPStatus.OK}:
                 raise ServiceUnavailableError(
                     f"Failed to invoke model {model}, status code: {response.status_code}, "
                     f"message: {response.message}"

+ 1 - 1
api/core/model_runtime/model_providers/upstage/llm/llm.py

@@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
         """
         Code block mode wrapper for invoking large language model
         """
-        if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
             stop = stop or []
             self._transform_chat_json_prompts(
                 model=model,

+ 2 - 2
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -5,7 +5,6 @@ import logging
 from collections.abc import Generator
 from typing import Optional, Union, cast
 
-import google.api_core.exceptions as exceptions
 import google.auth.transport.requests
 import vertexai.generative_models as glm
 from anthropic import AnthropicVertex, Stream
@@ -17,6 +16,7 @@ from anthropic.types import (
     MessageStopEvent,
     MessageStreamEvent,
 )
+from google.api_core import exceptions
 from google.cloud import aiplatform
 from google.oauth2 import service_account
 from PIL import Image
@@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
                             mime_type = data_split[0].replace("data:", "")
                             base64_data = data_split[1]
 
-                        if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                        if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                             raise ValueError(
                                 f"Unsupported image type {mime_type}, "
                                 f"only support image/jpeg, image/png, image/gif, and image/webp"

+ 1 - 2
api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py

@@ -96,7 +96,6 @@ class Signer:
         signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
         sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
         request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
-        return
 
     @staticmethod
     def hashed_canonical_request_v4(request, meta):
@@ -105,7 +104,7 @@ class Signer:
 
         signed_headers = {}
         for key in request.headers:
-            if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"):
+            if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
                 signed_headers[key.lower()] = request.headers[key]
 
         if "host" in signed_headers:

+ 1 - 1
api/core/model_runtime/model_providers/wenxin/llm/llm.py

@@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
         """
         Code block mode wrapper for invoking large language model
         """
-        if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]:
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
             response_format = model_parameters["response_format"]
             stop = stop or []
             self._transform_json_prompts(

+ 1 - 1
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -103,7 +103,7 @@ class XinferenceHelper:
             model_handle_type = "embedding"
         elif response_json.get("model_type") == "audio":
             model_handle_type = "audio"
-            if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]:
+            if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
                 model_ability.append("text-to-audio")
             else:
                 model_ability.append("audio-to-text")

+ 9 - 12
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         new_prompt_messages: list[PromptMessage] = []
         for prompt_message in prompt_messages:
             copy_prompt_message = prompt_message.copy()
-            if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
+            if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
                 if isinstance(copy_prompt_message.content, list):
                     # check if model is 'glm-4v'
-                    if model not in ("glm-4v", "glm-4v-plus"):
+                    if model not in {"glm-4v", "glm-4v-plus"}:
                         # not support list message
                         continue
                     # get image and
@@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 ):
                     new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
                 else:
-                    if (
-                        copy_prompt_message.role == PromptMessageRole.USER
-                        or copy_prompt_message.role == PromptMessageRole.TOOL
-                    ):
+                    if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
                         new_prompt_messages.append(copy_prompt_message)
                     elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
                         new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 else:
                     new_prompt_messages.append(copy_prompt_message)
 
-        if model == "glm-4v" or model == "glm-4v-plus":
+        if model in {"glm-4v", "glm-4v-plus"}:
             params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
         else:
             params = {"model": model, "messages": [], **model_parameters}
@@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 # chatglm model
                 for prompt_message in new_prompt_messages:
                     # merge system message to user message
-                    if (
-                        prompt_message.role == PromptMessageRole.SYSTEM
-                        or prompt_message.role == PromptMessageRole.TOOL
-                        or prompt_message.role == PromptMessageRole.USER
-                    ):
+                    if prompt_message.role in {
+                        PromptMessageRole.SYSTEM,
+                        PromptMessageRole.TOOL,
+                        PromptMessageRole.USER,
+                    }:
                         if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
                             params["messages"][-1]["content"] += "\n\n" + prompt_message.content
                         else:

+ 2 - 3
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py

@@ -1,5 +1,4 @@
 from __future__ import annotations
 
-from .fine_tuning_job import FineTuningJob as FineTuningJob
-from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
-from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
+from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
+from .fine_tuning_job_event import FineTuningJobEvent

+ 2 - 2
api/core/model_runtime/schema_validators/common_validator.py

@@ -75,7 +75,7 @@ class CommonValidator:
         if not isinstance(value, str):
             raise ValueError(f"Variable {credential_form_schema.variable} should be string")
 
-        if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]:
+        if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
             # If the value is in options, no validation is performed
             if credential_form_schema.options:
                 if value not in [option.value for option in credential_form_schema.options]:
@@ -83,7 +83,7 @@ class CommonValidator:
 
         if credential_form_schema.type == FormType.SWITCH:
             # If the value is not in ['true', 'false'], an exception is thrown
-            if value.lower() not in ["true", "false"]:
+            if value.lower() not in {"true", "false"}:
                 raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
 
             value = True if value.lower() == "true" else False

+ 2 - 2
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector):
     def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
         try:
             parsed_url = urlparse(config.host)
-            if parsed_url.scheme in ["http", "https"]:
+            if parsed_url.scheme in {"http", "https"}:
                 hosts = f"{config.host}:{config.port}"
             else:
                 hosts = f"http://{config.host}:{config.port}"
@@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector):
         return uuids
 
     def text_exists(self, id: str) -> bool:
-        return self._client.exists(index=self._collection_name, id=id).__bool__()
+        return bool(self._client.exists(index=self._collection_name, id=id))
 
     def delete_by_ids(self, ids: list[str]) -> None:
         for id in ids:

+ 2 - 2
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -35,7 +35,7 @@ class MyScaleVector(BaseVector):
         super().__init__(collection_name)
         self._config = config
         self._metric = metric
-        self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
+        self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
         self._client = get_client(
             host=config.host,
             port=config.port,
@@ -92,7 +92,7 @@ class MyScaleVector(BaseVector):
 
     @staticmethod
     def escape_str(value: Any) -> str:
-        return "".join(" " if c in ("\\", "'") else c for c in str(value))
+        return "".join(" " if c in {"\\", "'"} else c for c in str(value))
 
     def text_exists(self, id: str) -> bool:
         results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")

+ 1 - 9
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -223,15 +223,7 @@ class OracleVector(BaseVector):
                 words = pseg.cut(query)
                 current_entity = ""
                 for word, pos in words:
-                    if (
-                        pos == "nr"
-                        or pos == "Ng"
-                        or pos == "eng"
-                        or pos == "nz"
-                        or pos == "n"
-                        or pos == "ORG"
-                        or pos == "v"
-                    ):  # nr: 人名, ns: 地名, nt: 机构名
+                    if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:  # nr: 人名, ns: 地名, nt: 机构名
                         current_entity += word
                     else:
                         if current_entity:

+ 6 - 6
api/core/rag/extractor/extract_processor.py

@@ -98,17 +98,17 @@ class ExtractProcessor:
                 unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
                 unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
                 if etl_type == "Unstructured":
-                    if file_extension == ".xlsx" or file_extension == ".xls":
+                    if file_extension in {".xlsx", ".xls"}:
                         extractor = ExcelExtractor(file_path)
                     elif file_extension == ".pdf":
                         extractor = PdfExtractor(file_path)
-                    elif file_extension in [".md", ".markdown"]:
+                    elif file_extension in {".md", ".markdown"}:
                         extractor = (
                             UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
                             if is_automatic
                             else MarkdownExtractor(file_path, autodetect_encoding=True)
                         )
-                    elif file_extension in [".htm", ".html"]:
+                    elif file_extension in {".htm", ".html"}:
                         extractor = HtmlExtractor(file_path)
                     elif file_extension == ".docx":
                         extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
@@ -134,13 +134,13 @@ class ExtractProcessor:
                             else TextExtractor(file_path, autodetect_encoding=True)
                         )
                 else:
-                    if file_extension == ".xlsx" or file_extension == ".xls":
+                    if file_extension in {".xlsx", ".xls"}:
                         extractor = ExcelExtractor(file_path)
                     elif file_extension == ".pdf":
                         extractor = PdfExtractor(file_path)
-                    elif file_extension in [".md", ".markdown"]:
+                    elif file_extension in {".md", ".markdown"}:
                         extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
-                    elif file_extension in [".htm", ".html"]:
+                    elif file_extension in {".htm", ".html"}:
                         extractor = HtmlExtractor(file_path)
                     elif file_extension == ".docx":
                         extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)

+ 1 - 1
api/core/rag/extractor/firecrawl/firecrawl_app.py

@@ -32,7 +32,7 @@ class FirecrawlApp:
             else:
                 raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
 
-        elif response.status_code in [402, 409, 500]:
+        elif response.status_code in {402, 409, 500}:
             error_message = response.json().get("error", "Unknown error occurred")
             raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
         else:

+ 2 - 2
api/core/rag/extractor/notion_extractor.py

@@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor):
                     multi_select_list = property_value[type]
                     for multi_select in multi_select_list:
                         value.append(multi_select["name"])
-                elif type == "rich_text" or type == "title":
+                elif type in {"rich_text", "title"}:
                     if len(property_value[type]) > 0:
                         value = property_value[type][0]["plain_text"]
                     else:
                         value = ""
-                elif type == "select" or type == "status":
+                elif type in {"select", "status"}:
                     if property_value[type]:
                         value = property_value[type]["name"]
                     else:

+ 1 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -115,7 +115,7 @@ class DatasetRetrieval:
 
             available_datasets.append(dataset)
         all_documents = []
-        user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
+        user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
         if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
             all_documents = self.single_retrieve(
                 app_id,

+ 1 - 1
api/core/rag/splitter/text_splitter.py

@@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l
             splits = re.split(separator, text)
     else:
         splits = list(text)
-    return [s for s in splits if (s != "" and s != "\n")]
+    return [s for s in splits if (s not in {"", "\n"})]
 
 
 class TextSplitter(BaseDocumentTransformer, ABC):

+ 1 - 1
api/core/tools/provider/app_tool_provider.py

@@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController):
                 label = input_form[form_type]["label"]
                 variable_name = input_form[form_type]["variable_name"]
                 options = input_form[form_type].get("options", [])
-                if form_type == "paragraph" or form_type == "text-input":
+                if form_type in {"paragraph", "text-input"}:
                     tool["parameters"].append(
                         ToolParameter(
                             name=variable_name,

+ 2 - 2
api/core/tools/provider/builtin/aippt/tools/aippt.py

@@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool):
                             pass
                     elif event == "close":
                         break
-                    elif event == "error" or event == "filter":
+                    elif event in {"error", "filter"}:
                         raise Exception(f"Failed to generate outline: {data}")
 
         return outline
@@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool):
                                 pass
                         elif event == "close":
                             break
-                        elif event == "error" or event == "filter":
+                        elif event in {"error", "filter"}:
                             raise Exception(f"Failed to generate content: {data}")
 
             return content

+ 2 - 2
api/core/tools/provider/builtin/azuredalle/tools/dalle3.py

@@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool):
         n = tool_parameters.get("n", 1)
         # get quality
         quality = tool_parameters.get("quality", "standard")
-        if quality not in ["standard", "hd"]:
+        if quality not in {"standard", "hd"}:
             return self.create_text_message("Invalid quality")
         # get style
         style = tool_parameters.get("style", "vivid")
-        if style not in ["natural", "vivid"]:
+        if style not in {"natural", "vivid"}:
             return self.create_text_message("Invalid style")
         # set extra body
         seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

+ 1 - 1
api/core/tools/provider/builtin/code/tools/simple_code.py

@@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool):
         language = tool_parameters.get("language", CodeLanguage.PYTHON3)
         code = tool_parameters.get("code", "")
 
-        if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
+        if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
             raise ValueError(f"Only python3 and javascript are supported, not {language}")
 
         result = CodeExecutor.execute_code(language, "", code)

+ 2 - 2
api/core/tools/provider/builtin/cogview/tools/cogview3.py

@@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool):
         n = tool_parameters.get("n", 1)
         # get quality
         quality = tool_parameters.get("quality", "standard")
-        if quality not in ["standard", "hd"]:
+        if quality not in {"standard", "hd"}:
             return self.create_text_message("Invalid quality")
         # get style
         style = tool_parameters.get("style", "vivid")
-        if style not in ["natural", "vivid"]:
+        if style not in {"natural", "vivid"}:
             return self.create_text_message("Invalid style")
         # set extra body
         seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))

+ 2 - 2
api/core/tools/provider/builtin/dalle/tools/dalle3.py

@@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool):
         n = tool_parameters.get("n", 1)
         # get quality
         quality = tool_parameters.get("quality", "standard")
-        if quality not in ["standard", "hd"]:
+        if quality not in {"standard", "hd"}:
             return self.create_text_message("Invalid quality")
         # get style
         style = tool_parameters.get("style", "vivid")
-        if style not in ["natural", "vivid"]:
+        if style not in {"natural", "vivid"}:
             return self.create_text_message("Invalid style")
 
         # call openapi dalle3

+ 2 - 2
api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py

@@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool):
 
     def _extract_options(self, control: dict) -> list:
         options = []
-        if control["type"] in [9, 10, 11]:
+        if control["type"] in {9, 10, 11}:
             options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])])
-        elif control["type"] in [28, 36]:
+        elif control["type"] in {28, 36}:
             itemnames = control["advancedSetting"].get("itemnames")
             if itemnames and itemnames.startswith("[{"):
                 try:

+ 3 - 3
api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py

@@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool):
         type_id = field.get("typeId")
         if type_id == 10:
             value = value if isinstance(value, str) else "、".join(value)
-        elif type_id in [28, 36]:
+        elif type_id in {28, 36}:
             value = field.get("options", {}).get(value, value)
-        elif type_id in [26, 27, 48, 14]:
+        elif type_id in {26, 27, 48, 14}:
             value = self.process_value(value)
-        elif type_id in [35, 29]:
+        elif type_id in {35, 29}:
             value = self.parse_cascade_or_associated(field, value)
         elif type_id == 40:
             value = self.parse_location(value)

+ 1 - 1
api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py

@@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
             models_data=[],
             headers=headers,
             params=params,
-            recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
+            recursive=result_type not in {"first sd_name", "first name sd_name pair"},
         )
 
         result_str = ""

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google.py

@@ -38,7 +38,7 @@ class SearchAPI:
         return {
             "engine": "google",
             "q": query,
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
 
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google_jobs.py

@@ -38,7 +38,7 @@ class SearchAPI:
         return {
             "engine": "google_jobs",
             "q": query,
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
 
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/google_news.py

@@ -38,7 +38,7 @@ class SearchAPI:
         return {
             "engine": "google_news",
             "q": query,
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
 
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py

@@ -38,7 +38,7 @@ class SearchAPI:
             "engine": "youtube_transcripts",
             "video_id": video_id,
             "lang": language or "en",
-            **{key: value for key, value in kwargs.items() if value not in [None, ""]},
+            **{key: value for key, value in kwargs.items() if value not in {None, ""}},
         }
 
     @staticmethod

+ 1 - 1
api/core/tools/provider/builtin/spider/spiderApp.py

@@ -214,7 +214,7 @@ class Spider:
         return requests.delete(url, headers=headers, stream=stream)
 
     def _handle_error(self, response, action):
-        if response.status_code in [402, 409, 500]:
+        if response.status_code in {402, 409, 500}:
             error_message = response.json().get("error", "Unknown error occurred")
             raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
         else:

+ 1 - 1
api/core/tools/provider/builtin/stability/tools/text2image.py

@@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
 
         model = tool_parameters.get("model", "core")
 
-        if model in ["sd3", "sd3-turbo"]:
+        if model in {"sd3", "sd3-turbo"}:
             payload["model"] = tool_parameters.get("model")
 
         if model != "sd3-turbo":

+ 1 - 1
api/core/tools/provider/builtin/vanna/tools/vanna.py

@@ -38,7 +38,7 @@ class VannaTool(BuiltinTool):
         vn = VannaDefault(model=model, api_key=api_key)
 
         db_type = tool_parameters.get("db_type", "")
-        if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
+        if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
             if not db_name:
                 return self.create_text_message("Please input database name")
             if not username:

+ 1 - 1
api/core/tools/provider/builtin_tool_provider.py

@@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
 
 class BuiltinToolProviderController(ToolProviderController):
     def __init__(self, **data: Any) -> None:
-        if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
+        if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
             super().__init__(**data)
             return
 

+ 9 - 9
api/core/tools/provider/tool_provider.py

@@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC):
 
             # check type
             credential_schema = credentials_need_to_validate[credential_name]
-            if (
-                credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT
-                or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT
-            ):
+            if credential_schema in {
+                ToolProviderCredentials.CredentialsType.SECRET_INPUT,
+                ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+            }:
                 if not isinstance(credentials[credential_name], str):
                     raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
 
@@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC):
             if credential_schema.default is not None:
                 default_value = credential_schema.default
                 # parse default value into the correct type
-                if (
-                    credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT
-                    or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT
-                    or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT
-                ):
+                if credential_schema.type in {
+                    ToolProviderCredentials.CredentialsType.SECRET_INPUT,
+                    ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+                    ToolProviderCredentials.CredentialsType.SELECT,
+                }:
                     default_value = str(default_value)
 
                 credentials[credential_name] = default_value

+ 4 - 4
api/core/tools/tool/api_tool.py

@@ -5,7 +5,7 @@ from urllib.parse import urlencode
 
 import httpx
 
-import core.helper.ssrf_proxy as ssrf_proxy
+from core.helper import ssrf_proxy
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
 from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
@@ -191,7 +191,7 @@ class ApiTool(Tool):
             else:
                 body = body
 
-        if method in ("get", "head", "post", "put", "delete", "patch"):
+        if method in {"get", "head", "post", "put", "delete", "patch"}:
             response = getattr(ssrf_proxy, method)(
                 url,
                 params=params,
@@ -224,9 +224,9 @@ class ApiTool(Tool):
                     elif option["type"] == "string":
                         return str(value)
                     elif option["type"] == "boolean":
-                        if str(value).lower() in ["true", "1"]:
+                        if str(value).lower() in {"true", "1"}:
                             return True
-                        elif str(value).lower() in ["false", "0"]:
+                        elif str(value).lower() in {"false", "0"}:
                             return False
                         else:
                             continue  # Not a boolean, try next option

+ 3 - 9
api/core/tools/tool_engine.py

@@ -189,10 +189,7 @@ class ToolEngine:
                 result += response.message
             elif response.type == ToolInvokeMessage.MessageType.LINK:
                 result += f"result link: {response.message}. please tell user to check it."
-            elif (
-                response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
-                or response.type == ToolInvokeMessage.MessageType.IMAGE
-            ):
+            elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                 result += (
                     "image has been created and sent to user already, you do not need to create it,"
                     " just tell the user to check it now."
@@ -212,10 +209,7 @@ class ToolEngine:
         result = []
 
         for response in tool_response:
-            if (
-                response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
-                or response.type == ToolInvokeMessage.MessageType.IMAGE
-            ):
+            if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                 mimetype = None
                 if response.meta.get("mime_type"):
                     mimetype = response.meta.get("mime_type")
@@ -297,7 +291,7 @@ class ToolEngine:
                 belongs_to="assistant",
                 url=message.url,
                 upload_file_id=None,
-                created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"),
+                created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
                 created_by=user_id,
             )
 

+ 1 - 1
api/core/tools/utils/message_transformer.py

@@ -19,7 +19,7 @@ class ToolFileMessageTransformer:
         result = []
 
         for message in messages:
-            if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK:
+            if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
                 result.append(message)
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
                 # try to download image

+ 1 - 1
api/core/tools/utils/parser.py

@@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser:
         elif "schema" in parameter and "type" in parameter["schema"]:
             typ = parameter["schema"]["type"]
 
-        if typ == "integer" or typ == "number":
+        if typ in {"integer", "number"}:
             return ToolParameter.ToolParameterType.NUMBER
         elif typ == "boolean":
             return ToolParameter.ToolParameterType.BOOLEAN

+ 1 - 1
api/core/tools/utils/web_reader_tool.py

@@ -313,7 +313,7 @@ def normalize_whitespace(text):
 
 
 def is_leaf(element):
-    return element.name in ["p", "li"]
+    return element.name in {"p", "li"}
 
 
 def is_text(element):

+ 1 - 1
api/core/workflow/graph_engine/entities/runtime_route_state.py

@@ -51,7 +51,7 @@ class RouteNodeState(BaseModel):
 
         :param run_result: run result
         """
-        if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
+        if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
             raise Exception(f"Route state {self.id} already finished")
 
         if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:

+ 2 - 2
api/core/workflow/nodes/answer/answer_stream_generate_router.py

@@ -148,11 +148,11 @@ class AnswerStreamGeneratorRouter:
         for edge in reverse_edges:
             source_node_id = edge.source_node_id
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
-            if source_node_type in (
+            if source_node_type in {
                 NodeType.ANSWER.value,
                 NodeType.IF_ELSE.value,
                 NodeType.QUESTION_CLASSIFIER.value,
-            ):
+            }:
                 answer_dependencies[answer_node_id].append(source_node_id)
             else:
                 cls._recursive_fetch_answer_dependencies(

+ 2 - 2
api/core/workflow/nodes/end/end_stream_generate_router.py

@@ -136,10 +136,10 @@ class EndStreamGeneratorRouter:
         for edge in reverse_edges:
             source_node_id = edge.source_node_id
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
-            if source_node_type in (
+            if source_node_type in {
                 NodeType.IF_ELSE.value,
                 NodeType.QUESTION_CLASSIFIER,
-            ):
+            }:
                 end_dependencies[end_node_id].append(source_node_id)
             else:
                 cls._recursive_fetch_end_dependencies(

+ 4 - 4
api/core/workflow/nodes/http_request/http_executor.py

@@ -6,8 +6,8 @@ from urllib.parse import urlencode
 
 import httpx
 
-import core.helper.ssrf_proxy as ssrf_proxy
 from configs import dify_config
+from core.helper import ssrf_proxy
 from core.workflow.entities.variable_entities import VariableSelector
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.http_request.entities import (
@@ -176,7 +176,7 @@ class HttpExecutor:
             elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
                 self.headers["Content-Type"] = "application/x-www-form-urlencoded"
 
-            if node_data.body.type in ["form-data", "x-www-form-urlencoded"]:
+            if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
                 body = self._to_dict(body_data)
 
                 if node_data.body.type == "form-data":
@@ -187,7 +187,7 @@ class HttpExecutor:
                     self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
                 else:
                     self.body = urlencode(body)
-            elif node_data.body.type in ["json", "raw-text"]:
+            elif node_data.body.type in {"json", "raw-text"}:
                 self.body = body_data
             elif node_data.body.type == "none":
                 self.body = ""
@@ -258,7 +258,7 @@ class HttpExecutor:
             "follow_redirects": True,
         }
 
-        if self.method in ("get", "head", "post", "put", "delete", "patch"):
+        if self.method in {"get", "head", "post", "put", "delete", "patch"}:
             response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
         else:
             raise ValueError(f"Invalid http method {self.method}")

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů