Selaa lähdekoodia

chore(api): Introduce Ruff Formatter. (#7291)

-LAN- 8 kuukautta sitten
vanhempi
commit
3571292fbf
61 muutettua tiedostoa jossa 1319 lisäystä ja 1339 poistoa
  1. 4 0
      .github/workflows/style.yml
  2. 82 79
      api/app.py
  3. 217 202
      api/commands.py
  4. 1 1
      api/constants/__init__.py
  5. 19 20
      api/constants/languages.py
  6. 48 51
      api/constants/model_template.py
  7. 2 2
      api/contexts/__init__.py
  8. 4 4
      api/events/app_event.py
  9. 1 1
      api/events/dataset_event.py
  10. 1 1
      api/events/document_event.py
  11. 8 2
      api/events/event_handlers/clean_when_dataset_deleted.py
  12. 3 3
      api/events/event_handlers/clean_when_document_deleted.py
  13. 14 10
      api/events/event_handlers/create_document_index.py
  14. 1 1
      api/events/event_handlers/create_installed_app_when_app_created.py
  15. 5 5
      api/events/event_handlers/create_site_record_when_app_created.py
  16. 4 4
      api/events/event_handlers/deduct_quota_when_messaeg_created.py
  17. 3 3
      api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
  18. 1 1
      api/events/event_handlers/document_index_event.py
  19. 10 16
      api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
  20. 9 14
      api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
  21. 3 3
      api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py
  22. 1 1
      api/events/message_event.py
  23. 2 2
      api/events/tenant_event.py
  24. 10 13
      api/extensions/ext_celery.py
  25. 5 6
      api/extensions/ext_compress.py
  26. 5 5
      api/extensions/ext_database.py
  27. 37 35
      api/extensions/ext_mail.py
  28. 15 12
      api/extensions/ext_redis.py
  29. 7 10
      api/extensions/ext_sentry.py
  30. 13 25
      api/extensions/ext_storage.py
  31. 6 7
      api/extensions/storage/aliyun_storage.py
  32. 9 9
      api/extensions/storage/azure_storage.py
  33. 3 2
      api/extensions/storage/base_storage.py
  34. 8 7
      api/extensions/storage/google_storage.py
  35. 14 15
      api/extensions/storage/local_storage.py
  36. 13 12
      api/extensions/storage/oci_storage.py
  37. 17 16
      api/extensions/storage/s3_storage.py
  38. 9 10
      api/extensions/storage/tencent_storage.py
  39. 3 3
      api/fields/annotation_fields.py
  40. 7 7
      api/fields/api_based_extension_fields.py
  41. 111 115
      api/fields/app_fields.py
  42. 138 141
      api/fields/conversation_fields.py
  43. 12 12
      api/fields/conversation_variable_fields.py
  44. 32 40
      api/fields/data_source_fields.py
  45. 44 55
      api/fields/dataset_fields.py
  46. 59 61
      api/fields/document_fields.py
  47. 4 4
      api/fields/end_user_fields.py
  48. 11 11
      api/fields/file_fields.py
  49. 30 30
      api/fields/hit_testing_fields.py
  50. 13 15
      api/fields/installed_app_fields.py
  51. 22 28
      api/fields/member_fields.py
  52. 60 64
      api/fields/message_fields.py
  53. 24 24
      api/fields/segment_fields.py
  54. 1 6
      api/fields/tag_fields.py
  55. 9 9
      api/fields/workflow_app_log_fields.py
  56. 26 26
      api/fields/workflow_fields.py
  57. 25 25
      api/fields/workflow_run_fields.py
  58. 12 1
      api/pyproject.toml
  59. 13 8
      api/schedule/clean_embedding_cache_task.py
  60. 46 44
      api/schedule/clean_unused_datasets_task.py
  61. 3 0
      dev/reformat

+ 4 - 0
.github/workflows/style.yml

@@ -45,6 +45,10 @@ jobs:
         if: steps.changed-files.outputs.any_changed == 'true'
         if: steps.changed-files.outputs.any_changed == 'true'
         run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
         run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
 
 
+      - name: Ruff formatter check
+        if: steps.changed-files.outputs.any_changed == 'true'
+        run: poetry run -C api ruff format --check ./api
+
       - name: Lint hints
       - name: Lint hints
         if: failure()
         if: failure()
         run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
         run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

+ 82 - 79
api/app.py

@@ -1,6 +1,6 @@
 import os
 import os
 
 
-if os.environ.get("DEBUG", "false").lower() != 'true':
+if os.environ.get("DEBUG", "false").lower() != "true":
     from gevent import monkey
     from gevent import monkey
 
 
     monkey.patch_all()
     monkey.patch_all()
@@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
 if os.name == "nt":
 if os.name == "nt":
     os.system('tzutil /s "UTC"')
     os.system('tzutil /s "UTC"')
 else:
 else:
-    os.environ['TZ'] = 'UTC'
+    os.environ["TZ"] = "UTC"
     time.tzset()
     time.tzset()
 
 
 
 
@@ -70,13 +70,14 @@ class DifyApp(Flask):
 # -------------
 # -------------
 
 
 
 
-config_type = os.getenv('EDITION', default='SELF_HOSTED')  # ce edition first
+config_type = os.getenv("EDITION", default="SELF_HOSTED")  # ce edition first
 
 
 
 
 # ----------------------------
 # ----------------------------
 # Application Factory Function
 # Application Factory Function
 # ----------------------------
 # ----------------------------
 
 
+
 def create_flask_app_with_configs() -> Flask:
 def create_flask_app_with_configs() -> Flask:
     """
     """
     create a raw flask app
     create a raw flask app
@@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
         elif isinstance(value, int | float | bool):
         elif isinstance(value, int | float | bool):
             os.environ[key] = str(value)
             os.environ[key] = str(value)
         elif value is None:
         elif value is None:
-            os.environ[key] = ''
+            os.environ[key] = ""
 
 
     return dify_app
     return dify_app
 
 
@@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
 def create_app() -> Flask:
 def create_app() -> Flask:
     app = create_flask_app_with_configs()
     app = create_flask_app_with_configs()
 
 
-    app.secret_key = app.config['SECRET_KEY']
+    app.secret_key = app.config["SECRET_KEY"]
 
 
     log_handlers = None
     log_handlers = None
-    log_file = app.config.get('LOG_FILE')
+    log_file = app.config.get("LOG_FILE")
     if log_file:
     if log_file:
         log_dir = os.path.dirname(log_file)
         log_dir = os.path.dirname(log_file)
         os.makedirs(log_dir, exist_ok=True)
         os.makedirs(log_dir, exist_ok=True)
@@ -111,23 +112,24 @@ def create_app() -> Flask:
             RotatingFileHandler(
             RotatingFileHandler(
                 filename=log_file,
                 filename=log_file,
                 maxBytes=1024 * 1024 * 1024,
                 maxBytes=1024 * 1024 * 1024,
-                backupCount=5
+                backupCount=5,
             ),
             ),
-            logging.StreamHandler(sys.stdout)
+            logging.StreamHandler(sys.stdout),
         ]
         ]
 
 
     logging.basicConfig(
     logging.basicConfig(
-        level=app.config.get('LOG_LEVEL'),
-        format=app.config.get('LOG_FORMAT'),
-        datefmt=app.config.get('LOG_DATEFORMAT'),
+        level=app.config.get("LOG_LEVEL"),
+        format=app.config.get("LOG_FORMAT"),
+        datefmt=app.config.get("LOG_DATEFORMAT"),
         handlers=log_handlers,
         handlers=log_handlers,
-        force=True
+        force=True,
     )
     )
-    log_tz = app.config.get('LOG_TZ')
+    log_tz = app.config.get("LOG_TZ")
     if log_tz:
     if log_tz:
         from datetime import datetime
         from datetime import datetime
 
 
         import pytz
         import pytz
+
         timezone = pytz.timezone(log_tz)
         timezone = pytz.timezone(log_tz)
 
 
         def time_converter(seconds):
         def time_converter(seconds):
@@ -162,24 +164,24 @@ def initialize_extensions(app):
 @login_manager.request_loader
 @login_manager.request_loader
 def load_user_from_request(request_from_flask_login):
 def load_user_from_request(request_from_flask_login):
     """Load user based on the request."""
     """Load user based on the request."""
-    if request.blueprint not in ['console', 'inner_api']:
+    if request.blueprint not in ["console", "inner_api"]:
         return None
         return None
     # Check if the user_id contains a dot, indicating the old format
     # Check if the user_id contains a dot, indicating the old format
-    auth_header = request.headers.get('Authorization', '')
+    auth_header = request.headers.get("Authorization", "")
     if not auth_header:
     if not auth_header:
-        auth_token = request.args.get('_token')
+        auth_token = request.args.get("_token")
         if not auth_token:
         if not auth_token:
-            raise Unauthorized('Invalid Authorization token.')
+            raise Unauthorized("Invalid Authorization token.")
     else:
     else:
-        if ' ' not in auth_header:
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+        if " " not in auth_header:
+            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
         auth_scheme, auth_token = auth_header.split(None, 1)
         auth_scheme, auth_token = auth_header.split(None, 1)
         auth_scheme = auth_scheme.lower()
         auth_scheme = auth_scheme.lower()
-        if auth_scheme != 'bearer':
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+        if auth_scheme != "bearer":
+            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
 
 
     decoded = PassportService().verify(auth_token)
     decoded = PassportService().verify(auth_token)
-    user_id = decoded.get('user_id')
+    user_id = decoded.get("user_id")
 
 
     account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
     account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
     if account:
     if account:
@@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
 @login_manager.unauthorized_handler
 @login_manager.unauthorized_handler
 def unauthorized_handler():
 def unauthorized_handler():
     """Handle unauthorized requests."""
     """Handle unauthorized requests."""
-    return Response(json.dumps({
-        'code': 'unauthorized',
-        'message': "Unauthorized."
-    }), status=401, content_type="application/json")
+    return Response(
+        json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
+        status=401,
+        content_type="application/json",
+    )
 
 
 
 
 # register blueprint routers
 # register blueprint routers
@@ -204,38 +207,36 @@ def register_blueprints(app):
     from controllers.service_api import bp as service_api_bp
     from controllers.service_api import bp as service_api_bp
     from controllers.web import bp as web_bp
     from controllers.web import bp as web_bp
 
 
-    CORS(service_api_bp,
-         allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
-         methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
-         )
+    CORS(
+        service_api_bp,
+        allow_headers=["Content-Type", "Authorization", "X-App-Code"],
+        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+    )
     app.register_blueprint(service_api_bp)
     app.register_blueprint(service_api_bp)
 
 
-    CORS(web_bp,
-         resources={
-             r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
-         supports_credentials=True,
-         allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
-         methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
-         expose_headers=['X-Version', 'X-Env']
-         )
+    CORS(
+        web_bp,
+        resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
+        supports_credentials=True,
+        allow_headers=["Content-Type", "Authorization", "X-App-Code"],
+        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+        expose_headers=["X-Version", "X-Env"],
+    )
 
 
     app.register_blueprint(web_bp)
     app.register_blueprint(web_bp)
 
 
-    CORS(console_app_bp,
-         resources={
-             r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
-         supports_credentials=True,
-         allow_headers=['Content-Type', 'Authorization'],
-         methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
-         expose_headers=['X-Version', 'X-Env']
-         )
+    CORS(
+        console_app_bp,
+        resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
+        supports_credentials=True,
+        allow_headers=["Content-Type", "Authorization"],
+        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+        expose_headers=["X-Version", "X-Env"],
+    )
 
 
     app.register_blueprint(console_app_bp)
     app.register_blueprint(console_app_bp)
 
 
-    CORS(files_bp,
-         allow_headers=['Content-Type'],
-         methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
-         )
+    CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
     app.register_blueprint(files_bp)
     app.register_blueprint(files_bp)
 
 
     app.register_blueprint(inner_api_bp)
     app.register_blueprint(inner_api_bp)
@@ -245,29 +246,29 @@ def register_blueprints(app):
 app = create_app()
 app = create_app()
 celery = app.extensions["celery"]
 celery = app.extensions["celery"]
 
 
-if app.config.get('TESTING'):
+if app.config.get("TESTING"):
     print("App is running in TESTING mode")
     print("App is running in TESTING mode")
 
 
 
 
 @app.after_request
 @app.after_request
 def after_request(response):
 def after_request(response):
     """Add Version headers to the response."""
     """Add Version headers to the response."""
-    response.set_cookie('remember_token', '', expires=0)
-    response.headers.add('X-Version', app.config['CURRENT_VERSION'])
-    response.headers.add('X-Env', app.config['DEPLOY_ENV'])
+    response.set_cookie("remember_token", "", expires=0)
+    response.headers.add("X-Version", app.config["CURRENT_VERSION"])
+    response.headers.add("X-Env", app.config["DEPLOY_ENV"])
     return response
     return response
 
 
 
 
-@app.route('/health')
+@app.route("/health")
 def health():
 def health():
-    return Response(json.dumps({
-        'pid': os.getpid(),
-        'status': 'ok',
-        'version': app.config['CURRENT_VERSION']
-    }), status=200, content_type="application/json")
+    return Response(
+        json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
+        status=200,
+        content_type="application/json",
+    )
 
 
 
 
-@app.route('/threads')
+@app.route("/threads")
 def threads():
 def threads():
     num_threads = threading.active_count()
     num_threads = threading.active_count()
     threads = threading.enumerate()
     threads = threading.enumerate()
@@ -278,32 +279,34 @@ def threads():
         thread_id = thread.ident
         thread_id = thread.ident
         is_alive = thread.is_alive()
         is_alive = thread.is_alive()
 
 
-        thread_list.append({
-            'name': thread_name,
-            'id': thread_id,
-            'is_alive': is_alive
-        })
+        thread_list.append(
+            {
+                "name": thread_name,
+                "id": thread_id,
+                "is_alive": is_alive,
+            }
+        )
 
 
     return {
     return {
-        'pid': os.getpid(),
-        'thread_num': num_threads,
-        'threads': thread_list
+        "pid": os.getpid(),
+        "thread_num": num_threads,
+        "threads": thread_list,
     }
     }
 
 
 
 
-@app.route('/db-pool-stat')
+@app.route("/db-pool-stat")
 def pool_stat():
 def pool_stat():
     engine = db.engine
     engine = db.engine
     return {
     return {
-        'pid': os.getpid(),
-        'pool_size': engine.pool.size(),
-        'checked_in_connections': engine.pool.checkedin(),
-        'checked_out_connections': engine.pool.checkedout(),
-        'overflow_connections': engine.pool.overflow(),
-        'connection_timeout': engine.pool.timeout(),
-        'recycle_time': db.engine.pool._recycle
+        "pid": os.getpid(),
+        "pool_size": engine.pool.size(),
+        "checked_in_connections": engine.pool.checkedin(),
+        "checked_out_connections": engine.pool.checkedout(),
+        "overflow_connections": engine.pool.overflow(),
+        "connection_timeout": engine.pool.timeout(),
+        "recycle_time": db.engine.pool._recycle,
     }
     }
 
 
 
 
-if __name__ == '__main__':
-    app.run(host='0.0.0.0', port=5001)
+if __name__ == "__main__":
+    app.run(host="0.0.0.0", port=5001)

+ 217 - 202
api/commands.py

@@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
 from services.account_service import RegisterService, TenantService
 from services.account_service import RegisterService, TenantService
 
 
 
 
-@click.command('reset-password', help='Reset the account password.')
-@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
-@click.option('--new-password', prompt=True, help='the new password.')
-@click.option('--password-confirm', prompt=True, help='the new password confirm.')
+@click.command("reset-password", help="Reset the account password.")
+@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
+@click.option("--new-password", prompt=True, help="the new password.")
+@click.option("--password-confirm", prompt=True, help="the new password confirm.")
 def reset_password(email, new_password, password_confirm):
 def reset_password(email, new_password, password_confirm):
     """
     """
     Reset password of owner account
     Reset password of owner account
     Only available in SELF_HOSTED mode
     Only available in SELF_HOSTED mode
     """
     """
     if str(new_password).strip() != str(password_confirm).strip():
     if str(new_password).strip() != str(password_confirm).strip():
-        click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
+        click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
         return
         return
 
 
-    account = db.session.query(Account). \
-        filter(Account.email == email). \
-        one_or_none()
+    account = db.session.query(Account).filter(Account.email == email).one_or_none()
 
 
     if not account:
     if not account:
-        click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
+        click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
         return
         return
 
 
     try:
     try:
         valid_password(new_password)
         valid_password(new_password)
     except:
     except:
-        click.echo(
-            click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
+        click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
         return
         return
 
 
     # generate password salt
     # generate password salt
@@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
     account.password = base64_password_hashed
     account.password = base64_password_hashed
     account.password_salt = base64_salt
     account.password_salt = base64_salt
     db.session.commit()
     db.session.commit()
-    click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
+    click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
 
 
 
 
-@click.command('reset-email', help='Reset the account email.')
-@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
-@click.option('--new-email', prompt=True, help='the new email.')
-@click.option('--email-confirm', prompt=True, help='the new email confirm.')
+@click.command("reset-email", help="Reset the account email.")
+@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
+@click.option("--new-email", prompt=True, help="the new email.")
+@click.option("--email-confirm", prompt=True, help="the new email confirm.")
 def reset_email(email, new_email, email_confirm):
 def reset_email(email, new_email, email_confirm):
     """
     """
     Replace account email
     Replace account email
     :return:
     :return:
     """
     """
     if str(new_email).strip() != str(email_confirm).strip():
     if str(new_email).strip() != str(email_confirm).strip():
-        click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
+        click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
         return
         return
 
 
-    account = db.session.query(Account). \
-        filter(Account.email == email). \
-        one_or_none()
+    account = db.session.query(Account).filter(Account.email == email).one_or_none()
 
 
     if not account:
     if not account:
-        click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
+        click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
         return
         return
 
 
     try:
     try:
         email_validate(new_email)
         email_validate(new_email)
     except:
     except:
-        click.echo(
-            click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
+        click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
         return
         return
 
 
     account.email = new_email
     account.email = new_email
     db.session.commit()
     db.session.commit()
-    click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
-
-
-@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
-                                              'After the reset, all LLM credentials will become invalid, '
-                                              'requiring re-entry.'
-                                              'Only support SELF_HOSTED mode.')
-@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
-                                              ' this operation cannot be rolled back!', fg='red'))
+    click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
+
+
+@click.command(
+    "reset-encrypt-key-pair",
+    help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
+    "After the reset, all LLM credentials will become invalid, "
+    "requiring re-entry."
+    "Only support SELF_HOSTED mode.",
+)
+@click.confirmation_option(
+    prompt=click.style(
+        "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
+    )
+)
 def reset_encrypt_key_pair():
 def reset_encrypt_key_pair():
     """
     """
     Reset the encrypted key pair of workspace for encrypt LLM credentials.
     Reset the encrypted key pair of workspace for encrypt LLM credentials.
     After the reset, all LLM credentials will become invalid, requiring re-entry.
     After the reset, all LLM credentials will become invalid, requiring re-entry.
     Only support SELF_HOSTED mode.
     Only support SELF_HOSTED mode.
     """
     """
-    if dify_config.EDITION != 'SELF_HOSTED':
-        click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
+    if dify_config.EDITION != "SELF_HOSTED":
+        click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
         return
         return
 
 
     tenants = db.session.query(Tenant).all()
     tenants = db.session.query(Tenant).all()
     for tenant in tenants:
     for tenant in tenants:
         if not tenant:
         if not tenant:
-            click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
+            click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
             return
             return
 
 
         tenant.encrypt_public_key = generate_key_pair(tenant.id)
         tenant.encrypt_public_key = generate_key_pair(tenant.id)
 
 
-        db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
+        db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
         db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
         db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
         db.session.commit()
         db.session.commit()
 
 
-        click.echo(click.style('Congratulations! '
-                               'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
+        click.echo(
+            click.style(
+                "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
+                fg="green",
+            )
+        )
 
 
 
 
-@click.command('vdb-migrate', help='migrate vector db.')
-@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
+@click.command("vdb-migrate", help="migrate vector db.")
+@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
 def vdb_migrate(scope: str):
 def vdb_migrate(scope: str):
-    if scope in ['knowledge', 'all']:
+    if scope in ["knowledge", "all"]:
         migrate_knowledge_vector_database()
         migrate_knowledge_vector_database()
-    if scope in ['annotation', 'all']:
+    if scope in ["annotation", "all"]:
         migrate_annotation_vector_database()
         migrate_annotation_vector_database()
 
 
 
 
@@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
     """
     """
     Migrate annotation datas to target vector database .
     Migrate annotation datas to target vector database .
     """
     """
-    click.echo(click.style('Start migrate annotation data.', fg='green'))
+    click.echo(click.style("Start migrate annotation data.", fg="green"))
     create_count = 0
     create_count = 0
     skipped_count = 0
     skipped_count = 0
     total_count = 0
     total_count = 0
@@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
     while True:
     while True:
         try:
         try:
             # get apps info
             # get apps info
-            apps = db.session.query(App).filter(
-                App.status == 'normal'
-            ).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
+            apps = (
+                db.session.query(App)
+                .filter(App.status == "normal")
+                .order_by(App.created_at.desc())
+                .paginate(page=page, per_page=50)
+            )
         except NotFound:
         except NotFound:
             break
             break
 
 
         page += 1
         page += 1
         for app in apps:
         for app in apps:
             total_count = total_count + 1
             total_count = total_count + 1
-            click.echo(f'Processing the {total_count} app {app.id}. '
-                       + f'{create_count} created, {skipped_count} skipped.')
+            click.echo(
+                f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
+            )
             try:
             try:
-                click.echo('Create app annotation index: {}'.format(app.id))
-                app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
-                    AppAnnotationSetting.app_id == app.id
-                ).first()
+                click.echo("Create app annotation index: {}".format(app.id))
+                app_annotation_setting = (
+                    db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
+                )
 
 
                 if not app_annotation_setting:
                 if not app_annotation_setting:
                     skipped_count = skipped_count + 1
                     skipped_count = skipped_count + 1
-                    click.echo('App annotation setting is disabled: {}'.format(app.id))
+                    click.echo("App annotation setting is disabled: {}".format(app.id))
                     continue
                     continue
                 # get dataset_collection_binding info
                 # get dataset_collection_binding info
-                dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
-                    DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
-                ).first()
+                dataset_collection_binding = (
+                    db.session.query(DatasetCollectionBinding)
+                    .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
+                    .first()
+                )
                 if not dataset_collection_binding:
                 if not dataset_collection_binding:
-                    click.echo('App annotation collection binding is not exist: {}'.format(app.id))
+                    click.echo("App annotation collection binding is not exist: {}".format(app.id))
                     continue
                     continue
                 annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
                 annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
                 dataset = Dataset(
                 dataset = Dataset(
                     id=app.id,
                     id=app.id,
                     tenant_id=app.tenant_id,
                     tenant_id=app.tenant_id,
-                    indexing_technique='high_quality',
+                    indexing_technique="high_quality",
                     embedding_model_provider=dataset_collection_binding.provider_name,
                     embedding_model_provider=dataset_collection_binding.provider_name,
                     embedding_model=dataset_collection_binding.model_name,
                     embedding_model=dataset_collection_binding.model_name,
-                    collection_binding_id=dataset_collection_binding.id
+                    collection_binding_id=dataset_collection_binding.id,
                 )
                 )
                 documents = []
                 documents = []
                 if annotations:
                 if annotations:
                     for annotation in annotations:
                     for annotation in annotations:
                         document = Document(
                         document = Document(
                             page_content=annotation.question,
                             page_content=annotation.question,
-                            metadata={
-                                "annotation_id": annotation.id,
-                                "app_id": app.id,
-                                "doc_id": annotation.id
-                            }
+                            metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
                         )
                         )
                         documents.append(document)
                         documents.append(document)
 
 
-                vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
+                vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
                 click.echo(f"Start to migrate annotation, app_id: {app.id}.")
                 click.echo(f"Start to migrate annotation, app_id: {app.id}.")
 
 
                 try:
                 try:
                     vector.delete()
                     vector.delete()
-                    click.echo(
-                        click.style(f'Successfully delete vector index for app: {app.id}.',
-                                    fg='green'))
+                    click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
                 except Exception as e:
                 except Exception as e:
-                    click.echo(
-                        click.style(f'Failed to delete vector index for app {app.id}.',
-                                    fg='red'))
+                    click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
                     raise e
                     raise e
                 if documents:
                 if documents:
                     try:
                     try:
-                        click.echo(click.style(
-                            f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
-                            fg='green'))
-                        vector.create(documents)
                         click.echo(
                         click.echo(
-                            click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
+                            click.style(
+                                f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
+                                fg="green",
+                            )
+                        )
+                        vector.create(documents)
+                        click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
                     except Exception as e:
                     except Exception as e:
-                        click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
+                        click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
                         raise e
                         raise e
-                click.echo(f'Successfully migrated app annotation {app.id}.')
+                click.echo(f"Successfully migrated app annotation {app.id}.")
                 create_count += 1
                 create_count += 1
             except Exception as e:
             except Exception as e:
                 click.echo(
                 click.echo(
-                    click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
-                                fg='red'))
+                    click.style(
+                        "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
+                    )
+                )
                 continue
                 continue
 
 
     click.echo(
     click.echo(
-        click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
-                    fg='green'))
+        click.style(
+            f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
+            fg="green",
+        )
+    )
 
 
 
 
 def migrate_knowledge_vector_database():
 def migrate_knowledge_vector_database():
     """
     """
     Migrate vector database datas to target vector database .
     Migrate vector database datas to target vector database .
     """
     """
-    click.echo(click.style('Start migrate vector db.', fg='green'))
+    click.echo(click.style("Start migrate vector db.", fg="green"))
     create_count = 0
     create_count = 0
     skipped_count = 0
     skipped_count = 0
     total_count = 0
     total_count = 0
@@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
     page = 1
     page = 1
     while True:
     while True:
         try:
         try:
-            datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
-                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+            datasets = (
+                db.session.query(Dataset)
+                .filter(Dataset.indexing_technique == "high_quality")
+                .order_by(Dataset.created_at.desc())
+                .paginate(page=page, per_page=50)
+            )
         except NotFound:
         except NotFound:
             break
             break
 
 
         page += 1
         page += 1
         for dataset in datasets:
         for dataset in datasets:
             total_count = total_count + 1
             total_count = total_count + 1
-            click.echo(f'Processing the {total_count} dataset {dataset.id}. '
-                       + f'{create_count} created, {skipped_count} skipped.')
+            click.echo(
+                f"Processing the {total_count} dataset {dataset.id}. "
+                + f"{create_count} created, {skipped_count} skipped."
+            )
             try:
             try:
-                click.echo('Create dataset vdb index: {}'.format(dataset.id))
+                click.echo("Create dataset vdb index: {}".format(dataset.id))
                 if dataset.index_struct_dict:
                 if dataset.index_struct_dict:
-                    if dataset.index_struct_dict['type'] == vector_type:
+                    if dataset.index_struct_dict["type"] == vector_type:
                         skipped_count = skipped_count + 1
                         skipped_count = skipped_count + 1
                         continue
                         continue
-                collection_name = ''
+                collection_name = ""
                 if vector_type == VectorType.WEAVIATE:
                 if vector_type == VectorType.WEAVIATE:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": VectorType.WEAVIATE,
-                        "vector_store": {"class_prefix": collection_name}
-                    }
+                    index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.QDRANT:
                 elif vector_type == VectorType.QDRANT:
                     if dataset.collection_binding_id:
                     if dataset.collection_binding_id:
-                        dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
-                            filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
-                            one_or_none()
+                        dataset_collection_binding = (
+                            db.session.query(DatasetCollectionBinding)
+                            .filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
+                            .one_or_none()
+                        )
                         if dataset_collection_binding:
                         if dataset_collection_binding:
                             collection_name = dataset_collection_binding.collection_name
                             collection_name = dataset_collection_binding.collection_name
                         else:
                         else:
-                            raise ValueError('Dataset Collection Bindings is not exist!')
+                            raise ValueError("Dataset Collection Bindings is not exist!")
                     else:
                     else:
                         dataset_id = dataset.id
                         dataset_id = dataset.id
                         collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                         collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": VectorType.QDRANT,
-                        "vector_store": {"class_prefix": collection_name}
-                    }
+                    index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
 
 
                 elif vector_type == VectorType.MILVUS:
                 elif vector_type == VectorType.MILVUS:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": VectorType.MILVUS,
-                        "vector_store": {"class_prefix": collection_name}
-                    }
+                    index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.RELYT:
                 elif vector_type == VectorType.RELYT:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": 'relyt',
-                        "vector_store": {"class_prefix": collection_name}
-                    }
+                    index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.TENCENT:
                 elif vector_type == VectorType.TENCENT:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": VectorType.TENCENT,
-                        "vector_store": {"class_prefix": collection_name}
-                    }
+                    index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.PGVECTOR:
                 elif vector_type == VectorType.PGVECTOR:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": VectorType.PGVECTOR,
-                        "vector_store": {"class_prefix": collection_name}
-                    }
+                    index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.OPENSEARCH:
                 elif vector_type == VectorType.OPENSEARCH:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
                     index_struct_dict = {
                         "type": VectorType.OPENSEARCH,
                         "type": VectorType.OPENSEARCH,
-                        "vector_store": {"class_prefix": collection_name}
+                        "vector_store": {"class_prefix": collection_name},
                     }
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.ANALYTICDB:
                 elif vector_type == VectorType.ANALYTICDB:
@@ -341,16 +340,13 @@ def migrate_knowledge_vector_database():
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
                     index_struct_dict = {
                         "type": VectorType.ANALYTICDB,
                         "type": VectorType.ANALYTICDB,
-                        "vector_store": {"class_prefix": collection_name}
+                        "vector_store": {"class_prefix": collection_name},
                     }
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.ELASTICSEARCH:
                 elif vector_type == VectorType.ELASTICSEARCH:
                     dataset_id = dataset.id
                     dataset_id = dataset.id
                     index_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_name = Dataset.gen_collection_name_by_id(dataset_id)
-                    index_struct_dict = {
-                        "type": 'elasticsearch',
-                        "vector_store": {"class_prefix": index_name}
-                    }
+                    index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
                     dataset.index_struct = json.dumps(index_struct_dict)
                 else:
                 else:
                     raise ValueError(f"Vector store {vector_type} is not supported.")
                     raise ValueError(f"Vector store {vector_type} is not supported.")
@@ -361,29 +357,41 @@ def migrate_knowledge_vector_database():
                 try:
                 try:
                     vector.delete()
                     vector.delete()
                     click.echo(
                     click.echo(
-                        click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
-                                    fg='green'))
+                        click.style(
+                            f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
+                        )
+                    )
                 except Exception as e:
                 except Exception as e:
                     click.echo(
                     click.echo(
-                        click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
-                                    fg='red'))
+                        click.style(
+                            f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
+                        )
+                    )
                     raise e
                     raise e
 
 
-                dataset_documents = db.session.query(DatasetDocument).filter(
-                    DatasetDocument.dataset_id == dataset.id,
-                    DatasetDocument.indexing_status == 'completed',
-                    DatasetDocument.enabled == True,
-                    DatasetDocument.archived == False,
-                ).all()
+                dataset_documents = (
+                    db.session.query(DatasetDocument)
+                    .filter(
+                        DatasetDocument.dataset_id == dataset.id,
+                        DatasetDocument.indexing_status == "completed",
+                        DatasetDocument.enabled == True,
+                        DatasetDocument.archived == False,
+                    )
+                    .all()
+                )
 
 
                 documents = []
                 documents = []
                 segments_count = 0
                 segments_count = 0
                 for dataset_document in dataset_documents:
                 for dataset_document in dataset_documents:
-                    segments = db.session.query(DocumentSegment).filter(
-                        DocumentSegment.document_id == dataset_document.id,
-                        DocumentSegment.status == 'completed',
-                        DocumentSegment.enabled == True
-                    ).all()
+                    segments = (
+                        db.session.query(DocumentSegment)
+                        .filter(
+                            DocumentSegment.document_id == dataset_document.id,
+                            DocumentSegment.status == "completed",
+                            DocumentSegment.enabled == True,
+                        )
+                        .all()
+                    )
 
 
                     for segment in segments:
                     for segment in segments:
                         document = Document(
                         document = Document(
@@ -393,7 +401,7 @@ def migrate_knowledge_vector_database():
                                 "doc_hash": segment.index_node_hash,
                                 "doc_hash": segment.index_node_hash,
                                 "document_id": segment.document_id,
                                 "document_id": segment.document_id,
                                 "dataset_id": segment.dataset_id,
                                 "dataset_id": segment.dataset_id,
-                            }
+                            },
                         )
                         )
 
 
                         documents.append(document)
                         documents.append(document)
@@ -401,37 +409,43 @@ def migrate_knowledge_vector_database():
 
 
                 if documents:
                 if documents:
                     try:
                     try:
-                        click.echo(click.style(
-                            f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
-                            fg='green'))
+                        click.echo(
+                            click.style(
+                                f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
+                                fg="green",
+                            )
+                        )
                         vector.create(documents)
                         vector.create(documents)
                         click.echo(
                         click.echo(
-                            click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
+                            click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
+                        )
                     except Exception as e:
                     except Exception as e:
-                        click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
+                        click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
                         raise e
                         raise e
                 db.session.add(dataset)
                 db.session.add(dataset)
                 db.session.commit()
                 db.session.commit()
-                click.echo(f'Successfully migrated dataset {dataset.id}.')
+                click.echo(f"Successfully migrated dataset {dataset.id}.")
                 create_count += 1
                 create_count += 1
             except Exception as e:
             except Exception as e:
                 db.session.rollback()
                 db.session.rollback()
                 click.echo(
                 click.echo(
-                    click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
-                                fg='red'))
+                    click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
+                )
                 continue
                 continue
 
 
     click.echo(
     click.echo(
-        click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
-                    fg='green'))
+        click.style(
+            f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
+        )
+    )
 
 
 
 
-@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
+@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
 def convert_to_agent_apps():
 def convert_to_agent_apps():
     """
     """
     Convert Agent Assistant to Agent App.
     Convert Agent Assistant to Agent App.
     """
     """
-    click.echo(click.style('Start convert to agent apps.', fg='green'))
+    click.echo(click.style("Start convert to agent apps.", fg="green"))
 
 
     proceeded_app_ids = []
     proceeded_app_ids = []
 
 
@@ -466,7 +480,7 @@ def convert_to_agent_apps():
                 break
                 break
 
 
         for app in apps:
         for app in apps:
-            click.echo('Converting app: {}'.format(app.id))
+            click.echo("Converting app: {}".format(app.id))
 
 
             try:
             try:
                 app.mode = AppMode.AGENT_CHAT.value
                 app.mode = AppMode.AGENT_CHAT.value
@@ -478,137 +492,139 @@ def convert_to_agent_apps():
                 )
                 )
 
 
                 db.session.commit()
                 db.session.commit()
-                click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
+                click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
             except Exception as e:
             except Exception as e:
-                click.echo(
-                    click.style('Convert app error: {} {}'.format(e.__class__.__name__,
-                                                                  str(e)), fg='red'))
+                click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
 
 
-    click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
+    click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
 
 
 
 
-@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
-@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
+@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
+@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
 def add_qdrant_doc_id_index(field: str):
 def add_qdrant_doc_id_index(field: str):
-    click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
+    click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
     vector_type = dify_config.VECTOR_STORE
     vector_type = dify_config.VECTOR_STORE
     if vector_type != "qdrant":
     if vector_type != "qdrant":
-        click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
+        click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
         return
         return
     create_count = 0
     create_count = 0
 
 
     try:
     try:
         bindings = db.session.query(DatasetCollectionBinding).all()
         bindings = db.session.query(DatasetCollectionBinding).all()
         if not bindings:
         if not bindings:
-            click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
+            click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
             return
             return
         import qdrant_client
         import qdrant_client
         from qdrant_client.http.exceptions import UnexpectedResponse
         from qdrant_client.http.exceptions import UnexpectedResponse
         from qdrant_client.http.models import PayloadSchemaType
         from qdrant_client.http.models import PayloadSchemaType
 
 
         from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
         from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
+
         for binding in bindings:
         for binding in bindings:
             if dify_config.QDRANT_URL is None:
             if dify_config.QDRANT_URL is None:
-                raise ValueError('Qdrant url is required.')
+                raise ValueError("Qdrant url is required.")
             qdrant_config = QdrantConfig(
             qdrant_config = QdrantConfig(
                 endpoint=dify_config.QDRANT_URL,
                 endpoint=dify_config.QDRANT_URL,
                 api_key=dify_config.QDRANT_API_KEY,
                 api_key=dify_config.QDRANT_API_KEY,
                 root_path=current_app.root_path,
                 root_path=current_app.root_path,
                 timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
                 timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
                 grpc_port=dify_config.QDRANT_GRPC_PORT,
                 grpc_port=dify_config.QDRANT_GRPC_PORT,
-                prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
+                prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
             )
             )
             try:
             try:
                 client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
                 client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
                 # create payload index
                 # create payload index
-                client.create_payload_index(binding.collection_name, field,
-                                            field_schema=PayloadSchemaType.KEYWORD)
+                client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
                 create_count += 1
                 create_count += 1
             except UnexpectedResponse as e:
             except UnexpectedResponse as e:
                 # Collection does not exist, so return
                 # Collection does not exist, so return
                 if e.status_code == 404:
                 if e.status_code == 404:
-                    click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
+                    click.echo(
+                        click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
+                    )
                     continue
                     continue
                 # Some other error occurred, so re-raise the exception
                 # Some other error occurred, so re-raise the exception
                 else:
                 else:
-                    click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
+                    click.echo(
+                        click.style(
+                            f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
+                        )
+                    )
 
 
     except Exception as e:
     except Exception as e:
-        click.echo(click.style('Failed to create qdrant client.', fg='red'))
+        click.echo(click.style("Failed to create qdrant client.", fg="red"))
 
 
-    click.echo(
-        click.style(f'Congratulations! Create {create_count} collection indexes.',
-                    fg='green'))
+    click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
 
 
 
 
-@click.command('create-tenant', help='Create account and tenant.')
-@click.option('--email', prompt=True, help='The email address of the tenant account.')
-@click.option('--language', prompt=True, help='Account language, default: en-US.')
+@click.command("create-tenant", help="Create account and tenant.")
+@click.option("--email", prompt=True, help="The email address of the tenant account.")
+@click.option("--language", prompt=True, help="Account language, default: en-US.")
 def create_tenant(email: str, language: Optional[str] = None):
 def create_tenant(email: str, language: Optional[str] = None):
     """
     """
     Create tenant account
     Create tenant account
     """
     """
     if not email:
     if not email:
-        click.echo(click.style('Sorry, email is required.', fg='red'))
+        click.echo(click.style("Sorry, email is required.", fg="red"))
         return
         return
 
 
     # Create account
     # Create account
     email = email.strip()
     email = email.strip()
 
 
-    if '@' not in email:
-        click.echo(click.style('Sorry, invalid email address.', fg='red'))
+    if "@" not in email:
+        click.echo(click.style("Sorry, invalid email address.", fg="red"))
         return
         return
 
 
-    account_name = email.split('@')[0]
+    account_name = email.split("@")[0]
 
 
     if language not in languages:
     if language not in languages:
-        language = 'en-US'
+        language = "en-US"
 
 
     # generate random password
     # generate random password
     new_password = secrets.token_urlsafe(16)
     new_password = secrets.token_urlsafe(16)
 
 
     # register account
     # register account
-    account = RegisterService.register(
-        email=email,
-        name=account_name,
-        password=new_password,
-        language=language
-    )
+    account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
 
 
     TenantService.create_owner_tenant_if_not_exist(account)
     TenantService.create_owner_tenant_if_not_exist(account)
 
 
-    click.echo(click.style('Congratulations! Account and tenant created.\n'
-                           'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
+    click.echo(
+        click.style(
+            "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
+            fg="green",
+        )
+    )
 
 
 
 
-@click.command('upgrade-db', help='upgrade the database')
+@click.command("upgrade-db", help="upgrade the database")
 def upgrade_db():
 def upgrade_db():
-    click.echo('Preparing database migration...')
-    lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
+    click.echo("Preparing database migration...")
+    lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
     if lock.acquire(blocking=False):
     if lock.acquire(blocking=False):
         try:
         try:
-            click.echo(click.style('Start database migration.', fg='green'))
+            click.echo(click.style("Start database migration.", fg="green"))
 
 
             # run db migration
             # run db migration
             import flask_migrate
             import flask_migrate
+
             flask_migrate.upgrade()
             flask_migrate.upgrade()
 
 
-            click.echo(click.style('Database migration successful!', fg='green'))
+            click.echo(click.style("Database migration successful!", fg="green"))
 
 
         except Exception as e:
         except Exception as e:
-            logging.exception(f'Database migration failed, error: {e}')
+            logging.exception(f"Database migration failed, error: {e}")
         finally:
         finally:
             lock.release()
             lock.release()
     else:
     else:
-        click.echo('Database migration skipped')
+        click.echo("Database migration skipped")
 
 
 
 
-@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
+@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
 def fix_app_site_missing():
 def fix_app_site_missing():
     """
     """
     Fix app related site missing issue.
     Fix app related site missing issue.
     """
     """
-    click.echo(click.style('Start fix app related site missing issue.', fg='green'))
+    click.echo(click.style("Start fix app related site missing issue.", fg="green"))
 
 
     failed_app_ids = []
     failed_app_ids = []
     while True:
     while True:
@@ -639,15 +655,14 @@ where sites.id is null limit 1000"""
                         app_was_created.send(app, account=account)
                         app_was_created.send(app, account=account)
                 except Exception as e:
                 except Exception as e:
                     failed_app_ids.append(app_id)
                     failed_app_ids.append(app_id)
-                    click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
-                    logging.exception(f'Fix app related site missing issue failed, error: {e}')
+                    click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
+                    logging.exception(f"Fix app related site missing issue failed, error: {e}")
                     continue
                     continue
 
 
             if not processed_count:
             if not processed_count:
                 break
                 break
 
 
-
-    click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
+    click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
 
 
 
 
 def register_commands(app):
 def register_commands(app):

+ 1 - 1
api/constants/__init__.py

@@ -1 +1 @@
-HIDDEN_VALUE = '[__HIDDEN__]'
+HIDDEN_VALUE = "[__HIDDEN__]"

+ 19 - 20
api/constants/languages.py

@@ -1,22 +1,22 @@
 language_timezone_mapping = {
 language_timezone_mapping = {
-    'en-US': 'America/New_York',
-    'zh-Hans': 'Asia/Shanghai',
-    'zh-Hant': 'Asia/Taipei',
-    'pt-BR': 'America/Sao_Paulo',
-    'es-ES': 'Europe/Madrid',
-    'fr-FR': 'Europe/Paris',
-    'de-DE': 'Europe/Berlin',
-    'ja-JP': 'Asia/Tokyo',
-    'ko-KR': 'Asia/Seoul',
-    'ru-RU': 'Europe/Moscow',
-    'it-IT': 'Europe/Rome',
-    'uk-UA': 'Europe/Kyiv',
-    'vi-VN': 'Asia/Ho_Chi_Minh',
-    'ro-RO': 'Europe/Bucharest',
-    'pl-PL': 'Europe/Warsaw',
-    'hi-IN': 'Asia/Kolkata',
-    'tr-TR': 'Europe/Istanbul',
-    'fa-IR': 'Asia/Tehran',
+    "en-US": "America/New_York",
+    "zh-Hans": "Asia/Shanghai",
+    "zh-Hant": "Asia/Taipei",
+    "pt-BR": "America/Sao_Paulo",
+    "es-ES": "Europe/Madrid",
+    "fr-FR": "Europe/Paris",
+    "de-DE": "Europe/Berlin",
+    "ja-JP": "Asia/Tokyo",
+    "ko-KR": "Asia/Seoul",
+    "ru-RU": "Europe/Moscow",
+    "it-IT": "Europe/Rome",
+    "uk-UA": "Europe/Kyiv",
+    "vi-VN": "Asia/Ho_Chi_Minh",
+    "ro-RO": "Europe/Bucharest",
+    "pl-PL": "Europe/Warsaw",
+    "hi-IN": "Asia/Kolkata",
+    "tr-TR": "Europe/Istanbul",
+    "fa-IR": "Asia/Tehran",
 }
 }
 
 
 languages = list(language_timezone_mapping.keys())
 languages = list(language_timezone_mapping.keys())
@@ -26,6 +26,5 @@ def supported_language(lang):
     if lang in languages:
     if lang in languages:
         return lang
         return lang
 
 
-    error = ('{lang} is not a valid language.'
-             .format(lang=lang))
+    error = "{lang} is not a valid language.".format(lang=lang)
     raise ValueError(error)
     raise ValueError(error)

+ 48 - 51
api/constants/model_template.py

@@ -5,82 +5,79 @@ from models.model import AppMode
 default_app_templates = {
 default_app_templates = {
     # workflow default mode
     # workflow default mode
     AppMode.WORKFLOW: {
     AppMode.WORKFLOW: {
-        'app': {
-            'mode': AppMode.WORKFLOW.value,
-            'enable_site': True,
-            'enable_api': True
+        "app": {
+            "mode": AppMode.WORKFLOW.value,
+            "enable_site": True,
+            "enable_api": True,
         }
         }
     },
     },
-
     # completion default mode
     # completion default mode
     AppMode.COMPLETION: {
     AppMode.COMPLETION: {
-        'app': {
-            'mode': AppMode.COMPLETION.value,
-            'enable_site': True,
-            'enable_api': True
+        "app": {
+            "mode": AppMode.COMPLETION.value,
+            "enable_site": True,
+            "enable_api": True,
         },
         },
-        'model_config': {
-            'model': {
+        "model_config": {
+            "model": {
                 "provider": "openai",
                 "provider": "openai",
                 "name": "gpt-4o",
                 "name": "gpt-4o",
                 "mode": "chat",
                 "mode": "chat",
-                "completion_params": {}
+                "completion_params": {},
             },
             },
-            'user_input_form': json.dumps([
-                {
-                    "paragraph": {
-                        "label": "Query",
-                        "variable": "query",
-                        "required": True,
-                        "default": ""
-                    }
-                }
-            ]),
-            'pre_prompt': '{{query}}'
+            "user_input_form": json.dumps(
+                [
+                    {
+                        "paragraph": {
+                            "label": "Query",
+                            "variable": "query",
+                            "required": True,
+                            "default": "",
+                        },
+                    },
+                ]
+            ),
+            "pre_prompt": "{{query}}",
         },
         },
-
     },
     },
-
     # chat default mode
     # chat default mode
     AppMode.CHAT: {
     AppMode.CHAT: {
-        'app': {
-            'mode': AppMode.CHAT.value,
-            'enable_site': True,
-            'enable_api': True
+        "app": {
+            "mode": AppMode.CHAT.value,
+            "enable_site": True,
+            "enable_api": True,
         },
         },
-        'model_config': {
-            'model': {
+        "model_config": {
+            "model": {
                 "provider": "openai",
                 "provider": "openai",
                 "name": "gpt-4o",
                 "name": "gpt-4o",
                 "mode": "chat",
                 "mode": "chat",
-                "completion_params": {}
-            }
-        }
+                "completion_params": {},
+            },
+        },
     },
     },
-
     # advanced-chat default mode
     # advanced-chat default mode
     AppMode.ADVANCED_CHAT: {
     AppMode.ADVANCED_CHAT: {
-        'app': {
-            'mode': AppMode.ADVANCED_CHAT.value,
-            'enable_site': True,
-            'enable_api': True
-        }
+        "app": {
+            "mode": AppMode.ADVANCED_CHAT.value,
+            "enable_site": True,
+            "enable_api": True,
+        },
     },
     },
-
     # agent-chat default mode
     # agent-chat default mode
     AppMode.AGENT_CHAT: {
     AppMode.AGENT_CHAT: {
-        'app': {
-            'mode': AppMode.AGENT_CHAT.value,
-            'enable_site': True,
-            'enable_api': True
+        "app": {
+            "mode": AppMode.AGENT_CHAT.value,
+            "enable_site": True,
+            "enable_api": True,
         },
         },
-        'model_config': {
-            'model': {
+        "model_config": {
+            "model": {
                 "provider": "openai",
                 "provider": "openai",
                 "name": "gpt-4o",
                 "name": "gpt-4o",
                 "mode": "chat",
                 "mode": "chat",
-                "completion_params": {}
-            }
-        }
-    }
+                "completion_params": {},
+            },
+        },
+    },
 }
 }

+ 2 - 2
api/contexts/__init__.py

@@ -2,6 +2,6 @@ from contextvars import ContextVar
 
 
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.entities.variable_pool import VariablePool
 
 
-tenant_id: ContextVar[str] = ContextVar('tenant_id')
+tenant_id: ContextVar[str] = ContextVar("tenant_id")
 
 
-workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')
+workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")

+ 4 - 4
api/events/app_event.py

@@ -1,13 +1,13 @@
 from blinker import signal
 from blinker import signal
 
 
 # sender: app
 # sender: app
-app_was_created = signal('app-was-created')
+app_was_created = signal("app-was-created")
 
 
 # sender: app, kwargs: app_model_config
 # sender: app, kwargs: app_model_config
-app_model_config_was_updated = signal('app-model-config-was-updated')
+app_model_config_was_updated = signal("app-model-config-was-updated")
 
 
 # sender: app, kwargs: published_workflow
 # sender: app, kwargs: published_workflow
-app_published_workflow_was_updated = signal('app-published-workflow-was-updated')
+app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
 
 
 # sender: app, kwargs: synced_draft_workflow
 # sender: app, kwargs: synced_draft_workflow
-app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced')
+app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")

+ 1 - 1
api/events/dataset_event.py

@@ -1,4 +1,4 @@
 from blinker import signal
 from blinker import signal
 
 
 # sender: dataset
 # sender: dataset
-dataset_was_deleted = signal('dataset-was-deleted')
+dataset_was_deleted = signal("dataset-was-deleted")

+ 1 - 1
api/events/document_event.py

@@ -1,4 +1,4 @@
 from blinker import signal
 from blinker import signal
 
 
 # sender: document
 # sender: document
-document_was_deleted = signal('document-was-deleted')
+document_was_deleted = signal("document-was-deleted")

+ 8 - 2
api/events/event_handlers/clean_when_dataset_deleted.py

@@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task
 @dataset_was_deleted.connect
 @dataset_was_deleted.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     dataset = sender
     dataset = sender
-    clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
-                             dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)
+    clean_dataset_task.delay(
+        dataset.id,
+        dataset.tenant_id,
+        dataset.indexing_technique,
+        dataset.index_struct,
+        dataset.collection_binding_id,
+        dataset.doc_form,
+    )

+ 3 - 3
api/events/event_handlers/clean_when_document_deleted.py

@@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task
 @document_was_deleted.connect
 @document_was_deleted.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     document_id = sender
     document_id = sender
-    dataset_id = kwargs.get('dataset_id')
-    doc_form = kwargs.get('doc_form')
-    file_id = kwargs.get('file_id')
+    dataset_id = kwargs.get("dataset_id")
+    doc_form = kwargs.get("doc_form")
+    file_id = kwargs.get("file_id")
     clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
     clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

+ 14 - 10
api/events/event_handlers/create_document_index.py

@@ -14,21 +14,25 @@ from models.dataset import Document
 @document_index_created.connect
 @document_index_created.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     dataset_id = sender
     dataset_id = sender
-    document_ids = kwargs.get('document_ids', None)
+    document_ids = kwargs.get("document_ids", None)
     documents = []
     documents = []
     start_at = time.perf_counter()
     start_at = time.perf_counter()
     for document_id in document_ids:
     for document_id in document_ids:
-        logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
+        logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
 
 
-        document = db.session.query(Document).filter(
-            Document.id == document_id,
-            Document.dataset_id == dataset_id
-        ).first()
+        document = (
+            db.session.query(Document)
+            .filter(
+                Document.id == document_id,
+                Document.dataset_id == dataset_id,
+            )
+            .first()
+        )
 
 
         if not document:
         if not document:
-            raise NotFound('Document not found')
+            raise NotFound("Document not found")
 
 
-        document.indexing_status = 'parsing'
+        document.indexing_status = "parsing"
         document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         documents.append(document)
         documents.append(document)
         db.session.add(document)
         db.session.add(document)
@@ -38,8 +42,8 @@ def handle(sender, **kwargs):
         indexing_runner = IndexingRunner()
         indexing_runner = IndexingRunner()
         indexing_runner.run(documents)
         indexing_runner.run(documents)
         end_at = time.perf_counter()
         end_at = time.perf_counter()
-        logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
+        logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
     except DocumentIsPausedException as ex:
     except DocumentIsPausedException as ex:
-        logging.info(click.style(str(ex), fg='yellow'))
+        logging.info(click.style(str(ex), fg="yellow"))
     except Exception:
     except Exception:
         pass
         pass

+ 1 - 1
api/events/event_handlers/create_installed_app_when_app_created.py

@@ -10,7 +10,7 @@ def handle(sender, **kwargs):
     installed_app = InstalledApp(
     installed_app = InstalledApp(
         tenant_id=app.tenant_id,
         tenant_id=app.tenant_id,
         app_id=app.id,
         app_id=app.id,
-        app_owner_tenant_id=app.tenant_id
+        app_owner_tenant_id=app.tenant_id,
     )
     )
     db.session.add(installed_app)
     db.session.add(installed_app)
     db.session.commit()
     db.session.commit()

+ 5 - 5
api/events/event_handlers/create_site_record_when_app_created.py

@@ -7,15 +7,15 @@ from models.model import Site
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     """Create site record when an app is created."""
     """Create site record when an app is created."""
     app = sender
     app = sender
-    account = kwargs.get('account')
+    account = kwargs.get("account")
     site = Site(
     site = Site(
         app_id=app.id,
         app_id=app.id,
         title=app.name,
         title=app.name,
-        icon = app.icon,
-        icon_background = app.icon_background,
+        icon=app.icon,
+        icon_background=app.icon_background,
         default_language=account.interface_language,
         default_language=account.interface_language,
-        customize_token_strategy='not_allow',
-        code=Site.generate_code(16)
+        customize_token_strategy="not_allow",
+        code=Site.generate_code(16),
     )
     )
 
 
     db.session.add(site)
     db.session.add(site)

+ 4 - 4
api/events/event_handlers/deduct_quota_when_messaeg_created.py

@@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
 @message_was_created.connect
 @message_was_created.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     message = sender
     message = sender
-    application_generate_entity = kwargs.get('application_generate_entity')
+    application_generate_entity = kwargs.get("application_generate_entity")
 
 
     if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
     if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
         return
         return
@@ -39,7 +39,7 @@ def handle(sender, **kwargs):
         elif quota_unit == QuotaUnit.CREDITS:
         elif quota_unit == QuotaUnit.CREDITS:
             used_quota = 1
             used_quota = 1
 
 
-            if 'gpt-4' in model_config.model:
+            if "gpt-4" in model_config.model:
                 used_quota = 20
                 used_quota = 20
         else:
         else:
             used_quota = 1
             used_quota = 1
@@ -50,6 +50,6 @@ def handle(sender, **kwargs):
             Provider.provider_name == model_config.provider,
             Provider.provider_name == model_config.provider,
             Provider.provider_type == ProviderType.SYSTEM.value,
             Provider.provider_type == ProviderType.SYSTEM.value,
             Provider.quota_type == system_configuration.current_quota_type.value,
             Provider.quota_type == system_configuration.current_quota_type.value,
-            Provider.quota_limit > Provider.quota_used
-        ).update({'quota_used': Provider.quota_used + used_quota})
+            Provider.quota_limit > Provider.quota_used,
+        ).update({"quota_used": Provider.quota_used + used_quota})
         db.session.commit()
         db.session.commit()

+ 3 - 3
api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py

@@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced
 @app_draft_workflow_was_synced.connect
 @app_draft_workflow_was_synced.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     app = sender
     app = sender
-    for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []):
-        if node_data.get('data', {}).get('type') == NodeType.TOOL.value:
+    for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
+        if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
             try:
             try:
                 tool_entity = ToolEntity(**node_data["data"])
                 tool_entity = ToolEntity(**node_data["data"])
                 tool_runtime = ToolManager.get_tool_runtime(
                 tool_runtime = ToolManager.get_tool_runtime(
@@ -23,7 +23,7 @@ def handle(sender, **kwargs):
                     tool_runtime=tool_runtime,
                     tool_runtime=tool_runtime,
                     provider_name=tool_entity.provider_name,
                     provider_name=tool_entity.provider_name,
                     provider_type=tool_entity.provider_type,
                     provider_type=tool_entity.provider_type,
-                    identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
+                    identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}',
                 )
                 )
                 manager.delete_tool_parameters_cache()
                 manager.delete_tool_parameters_cache()
             except:
             except:

+ 1 - 1
api/events/event_handlers/document_index_event.py

@@ -1,4 +1,4 @@
 from blinker import signal
 from blinker import signal
 
 
 # sender: document
 # sender: document
-document_index_created = signal('document-index-created')
+document_index_created = signal("document-index-created")

+ 10 - 16
api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py

@@ -7,13 +7,11 @@ from models.model import AppModelConfig
 @app_model_config_was_updated.connect
 @app_model_config_was_updated.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     app = sender
     app = sender
-    app_model_config = kwargs.get('app_model_config')
+    app_model_config = kwargs.get("app_model_config")
 
 
     dataset_ids = get_dataset_ids_from_model_config(app_model_config)
     dataset_ids = get_dataset_ids_from_model_config(app_model_config)
 
 
-    app_dataset_joins = db.session.query(AppDatasetJoin).filter(
-        AppDatasetJoin.app_id == app.id
-    ).all()
+    app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
 
 
     removed_dataset_ids = []
     removed_dataset_ids = []
     if not app_dataset_joins:
     if not app_dataset_joins:
@@ -29,16 +27,12 @@ def handle(sender, **kwargs):
     if removed_dataset_ids:
     if removed_dataset_ids:
         for dataset_id in removed_dataset_ids:
         for dataset_id in removed_dataset_ids:
             db.session.query(AppDatasetJoin).filter(
             db.session.query(AppDatasetJoin).filter(
-                AppDatasetJoin.app_id == app.id,
-                AppDatasetJoin.dataset_id == dataset_id
+                AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
             ).delete()
             ).delete()
 
 
     if added_dataset_ids:
     if added_dataset_ids:
         for dataset_id in added_dataset_ids:
         for dataset_id in added_dataset_ids:
-            app_dataset_join = AppDatasetJoin(
-                app_id=app.id,
-                dataset_id=dataset_id
-            )
+            app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
             db.session.add(app_dataset_join)
             db.session.add(app_dataset_join)
 
 
     db.session.commit()
     db.session.commit()
@@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
 
 
     agent_mode = app_model_config.agent_mode_dict
     agent_mode = app_model_config.agent_mode_dict
 
 
-    tools = agent_mode.get('tools', []) or []
+    tools = agent_mode.get("tools", []) or []
     for tool in tools:
     for tool in tools:
         if len(list(tool.keys())) != 1:
         if len(list(tool.keys())) != 1:
             continue
             continue
@@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
 
 
     # get dataset from dataset_configs
     # get dataset from dataset_configs
     dataset_configs = app_model_config.dataset_configs_dict
     dataset_configs = app_model_config.dataset_configs_dict
-    datasets = dataset_configs.get('datasets', {}) or {}
-    for dataset in datasets.get('datasets', []) or []:
+    datasets = dataset_configs.get("datasets", {}) or {}
+    for dataset in datasets.get("datasets", []) or []:
         keys = list(dataset.keys())
         keys = list(dataset.keys())
-        if len(keys) == 1 and keys[0] == 'dataset':
-            if dataset['dataset'].get('id'):
-                dataset_ids.add(dataset['dataset'].get('id'))
+        if len(keys) == 1 and keys[0] == "dataset":
+            if dataset["dataset"].get("id"):
+                dataset_ids.add(dataset["dataset"].get("id"))
 
 
     return dataset_ids
     return dataset_ids

+ 9 - 14
api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py

@@ -11,13 +11,11 @@ from models.workflow import Workflow
 @app_published_workflow_was_updated.connect
 @app_published_workflow_was_updated.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     app = sender
     app = sender
-    published_workflow = kwargs.get('published_workflow')
+    published_workflow = kwargs.get("published_workflow")
     published_workflow = cast(Workflow, published_workflow)
     published_workflow = cast(Workflow, published_workflow)
 
 
     dataset_ids = get_dataset_ids_from_workflow(published_workflow)
     dataset_ids = get_dataset_ids_from_workflow(published_workflow)
-    app_dataset_joins = db.session.query(AppDatasetJoin).filter(
-        AppDatasetJoin.app_id == app.id
-    ).all()
+    app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
 
 
     removed_dataset_ids = []
     removed_dataset_ids = []
     if not app_dataset_joins:
     if not app_dataset_joins:
@@ -33,16 +31,12 @@ def handle(sender, **kwargs):
     if removed_dataset_ids:
     if removed_dataset_ids:
         for dataset_id in removed_dataset_ids:
         for dataset_id in removed_dataset_ids:
             db.session.query(AppDatasetJoin).filter(
             db.session.query(AppDatasetJoin).filter(
-                AppDatasetJoin.app_id == app.id,
-                AppDatasetJoin.dataset_id == dataset_id
+                AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
             ).delete()
             ).delete()
 
 
     if added_dataset_ids:
     if added_dataset_ids:
         for dataset_id in added_dataset_ids:
         for dataset_id in added_dataset_ids:
-            app_dataset_join = AppDatasetJoin(
-                app_id=app.id,
-                dataset_id=dataset_id
-            )
+            app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
             db.session.add(app_dataset_join)
             db.session.add(app_dataset_join)
 
 
     db.session.commit()
     db.session.commit()
@@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
     if not graph:
     if not graph:
         return dataset_ids
         return dataset_ids
 
 
-    nodes = graph.get('nodes', [])
+    nodes = graph.get("nodes", [])
 
 
     # fetch all knowledge retrieval nodes
     # fetch all knowledge retrieval nodes
-    knowledge_retrieval_nodes = [node for node in nodes
-                                 if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value]
+    knowledge_retrieval_nodes = [
+        node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
+    ]
 
 
     if not knowledge_retrieval_nodes:
     if not knowledge_retrieval_nodes:
         return dataset_ids
         return dataset_ids
 
 
     for node in knowledge_retrieval_nodes:
     for node in knowledge_retrieval_nodes:
         try:
         try:
-            node_data = KnowledgeRetrievalNodeData(**node.get('data', {}))
+            node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
             dataset_ids.update(node_data.dataset_ids)
             dataset_ids.update(node_data.dataset_ids)
         except Exception as e:
         except Exception as e:
             continue
             continue

+ 3 - 3
api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py

@@ -9,13 +9,13 @@ from models.provider import Provider
 @message_was_created.connect
 @message_was_created.connect
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     message = sender
     message = sender
-    application_generate_entity = kwargs.get('application_generate_entity')
+    application_generate_entity = kwargs.get("application_generate_entity")
 
 
     if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
     if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
         return
         return
 
 
     db.session.query(Provider).filter(
     db.session.query(Provider).filter(
         Provider.tenant_id == application_generate_entity.app_config.tenant_id,
         Provider.tenant_id == application_generate_entity.app_config.tenant_id,
-        Provider.provider_name == application_generate_entity.model_conf.provider
-    ).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)})
+        Provider.provider_name == application_generate_entity.model_conf.provider,
+    ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)})
     db.session.commit()
     db.session.commit()

+ 1 - 1
api/events/message_event.py

@@ -1,4 +1,4 @@
 from blinker import signal
 from blinker import signal
 
 
 # sender: message, kwargs: conversation
 # sender: message, kwargs: conversation
-message_was_created = signal('message-was-created')
+message_was_created = signal("message-was-created")

+ 2 - 2
api/events/tenant_event.py

@@ -1,7 +1,7 @@
 from blinker import signal
 from blinker import signal
 
 
 # sender: tenant
 # sender: tenant
-tenant_was_created = signal('tenant-was-created')
+tenant_was_created = signal("tenant-was-created")
 
 
 # sender: tenant
 # sender: tenant
-tenant_was_updated = signal('tenant-was-updated')
+tenant_was_updated = signal("tenant-was-updated")

+ 10 - 13
api/extensions/ext_celery.py

@@ -17,7 +17,7 @@ def init_app(app: Flask) -> Celery:
         backend=app.config["CELERY_BACKEND"],
         backend=app.config["CELERY_BACKEND"],
         task_ignore_result=True,
         task_ignore_result=True,
     )
     )
-    
+
     # Add SSL options to the Celery configuration
     # Add SSL options to the Celery configuration
     ssl_options = {
     ssl_options = {
         "ssl_cert_reqs": None,
         "ssl_cert_reqs": None,
@@ -35,7 +35,7 @@ def init_app(app: Flask) -> Celery:
         celery_app.conf.update(
         celery_app.conf.update(
             broker_use_ssl=ssl_options,  # Add the SSL options to the broker configuration
             broker_use_ssl=ssl_options,  # Add the SSL options to the broker configuration
         )
         )
-        
+
     celery_app.set_default()
     celery_app.set_default()
     app.extensions["celery"] = celery_app
     app.extensions["celery"] = celery_app
 
 
@@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery:
     ]
     ]
     day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
     day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
     beat_schedule = {
     beat_schedule = {
-        'clean_embedding_cache_task': {
-            'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task',
-            'schedule': timedelta(days=day),
+        "clean_embedding_cache_task": {
+            "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
+            "schedule": timedelta(days=day),
+        },
+        "clean_unused_datasets_task": {
+            "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
+            "schedule": timedelta(days=day),
         },
         },
-        'clean_unused_datasets_task': {
-            'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
-            'schedule': timedelta(days=day),
-        }
     }
     }
-    celery_app.conf.update(
-        beat_schedule=beat_schedule,
-        imports=imports
-    )
+    celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
 
 
     return celery_app
     return celery_app

+ 5 - 6
api/extensions/ext_compress.py

@@ -2,15 +2,14 @@ from flask import Flask
 
 
 
 
 def init_app(app: Flask):
 def init_app(app: Flask):
-    if app.config.get('API_COMPRESSION_ENABLED'):
+    if app.config.get("API_COMPRESSION_ENABLED"):
         from flask_compress import Compress
         from flask_compress import Compress
 
 
-        app.config['COMPRESS_MIMETYPES'] = [
-            'application/json',
-            'image/svg+xml',
-            'text/html',
+        app.config["COMPRESS_MIMETYPES"] = [
+            "application/json",
+            "image/svg+xml",
+            "text/html",
         ]
         ]
 
 
         compress = Compress()
         compress = Compress()
         compress.init_app(app)
         compress.init_app(app)
-

+ 5 - 5
api/extensions/ext_database.py

@@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy
 from sqlalchemy import MetaData
 from sqlalchemy import MetaData
 
 
 POSTGRES_INDEXES_NAMING_CONVENTION = {
 POSTGRES_INDEXES_NAMING_CONVENTION = {
-    'ix': '%(column_0_label)s_idx',
-    'uq': '%(table_name)s_%(column_0_name)s_key',
-    'ck': '%(table_name)s_%(constraint_name)s_check',
-    'fk': '%(table_name)s_%(column_0_name)s_fkey',
-    'pk': '%(table_name)s_pkey',
+    "ix": "%(column_0_label)s_idx",
+    "uq": "%(table_name)s_%(column_0_name)s_key",
+    "ck": "%(table_name)s_%(constraint_name)s_check",
+    "fk": "%(table_name)s_%(column_0_name)s_fkey",
+    "pk": "%(table_name)s_pkey",
 }
 }
 
 
 metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
 metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)

+ 37 - 35
api/extensions/ext_mail.py

@@ -14,67 +14,69 @@ class Mail:
         return self._client is not None
         return self._client is not None
 
 
     def init_app(self, app: Flask):
     def init_app(self, app: Flask):
-        if app.config.get('MAIL_TYPE'):
-            if app.config.get('MAIL_DEFAULT_SEND_FROM'):
-                self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
-            
-            if app.config.get('MAIL_TYPE') == 'resend':
-                api_key = app.config.get('RESEND_API_KEY')
+        if app.config.get("MAIL_TYPE"):
+            if app.config.get("MAIL_DEFAULT_SEND_FROM"):
+                self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
+
+            if app.config.get("MAIL_TYPE") == "resend":
+                api_key = app.config.get("RESEND_API_KEY")
                 if not api_key:
                 if not api_key:
-                    raise ValueError('RESEND_API_KEY is not set')
+                    raise ValueError("RESEND_API_KEY is not set")
 
 
-                api_url = app.config.get('RESEND_API_URL')
+                api_url = app.config.get("RESEND_API_URL")
                 if api_url:
                 if api_url:
                     resend.api_url = api_url
                     resend.api_url = api_url
 
 
                 resend.api_key = api_key
                 resend.api_key = api_key
                 self._client = resend.Emails
                 self._client = resend.Emails
-            elif app.config.get('MAIL_TYPE') == 'smtp':
+            elif app.config.get("MAIL_TYPE") == "smtp":
                 from libs.smtp import SMTPClient
                 from libs.smtp import SMTPClient
-                if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
-                    raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
-                if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'):
-                    raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS')
+
+                if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
+                    raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
+                if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
+                    raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
                 self._client = SMTPClient(
                 self._client = SMTPClient(
-                    server=app.config.get('SMTP_SERVER'),
-                    port=app.config.get('SMTP_PORT'),
-                    username=app.config.get('SMTP_USERNAME'),
-                    password=app.config.get('SMTP_PASSWORD'),
-                    _from=app.config.get('MAIL_DEFAULT_SEND_FROM'),
-                    use_tls=app.config.get('SMTP_USE_TLS'),
-                    opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS')
+                    server=app.config.get("SMTP_SERVER"),
+                    port=app.config.get("SMTP_PORT"),
+                    username=app.config.get("SMTP_USERNAME"),
+                    password=app.config.get("SMTP_PASSWORD"),
+                    _from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
+                    use_tls=app.config.get("SMTP_USE_TLS"),
+                    opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
                 )
                 )
             else:
             else:
-                raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
+                raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
         else:
         else:
-            logging.warning('MAIL_TYPE is not set')
-            
+            logging.warning("MAIL_TYPE is not set")
 
 
     def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
     def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
         if not self._client:
         if not self._client:
-            raise ValueError('Mail client is not initialized')
+            raise ValueError("Mail client is not initialized")
 
 
         if not from_ and self._default_send_from:
         if not from_ and self._default_send_from:
             from_ = self._default_send_from
             from_ = self._default_send_from
 
 
         if not from_:
         if not from_:
-            raise ValueError('mail from is not set')
+            raise ValueError("mail from is not set")
 
 
         if not to:
         if not to:
-            raise ValueError('mail to is not set')
+            raise ValueError("mail to is not set")
 
 
         if not subject:
         if not subject:
-            raise ValueError('mail subject is not set')
+            raise ValueError("mail subject is not set")
 
 
         if not html:
         if not html:
-            raise ValueError('mail html is not set')
-
-        self._client.send({
-            "from": from_,
-            "to": to,
-            "subject": subject,
-            "html": html
-        })
+            raise ValueError("mail html is not set")
+
+        self._client.send(
+            {
+                "from": from_,
+                "to": to,
+                "subject": subject,
+                "html": html,
+            }
+        )
 
 
 
 
 def init_app(app: Flask):
 def init_app(app: Flask):

+ 15 - 12
api/extensions/ext_redis.py

@@ -6,18 +6,21 @@ redis_client = redis.Redis()
 
 
 def init_app(app):
 def init_app(app):
     connection_class = Connection
     connection_class = Connection
-    if app.config.get('REDIS_USE_SSL'):
+    if app.config.get("REDIS_USE_SSL"):
         connection_class = SSLConnection
         connection_class = SSLConnection
 
 
-    redis_client.connection_pool = redis.ConnectionPool(**{
-        'host': app.config.get('REDIS_HOST'),
-        'port': app.config.get('REDIS_PORT'),
-        'username': app.config.get('REDIS_USERNAME'),
-        'password': app.config.get('REDIS_PASSWORD'),
-        'db': app.config.get('REDIS_DB'),
-        'encoding': 'utf-8',
-        'encoding_errors': 'strict',
-        'decode_responses': False
-    }, connection_class=connection_class)
+    redis_client.connection_pool = redis.ConnectionPool(
+        **{
+            "host": app.config.get("REDIS_HOST"),
+            "port": app.config.get("REDIS_PORT"),
+            "username": app.config.get("REDIS_USERNAME"),
+            "password": app.config.get("REDIS_PASSWORD"),
+            "db": app.config.get("REDIS_DB"),
+            "encoding": "utf-8",
+            "encoding_errors": "strict",
+            "decode_responses": False,
+        },
+        connection_class=connection_class,
+    )
 
 
-    app.extensions['redis'] = redis_client
+    app.extensions["redis"] = redis_client

+ 7 - 10
api/extensions/ext_sentry.py

@@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException
 
 
 
 
 def init_app(app):
 def init_app(app):
-    if app.config.get('SENTRY_DSN'):
+    if app.config.get("SENTRY_DSN"):
         sentry_sdk.init(
         sentry_sdk.init(
-            dsn=app.config.get('SENTRY_DSN'),
-            integrations=[
-                FlaskIntegration(),
-                CeleryIntegration()
-            ],
+            dsn=app.config.get("SENTRY_DSN"),
+            integrations=[FlaskIntegration(), CeleryIntegration()],
             ignore_errors=[HTTPException, ValueError],
             ignore_errors=[HTTPException, ValueError],
-            traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0),
-            profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0),
-            environment=app.config.get('DEPLOY_ENV'),
-            release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}"
+            traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
+            profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
+            environment=app.config.get("DEPLOY_ENV"),
+            release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
         )
         )

+ 13 - 25
api/extensions/ext_storage.py

@@ -17,31 +17,19 @@ class Storage:
         self.storage_runner = None
         self.storage_runner = None
 
 
     def init_app(self, app: Flask):
     def init_app(self, app: Flask):
-        storage_type = app.config.get('STORAGE_TYPE')
-        if storage_type == 's3':
-            self.storage_runner = S3Storage(
-                app=app
-            )
-        elif storage_type == 'azure-blob':
-            self.storage_runner = AzureStorage(
-                app=app
-            )
-        elif storage_type == 'aliyun-oss':
-            self.storage_runner = AliyunStorage(
-                app=app
-            )
-        elif storage_type == 'google-storage':
-            self.storage_runner = GoogleStorage(
-                app=app
-            )
-        elif storage_type == 'tencent-cos':
-            self.storage_runner = TencentStorage(
-                app=app
-            )
-        elif storage_type == 'oci-storage':
-            self.storage_runner = OCIStorage(
-                app=app
-            )        
+        storage_type = app.config.get("STORAGE_TYPE")
+        if storage_type == "s3":
+            self.storage_runner = S3Storage(app=app)
+        elif storage_type == "azure-blob":
+            self.storage_runner = AzureStorage(app=app)
+        elif storage_type == "aliyun-oss":
+            self.storage_runner = AliyunStorage(app=app)
+        elif storage_type == "google-storage":
+            self.storage_runner = GoogleStorage(app=app)
+        elif storage_type == "tencent-cos":
+            self.storage_runner = TencentStorage(app=app)
+        elif storage_type == "oci-storage":
+            self.storage_runner = OCIStorage(app=app)
         else:
         else:
             self.storage_runner = LocalStorage(app=app)
             self.storage_runner = LocalStorage(app=app)
 
 

+ 6 - 7
api/extensions/storage/aliyun_storage.py

@@ -8,23 +8,22 @@ from extensions.storage.base_storage import BaseStorage
 
 
 
 
 class AliyunStorage(BaseStorage):
 class AliyunStorage(BaseStorage):
-    """Implementation for aliyun storage.
-    """
+    """Implementation for aliyun storage."""
 
 
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
 
 
         app_config = self.app.config
         app_config = self.app.config
-        self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME')
+        self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME")
         oss_auth_method = aliyun_s3.Auth
         oss_auth_method = aliyun_s3.Auth
         region = None
         region = None
-        if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4':
+        if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4":
             oss_auth_method = aliyun_s3.AuthV4
             oss_auth_method = aliyun_s3.AuthV4
-            region = app_config.get('ALIYUN_OSS_REGION')
-        oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY'))
+            region = app_config.get("ALIYUN_OSS_REGION")
+        oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY"))
         self.client = aliyun_s3.Bucket(
         self.client = aliyun_s3.Bucket(
             oss_auth,
             oss_auth,
-            app_config.get('ALIYUN_OSS_ENDPOINT'),
+            app_config.get("ALIYUN_OSS_ENDPOINT"),
             self.bucket_name,
             self.bucket_name,
             connect_timeout=30,
             connect_timeout=30,
             region=region,
             region=region,

+ 9 - 9
api/extensions/storage/azure_storage.py

@@ -9,16 +9,15 @@ from extensions.storage.base_storage import BaseStorage
 
 
 
 
 class AzureStorage(BaseStorage):
 class AzureStorage(BaseStorage):
-    """Implementation for azure storage.
-    """
+    """Implementation for azure storage."""
 
 
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
         app_config = self.app.config
         app_config = self.app.config
-        self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME')
-        self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL')
-        self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME')
-        self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY')
+        self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME")
+        self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL")
+        self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME")
+        self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY")
 
 
     def save(self, filename, data):
     def save(self, filename, data):
         client = self._sync_client()
         client = self._sync_client()
@@ -39,6 +38,7 @@ class AzureStorage(BaseStorage):
             blob = client.get_blob_client(container=self.bucket_name, blob=filename)
             blob = client.get_blob_client(container=self.bucket_name, blob=filename)
             blob_data = blob.download_blob()
             blob_data = blob.download_blob()
             yield from blob_data.chunks()
             yield from blob_data.chunks()
+
         return generate(filename)
         return generate(filename)
 
 
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):
@@ -62,17 +62,17 @@ class AzureStorage(BaseStorage):
         blob_container.delete_blob(filename)
         blob_container.delete_blob(filename)
 
 
     def _sync_client(self):
     def _sync_client(self):
-        cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key)
+        cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
         cache_result = redis_client.get(cache_key)
         cache_result = redis_client.get(cache_key)
         if cache_result is not None:
         if cache_result is not None:
-            sas_token = cache_result.decode('utf-8')
+            sas_token = cache_result.decode("utf-8")
         else:
         else:
             sas_token = generate_account_sas(
             sas_token = generate_account_sas(
                 account_name=self.account_name,
                 account_name=self.account_name,
                 account_key=self.account_key,
                 account_key=self.account_key,
                 resource_types=ResourceTypes(service=True, container=True, object=True),
                 resource_types=ResourceTypes(service=True, container=True, object=True),
                 permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
                 permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
-                expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
+                expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1),
             )
             )
             redis_client.set(cache_key, sas_token, ex=3000)
             redis_client.set(cache_key, sas_token, ex=3000)
         return BlobServiceClient(account_url=self.account_url, credential=sas_token)
         return BlobServiceClient(account_url=self.account_url, credential=sas_token)

+ 3 - 2
api/extensions/storage/base_storage.py

@@ -1,4 +1,5 @@
 """Abstract interface for file storage implementations."""
 """Abstract interface for file storage implementations."""
+
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from collections.abc import Generator
 from collections.abc import Generator
 
 
@@ -6,8 +7,8 @@ from flask import Flask
 
 
 
 
 class BaseStorage(ABC):
 class BaseStorage(ABC):
-    """Interface for file storage.
-    """
+    """Interface for file storage."""
+
     app = None
     app = None
 
 
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):

+ 8 - 7
api/extensions/storage/google_storage.py

@@ -11,16 +11,16 @@ from extensions.storage.base_storage import BaseStorage
 
 
 
 
 class GoogleStorage(BaseStorage):
 class GoogleStorage(BaseStorage):
-    """Implementation for google storage.
-    """
+    """Implementation for google storage."""
+
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
         app_config = self.app.config
         app_config = self.app.config
-        self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME')
-        service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
+        self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME")
+        service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64")
         # if service_account_json_str is empty, use Application Default Credentials
         # if service_account_json_str is empty, use Application Default Credentials
         if service_account_json_str:
         if service_account_json_str:
-            service_account_json = base64.b64decode(service_account_json_str).decode('utf-8')
+            service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
             # convert str to object
             # convert str to object
             service_account_obj = json.loads(service_account_json)
             service_account_obj = json.loads(service_account_json)
             self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj)
             self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj)
@@ -43,9 +43,10 @@ class GoogleStorage(BaseStorage):
         def generate(filename: str = filename) -> Generator:
         def generate(filename: str = filename) -> Generator:
             bucket = self.client.get_bucket(self.bucket_name)
             bucket = self.client.get_bucket(self.bucket_name)
             blob = bucket.get_blob(filename)
             blob = bucket.get_blob(filename)
-            with closing(blob.open(mode='rb')) as blob_stream:
+            with closing(blob.open(mode="rb")) as blob_stream:
                 while chunk := blob_stream.read(4096):
                 while chunk := blob_stream.read(4096):
                     yield chunk
                     yield chunk
+
         return generate()
         return generate()
 
 
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):
@@ -60,4 +61,4 @@ class GoogleStorage(BaseStorage):
 
 
     def delete(self, filename):
     def delete(self, filename):
         bucket = self.client.get_bucket(self.bucket_name)
         bucket = self.client.get_bucket(self.bucket_name)
-        bucket.delete_blob(filename)
+        bucket.delete_blob(filename)

+ 14 - 15
api/extensions/storage/local_storage.py

@@ -8,21 +8,20 @@ from extensions.storage.base_storage import BaseStorage
 
 
 
 
 class LocalStorage(BaseStorage):
 class LocalStorage(BaseStorage):
-    """Implementation for local storage.
-    """
+    """Implementation for local storage."""
 
 
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
-        folder = self.app.config.get('STORAGE_LOCAL_PATH')
+        folder = self.app.config.get("STORAGE_LOCAL_PATH")
         if not os.path.isabs(folder):
         if not os.path.isabs(folder):
             folder = os.path.join(app.root_path, folder)
             folder = os.path.join(app.root_path, folder)
         self.folder = folder
         self.folder = folder
 
 
     def save(self, filename, data):
     def save(self, filename, data):
-        if not self.folder or self.folder.endswith('/'):
+        if not self.folder or self.folder.endswith("/"):
             filename = self.folder + filename
             filename = self.folder + filename
         else:
         else:
-            filename = self.folder + '/' + filename
+            filename = self.folder + "/" + filename
 
 
         folder = os.path.dirname(filename)
         folder = os.path.dirname(filename)
         os.makedirs(folder, exist_ok=True)
         os.makedirs(folder, exist_ok=True)
@@ -31,10 +30,10 @@ class LocalStorage(BaseStorage):
             f.write(data)
             f.write(data)
 
 
     def load_once(self, filename: str) -> bytes:
     def load_once(self, filename: str) -> bytes:
-        if not self.folder or self.folder.endswith('/'):
+        if not self.folder or self.folder.endswith("/"):
             filename = self.folder + filename
             filename = self.folder + filename
         else:
         else:
-            filename = self.folder + '/' + filename
+            filename = self.folder + "/" + filename
 
 
         if not os.path.exists(filename):
         if not os.path.exists(filename):
             raise FileNotFoundError("File not found")
             raise FileNotFoundError("File not found")
@@ -46,10 +45,10 @@ class LocalStorage(BaseStorage):
 
 
     def load_stream(self, filename: str) -> Generator:
     def load_stream(self, filename: str) -> Generator:
         def generate(filename: str = filename) -> Generator:
         def generate(filename: str = filename) -> Generator:
-            if not self.folder or self.folder.endswith('/'):
+            if not self.folder or self.folder.endswith("/"):
                 filename = self.folder + filename
                 filename = self.folder + filename
             else:
             else:
-                filename = self.folder + '/' + filename
+                filename = self.folder + "/" + filename
 
 
             if not os.path.exists(filename):
             if not os.path.exists(filename):
                 raise FileNotFoundError("File not found")
                 raise FileNotFoundError("File not found")
@@ -61,10 +60,10 @@ class LocalStorage(BaseStorage):
         return generate()
         return generate()
 
 
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):
-        if not self.folder or self.folder.endswith('/'):
+        if not self.folder or self.folder.endswith("/"):
             filename = self.folder + filename
             filename = self.folder + filename
         else:
         else:
-            filename = self.folder + '/' + filename
+            filename = self.folder + "/" + filename
 
 
         if not os.path.exists(filename):
         if not os.path.exists(filename):
             raise FileNotFoundError("File not found")
             raise FileNotFoundError("File not found")
@@ -72,17 +71,17 @@ class LocalStorage(BaseStorage):
         shutil.copyfile(filename, target_filepath)
         shutil.copyfile(filename, target_filepath)
 
 
     def exists(self, filename):
     def exists(self, filename):
-        if not self.folder or self.folder.endswith('/'):
+        if not self.folder or self.folder.endswith("/"):
             filename = self.folder + filename
             filename = self.folder + filename
         else:
         else:
-            filename = self.folder + '/' + filename
+            filename = self.folder + "/" + filename
 
 
         return os.path.exists(filename)
         return os.path.exists(filename)
 
 
     def delete(self, filename):
     def delete(self, filename):
-        if not self.folder or self.folder.endswith('/'):
+        if not self.folder or self.folder.endswith("/"):
             filename = self.folder + filename
             filename = self.folder + filename
         else:
         else:
-            filename = self.folder + '/' + filename
+            filename = self.folder + "/" + filename
         if os.path.exists(filename):
         if os.path.exists(filename):
             os.remove(filename)
             os.remove(filename)

+ 13 - 12
api/extensions/storage/oci_storage.py

@@ -12,14 +12,14 @@ class OCIStorage(BaseStorage):
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
         app_config = self.app.config
         app_config = self.app.config
-        self.bucket_name = app_config.get('OCI_BUCKET_NAME')
+        self.bucket_name = app_config.get("OCI_BUCKET_NAME")
         self.client = boto3.client(
         self.client = boto3.client(
-                    's3',
-                    aws_secret_access_key=app_config.get('OCI_SECRET_KEY'),
-                    aws_access_key_id=app_config.get('OCI_ACCESS_KEY'),
-                    endpoint_url=app_config.get('OCI_ENDPOINT'),
-                    region_name=app_config.get('OCI_REGION')
-                )
+            "s3",
+            aws_secret_access_key=app_config.get("OCI_SECRET_KEY"),
+            aws_access_key_id=app_config.get("OCI_ACCESS_KEY"),
+            endpoint_url=app_config.get("OCI_ENDPOINT"),
+            region_name=app_config.get("OCI_REGION"),
+        )
 
 
     def save(self, filename, data):
     def save(self, filename, data):
         self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
         self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@@ -27,9 +27,9 @@ class OCIStorage(BaseStorage):
     def load_once(self, filename: str) -> bytes:
     def load_once(self, filename: str) -> bytes:
         try:
         try:
             with closing(self.client) as client:
             with closing(self.client) as client:
-                data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
+                data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
         except ClientError as ex:
         except ClientError as ex:
-            if ex.response['Error']['Code'] == 'NoSuchKey':
+            if ex.response["Error"]["Code"] == "NoSuchKey":
                 raise FileNotFoundError("File not found")
                 raise FileNotFoundError("File not found")
             else:
             else:
                 raise
                 raise
@@ -40,12 +40,13 @@ class OCIStorage(BaseStorage):
             try:
             try:
                 with closing(self.client) as client:
                 with closing(self.client) as client:
                     response = client.get_object(Bucket=self.bucket_name, Key=filename)
                     response = client.get_object(Bucket=self.bucket_name, Key=filename)
-                    yield from response['Body'].iter_chunks()
+                    yield from response["Body"].iter_chunks()
             except ClientError as ex:
             except ClientError as ex:
-                if ex.response['Error']['Code'] == 'NoSuchKey':
+                if ex.response["Error"]["Code"] == "NoSuchKey":
                     raise FileNotFoundError("File not found")
                     raise FileNotFoundError("File not found")
                 else:
                 else:
                     raise
                     raise
+
         return generate()
         return generate()
 
 
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):
@@ -61,4 +62,4 @@ class OCIStorage(BaseStorage):
                 return False
                 return False
 
 
     def delete(self, filename):
     def delete(self, filename):
-        self.client.delete_object(Bucket=self.bucket_name, Key=filename)
+        self.client.delete_object(Bucket=self.bucket_name, Key=filename)

+ 17 - 16
api/extensions/storage/s3_storage.py

@@ -10,24 +10,24 @@ from extensions.storage.base_storage import BaseStorage
 
 
 
 
 class S3Storage(BaseStorage):
 class S3Storage(BaseStorage):
-    """Implementation for s3 storage.
-    """
+    """Implementation for s3 storage."""
+
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
         app_config = self.app.config
         app_config = self.app.config
-        self.bucket_name = app_config.get('S3_BUCKET_NAME')
-        if app_config.get('S3_USE_AWS_MANAGED_IAM'):
+        self.bucket_name = app_config.get("S3_BUCKET_NAME")
+        if app_config.get("S3_USE_AWS_MANAGED_IAM"):
             session = boto3.Session()
             session = boto3.Session()
-            self.client = session.client('s3')
+            self.client = session.client("s3")
         else:
         else:
             self.client = boto3.client(
             self.client = boto3.client(
-                        's3',
-                        aws_secret_access_key=app_config.get('S3_SECRET_KEY'),
-                        aws_access_key_id=app_config.get('S3_ACCESS_KEY'),
-                        endpoint_url=app_config.get('S3_ENDPOINT'),
-                        region_name=app_config.get('S3_REGION'),
-                        config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')})
-                    )
+                "s3",
+                aws_secret_access_key=app_config.get("S3_SECRET_KEY"),
+                aws_access_key_id=app_config.get("S3_ACCESS_KEY"),
+                endpoint_url=app_config.get("S3_ENDPOINT"),
+                region_name=app_config.get("S3_REGION"),
+                config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}),
+            )
 
 
     def save(self, filename, data):
     def save(self, filename, data):
         self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
         self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@@ -35,9 +35,9 @@ class S3Storage(BaseStorage):
     def load_once(self, filename: str) -> bytes:
     def load_once(self, filename: str) -> bytes:
         try:
         try:
             with closing(self.client) as client:
             with closing(self.client) as client:
-                data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
+                data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
         except ClientError as ex:
         except ClientError as ex:
-            if ex.response['Error']['Code'] == 'NoSuchKey':
+            if ex.response["Error"]["Code"] == "NoSuchKey":
                 raise FileNotFoundError("File not found")
                 raise FileNotFoundError("File not found")
             else:
             else:
                 raise
                 raise
@@ -48,12 +48,13 @@ class S3Storage(BaseStorage):
             try:
             try:
                 with closing(self.client) as client:
                 with closing(self.client) as client:
                     response = client.get_object(Bucket=self.bucket_name, Key=filename)
                     response = client.get_object(Bucket=self.bucket_name, Key=filename)
-                    yield from response['Body'].iter_chunks()
+                    yield from response["Body"].iter_chunks()
             except ClientError as ex:
             except ClientError as ex:
-                if ex.response['Error']['Code'] == 'NoSuchKey':
+                if ex.response["Error"]["Code"] == "NoSuchKey":
                     raise FileNotFoundError("File not found")
                     raise FileNotFoundError("File not found")
                 else:
                 else:
                     raise
                     raise
+
         return generate()
         return generate()
 
 
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):

+ 9 - 10
api/extensions/storage/tencent_storage.py

@@ -7,18 +7,17 @@ from extensions.storage.base_storage import BaseStorage
 
 
 
 
 class TencentStorage(BaseStorage):
 class TencentStorage(BaseStorage):
-    """Implementation for tencent cos storage.
-    """
+    """Implementation for tencent cos storage."""
 
 
     def __init__(self, app: Flask):
     def __init__(self, app: Flask):
         super().__init__(app)
         super().__init__(app)
         app_config = self.app.config
         app_config = self.app.config
-        self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME')
+        self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME")
         config = CosConfig(
         config = CosConfig(
-            Region=app_config.get('TENCENT_COS_REGION'),
-            SecretId=app_config.get('TENCENT_COS_SECRET_ID'),
-            SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'),
-            Scheme=app_config.get('TENCENT_COS_SCHEME'),
+            Region=app_config.get("TENCENT_COS_REGION"),
+            SecretId=app_config.get("TENCENT_COS_SECRET_ID"),
+            SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"),
+            Scheme=app_config.get("TENCENT_COS_SCHEME"),
         )
         )
         self.client = CosS3Client(config)
         self.client = CosS3Client(config)
 
 
@@ -26,19 +25,19 @@ class TencentStorage(BaseStorage):
         self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
         self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
 
 
     def load_once(self, filename: str) -> bytes:
     def load_once(self, filename: str) -> bytes:
-        data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read()
+        data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
         return data
         return data
 
 
     def load_stream(self, filename: str) -> Generator:
     def load_stream(self, filename: str) -> Generator:
         def generate(filename: str = filename) -> Generator:
         def generate(filename: str = filename) -> Generator:
             response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
             response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
-            yield from response['Body'].get_stream(chunk_size=4096)
+            yield from response["Body"].get_stream(chunk_size=4096)
 
 
         return generate()
         return generate()
 
 
     def download(self, filename, target_filepath):
     def download(self, filename, target_filepath):
         response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
         response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
-        response['Body'].get_stream_to_file(target_filepath)
+        response["Body"].get_stream_to_file(target_filepath)
 
 
     def exists(self, filename):
     def exists(self, filename):
         return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
         return self.client.object_exists(Bucket=self.bucket_name, Key=filename)

+ 3 - 3
api/fields/annotation_fields.py

@@ -5,7 +5,7 @@ from libs.helper import TimestampField
 annotation_fields = {
 annotation_fields = {
     "id": fields.String,
     "id": fields.String,
     "question": fields.String,
     "question": fields.String,
-    "answer": fields.Raw(attribute='content'),
+    "answer": fields.Raw(attribute="content"),
     "hit_count": fields.Integer,
     "hit_count": fields.Integer,
     "created_at": TimestampField,
     "created_at": TimestampField,
     # 'account': fields.Nested(simple_account_fields, allow_null=True)
     # 'account': fields.Nested(simple_account_fields, allow_null=True)
@@ -21,8 +21,8 @@ annotation_hit_history_fields = {
     "score": fields.Float,
     "score": fields.Float,
     "question": fields.String,
     "question": fields.String,
     "created_at": TimestampField,
     "created_at": TimestampField,
-    "match": fields.String(attribute='annotation_question'),
-    "response": fields.String(attribute='annotation_content')
+    "match": fields.String(attribute="annotation_question"),
+    "response": fields.String(attribute="annotation_content"),
 }
 }
 
 
 annotation_hit_history_list_fields = {
 annotation_hit_history_list_fields = {

+ 7 - 7
api/fields/api_based_extension_fields.py

@@ -8,16 +8,16 @@ class HiddenAPIKey(fields.Raw):
         api_key = obj.api_key
         api_key = obj.api_key
         # If the length of the api_key is less than 8 characters, show the first and last characters
         # If the length of the api_key is less than 8 characters, show the first and last characters
         if len(api_key) <= 8:
         if len(api_key) <= 8:
-            return api_key[0] + '******' + api_key[-1]
+            return api_key[0] + "******" + api_key[-1]
         # If the api_key is greater than 8 characters, show the first three and the last three characters
         # If the api_key is greater than 8 characters, show the first three and the last three characters
         else:
         else:
-            return api_key[:3] + '******' + api_key[-3:]
+            return api_key[:3] + "******" + api_key[-3:]
 
 
 
 
 api_based_extension_fields = {
 api_based_extension_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'api_endpoint': fields.String,
-    'api_key': HiddenAPIKey,
-    'created_at': TimestampField
+    "id": fields.String,
+    "name": fields.String,
+    "api_endpoint": fields.String,
+    "api_key": HiddenAPIKey,
+    "created_at": TimestampField,
 }
 }

+ 111 - 115
api/fields/app_fields.py

@@ -3,157 +3,153 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 app_detail_kernel_fields = {
 app_detail_kernel_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'mode': fields.String(attribute='mode_compatible_with_agent'),
-    'icon': fields.String,
-    'icon_background': fields.String,
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "mode": fields.String(attribute="mode_compatible_with_agent"),
+    "icon": fields.String,
+    "icon_background": fields.String,
 }
 }
 
 
 related_app_list = {
 related_app_list = {
-    'data': fields.List(fields.Nested(app_detail_kernel_fields)),
-    'total': fields.Integer,
+    "data": fields.List(fields.Nested(app_detail_kernel_fields)),
+    "total": fields.Integer,
 }
 }
 
 
 model_config_fields = {
 model_config_fields = {
-    'opening_statement': fields.String,
-    'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
-    'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
-    'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
-    'text_to_speech': fields.Raw(attribute='text_to_speech_dict'),
-    'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
-    'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
-    'more_like_this': fields.Raw(attribute='more_like_this_dict'),
-    'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
-    'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
-    'model': fields.Raw(attribute='model_dict'),
-    'user_input_form': fields.Raw(attribute='user_input_form_list'),
-    'dataset_query_variable': fields.String,
-    'pre_prompt': fields.String,
-    'agent_mode': fields.Raw(attribute='agent_mode_dict'),
-    'prompt_type': fields.String,
-    'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
-    'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
-    'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
-    'file_upload': fields.Raw(attribute='file_upload_dict'),
-    'created_at': TimestampField
+    "opening_statement": fields.String,
+    "suggested_questions": fields.Raw(attribute="suggested_questions_list"),
+    "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"),
+    "speech_to_text": fields.Raw(attribute="speech_to_text_dict"),
+    "text_to_speech": fields.Raw(attribute="text_to_speech_dict"),
+    "retriever_resource": fields.Raw(attribute="retriever_resource_dict"),
+    "annotation_reply": fields.Raw(attribute="annotation_reply_dict"),
+    "more_like_this": fields.Raw(attribute="more_like_this_dict"),
+    "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"),
+    "external_data_tools": fields.Raw(attribute="external_data_tools_list"),
+    "model": fields.Raw(attribute="model_dict"),
+    "user_input_form": fields.Raw(attribute="user_input_form_list"),
+    "dataset_query_variable": fields.String,
+    "pre_prompt": fields.String,
+    "agent_mode": fields.Raw(attribute="agent_mode_dict"),
+    "prompt_type": fields.String,
+    "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"),
+    "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"),
+    "dataset_configs": fields.Raw(attribute="dataset_configs_dict"),
+    "file_upload": fields.Raw(attribute="file_upload_dict"),
+    "created_at": TimestampField,
 }
 }
 
 
 app_detail_fields = {
 app_detail_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'mode': fields.String(attribute='mode_compatible_with_agent'),
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'enable_site': fields.Boolean,
-    'enable_api': fields.Boolean,
-    'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True),
-    'tracing': fields.Raw,
-    'created_at': TimestampField
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "mode": fields.String(attribute="mode_compatible_with_agent"),
+    "icon": fields.String,
+    "icon_background": fields.String,
+    "enable_site": fields.Boolean,
+    "enable_api": fields.Boolean,
+    "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
+    "tracing": fields.Raw,
+    "created_at": TimestampField,
 }
 }
 
 
 prompt_config_fields = {
 prompt_config_fields = {
-    'prompt_template': fields.String,
+    "prompt_template": fields.String,
 }
 }
 
 
 model_config_partial_fields = {
 model_config_partial_fields = {
-    'model': fields.Raw(attribute='model_dict'),
-    'pre_prompt': fields.String,
+    "model": fields.Raw(attribute="model_dict"),
+    "pre_prompt": fields.String,
 }
 }
 
 
-tag_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'type': fields.String
-}
+tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
 
 
 app_partial_fields = {
 app_partial_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'max_active_requests': fields.Raw(),
-    'description': fields.String(attribute='desc_or_prompt'),
-    'mode': fields.String(attribute='mode_compatible_with_agent'),
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
-    'created_at': TimestampField,
-    'tags': fields.List(fields.Nested(tag_fields))
+    "id": fields.String,
+    "name": fields.String,
+    "max_active_requests": fields.Raw(),
+    "description": fields.String(attribute="desc_or_prompt"),
+    "mode": fields.String(attribute="mode_compatible_with_agent"),
+    "icon": fields.String,
+    "icon_background": fields.String,
+    "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True),
+    "created_at": TimestampField,
+    "tags": fields.List(fields.Nested(tag_fields)),
 }
 }
 
 
 
 
 app_pagination_fields = {
 app_pagination_fields = {
-    'page': fields.Integer,
-    'limit': fields.Integer(attribute='per_page'),
-    'total': fields.Integer,
-    'has_more': fields.Boolean(attribute='has_next'),
-    'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
+    "page": fields.Integer,
+    "limit": fields.Integer(attribute="per_page"),
+    "total": fields.Integer,
+    "has_more": fields.Boolean(attribute="has_next"),
+    "data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
 }
 }
 
 
 template_fields = {
 template_fields = {
-    'name': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'description': fields.String,
-    'mode': fields.String,
-    'model_config': fields.Nested(model_config_fields),
+    "name": fields.String,
+    "icon": fields.String,
+    "icon_background": fields.String,
+    "description": fields.String,
+    "mode": fields.String,
+    "model_config": fields.Nested(model_config_fields),
 }
 }
 
 
 template_list_fields = {
 template_list_fields = {
-    'data': fields.List(fields.Nested(template_fields)),
+    "data": fields.List(fields.Nested(template_fields)),
 }
 }
 
 
 site_fields = {
 site_fields = {
-    'access_token': fields.String(attribute='code'),
-    'code': fields.String,
-    'title': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'description': fields.String,
-    'default_language': fields.String,
-    'chat_color_theme': fields.String,
-    'chat_color_theme_inverted': fields.Boolean,
-    'customize_domain': fields.String,
-    'copyright': fields.String,
-    'privacy_policy': fields.String,
-    'custom_disclaimer': fields.String,
-    'customize_token_strategy': fields.String,
-    'prompt_public': fields.Boolean,
-    'app_base_url': fields.String,
-    'show_workflow_steps': fields.Boolean,
+    "access_token": fields.String(attribute="code"),
+    "code": fields.String,
+    "title": fields.String,
+    "icon": fields.String,
+    "icon_background": fields.String,
+    "description": fields.String,
+    "default_language": fields.String,
+    "chat_color_theme": fields.String,
+    "chat_color_theme_inverted": fields.Boolean,
+    "customize_domain": fields.String,
+    "copyright": fields.String,
+    "privacy_policy": fields.String,
+    "custom_disclaimer": fields.String,
+    "customize_token_strategy": fields.String,
+    "prompt_public": fields.Boolean,
+    "app_base_url": fields.String,
+    "show_workflow_steps": fields.Boolean,
 }
 }
 
 
 app_detail_fields_with_site = {
 app_detail_fields_with_site = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'mode': fields.String(attribute='mode_compatible_with_agent'),
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'enable_site': fields.Boolean,
-    'enable_api': fields.Boolean,
-    'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True),
-    'site': fields.Nested(site_fields),
-    'api_base_url': fields.String,
-    'created_at': TimestampField,
-    'deleted_tools': fields.List(fields.String),
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "mode": fields.String(attribute="mode_compatible_with_agent"),
+    "icon": fields.String,
+    "icon_background": fields.String,
+    "enable_site": fields.Boolean,
+    "enable_api": fields.Boolean,
+    "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
+    "site": fields.Nested(site_fields),
+    "api_base_url": fields.String,
+    "created_at": TimestampField,
+    "deleted_tools": fields.List(fields.String),
 }
 }
 
 
 app_site_fields = {
 app_site_fields = {
-    'app_id': fields.String,
-    'access_token': fields.String(attribute='code'),
-    'code': fields.String,
-    'title': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'description': fields.String,
-    'default_language': fields.String,
-    'customize_domain': fields.String,
-    'copyright': fields.String,
-    'privacy_policy': fields.String,
-    'custom_disclaimer': fields.String,
-    'customize_token_strategy': fields.String,
-    'prompt_public': fields.Boolean,
-    'show_workflow_steps': fields.Boolean,
+    "app_id": fields.String,
+    "access_token": fields.String(attribute="code"),
+    "code": fields.String,
+    "title": fields.String,
+    "icon": fields.String,
+    "icon_background": fields.String,
+    "description": fields.String,
+    "default_language": fields.String,
+    "customize_domain": fields.String,
+    "copyright": fields.String,
+    "privacy_policy": fields.String,
+    "custom_disclaimer": fields.String,
+    "customize_token_strategy": fields.String,
+    "prompt_public": fields.Boolean,
+    "show_workflow_steps": fields.Boolean,
 }
 }

+ 138 - 141
api/fields/conversation_fields.py

@@ -6,205 +6,202 @@ from libs.helper import TimestampField
 
 
 class MessageTextField(fields.Raw):
 class MessageTextField(fields.Raw):
     def format(self, value):
     def format(self, value):
-        return value[0]['text'] if value else ''
+        return value[0]["text"] if value else ""
 
 
 
 
 feedback_fields = {
 feedback_fields = {
-    'rating': fields.String,
-    'content': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account': fields.Nested(simple_account_fields, allow_null=True),
+    "rating": fields.String,
+    "content": fields.String,
+    "from_source": fields.String,
+    "from_end_user_id": fields.String,
+    "from_account": fields.Nested(simple_account_fields, allow_null=True),
 }
 }
 
 
 annotation_fields = {
 annotation_fields = {
-    'id': fields.String,
-    'question': fields.String,
-    'content': fields.String,
-    'account': fields.Nested(simple_account_fields, allow_null=True),
-    'created_at': TimestampField
+    "id": fields.String,
+    "question": fields.String,
+    "content": fields.String,
+    "account": fields.Nested(simple_account_fields, allow_null=True),
+    "created_at": TimestampField,
 }
 }
 
 
 annotation_hit_history_fields = {
 annotation_hit_history_fields = {
-    'annotation_id': fields.String(attribute='id'),
-    'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True),
-    'created_at': TimestampField
+    "annotation_id": fields.String(attribute="id"),
+    "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
+    "created_at": TimestampField,
 }
 }
 
 
 message_file_fields = {
 message_file_fields = {
-    'id': fields.String,
-    'type': fields.String,
-    'url': fields.String,
-    'belongs_to': fields.String(default='user'),
+    "id": fields.String,
+    "type": fields.String,
+    "url": fields.String,
+    "belongs_to": fields.String(default="user"),
 }
 }
 
 
 agent_thought_fields = {
 agent_thought_fields = {
-    'id': fields.String,
-    'chain_id': fields.String,
-    'message_id': fields.String,
-    'position': fields.Integer,
-    'thought': fields.String,
-    'tool': fields.String,
-    'tool_labels': fields.Raw,
-    'tool_input': fields.String,
-    'created_at': TimestampField,
-    'observation': fields.String,
-    'files': fields.List(fields.String),
+    "id": fields.String,
+    "chain_id": fields.String,
+    "message_id": fields.String,
+    "position": fields.Integer,
+    "thought": fields.String,
+    "tool": fields.String,
+    "tool_labels": fields.Raw,
+    "tool_input": fields.String,
+    "created_at": TimestampField,
+    "observation": fields.String,
+    "files": fields.List(fields.String),
 }
 }
 
 
 message_detail_fields = {
 message_detail_fields = {
-    'id': fields.String,
-    'conversation_id': fields.String,
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'message': fields.Raw,
-    'message_tokens': fields.Integer,
-    'answer': fields.String(attribute='re_sign_file_url_answer'),
-    'answer_tokens': fields.Integer,
-    'provider_response_latency': fields.Float,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account_id': fields.String,
-    'feedbacks': fields.List(fields.Nested(feedback_fields)),
-    'workflow_run_id': fields.String,
-    'annotation': fields.Nested(annotation_fields, allow_null=True),
-    'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
-    'created_at': TimestampField,
-    'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
-    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
-    'metadata': fields.Raw(attribute='message_metadata_dict'),
-    'status': fields.String,
-    'error': fields.String,
-}
-
-feedback_stat_fields = {
-    'like': fields.Integer,
-    'dislike': fields.Integer
-}
+    "id": fields.String,
+    "conversation_id": fields.String,
+    "inputs": fields.Raw,
+    "query": fields.String,
+    "message": fields.Raw,
+    "message_tokens": fields.Integer,
+    "answer": fields.String(attribute="re_sign_file_url_answer"),
+    "answer_tokens": fields.Integer,
+    "provider_response_latency": fields.Float,
+    "from_source": fields.String,
+    "from_end_user_id": fields.String,
+    "from_account_id": fields.String,
+    "feedbacks": fields.List(fields.Nested(feedback_fields)),
+    "workflow_run_id": fields.String,
+    "annotation": fields.Nested(annotation_fields, allow_null=True),
+    "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
+    "created_at": TimestampField,
+    "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
+    "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+    "metadata": fields.Raw(attribute="message_metadata_dict"),
+    "status": fields.String,
+    "error": fields.String,
+}
+
+feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
 
 
 model_config_fields = {
 model_config_fields = {
-    'opening_statement': fields.String,
-    'suggested_questions': fields.Raw,
-    'model': fields.Raw,
-    'user_input_form': fields.Raw,
-    'pre_prompt': fields.String,
-    'agent_mode': fields.Raw,
+    "opening_statement": fields.String,
+    "suggested_questions": fields.Raw,
+    "model": fields.Raw,
+    "user_input_form": fields.Raw,
+    "pre_prompt": fields.String,
+    "agent_mode": fields.Raw,
 }
 }
 
 
 simple_configs_fields = {
 simple_configs_fields = {
-    'prompt_template': fields.String,
+    "prompt_template": fields.String,
 }
 }
 
 
 simple_model_config_fields = {
 simple_model_config_fields = {
-    'model': fields.Raw(attribute='model_dict'),
-    'pre_prompt': fields.String,
+    "model": fields.Raw(attribute="model_dict"),
+    "pre_prompt": fields.String,
 }
 }
 
 
 simple_message_detail_fields = {
 simple_message_detail_fields = {
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'message': MessageTextField,
-    'answer': fields.String,
+    "inputs": fields.Raw,
+    "query": fields.String,
+    "message": MessageTextField,
+    "answer": fields.String,
 }
 }
 
 
 conversation_fields = {
 conversation_fields = {
-    'id': fields.String,
-    'status': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_end_user_session_id': fields.String(),
-    'from_account_id': fields.String,
-    'read_at': TimestampField,
-    'created_at': TimestampField,
-    'annotation': fields.Nested(annotation_fields, allow_null=True),
-    'model_config': fields.Nested(simple_model_config_fields),
-    'user_feedback_stats': fields.Nested(feedback_stat_fields),
-    'admin_feedback_stats': fields.Nested(feedback_stat_fields),
-    'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
+    "id": fields.String,
+    "status": fields.String,
+    "from_source": fields.String,
+    "from_end_user_id": fields.String,
+    "from_end_user_session_id": fields.String(),
+    "from_account_id": fields.String,
+    "read_at": TimestampField,
+    "created_at": TimestampField,
+    "annotation": fields.Nested(annotation_fields, allow_null=True),
+    "model_config": fields.Nested(simple_model_config_fields),
+    "user_feedback_stats": fields.Nested(feedback_stat_fields),
+    "admin_feedback_stats": fields.Nested(feedback_stat_fields),
+    "message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
 }
 }
 
 
 conversation_pagination_fields = {
 conversation_pagination_fields = {
-    'page': fields.Integer,
-    'limit': fields.Integer(attribute='per_page'),
-    'total': fields.Integer,
-    'has_more': fields.Boolean(attribute='has_next'),
-    'data': fields.List(fields.Nested(conversation_fields), attribute='items')
+    "page": fields.Integer,
+    "limit": fields.Integer(attribute="per_page"),
+    "total": fields.Integer,
+    "has_more": fields.Boolean(attribute="has_next"),
+    "data": fields.List(fields.Nested(conversation_fields), attribute="items"),
 }
 }
 
 
 conversation_message_detail_fields = {
 conversation_message_detail_fields = {
-    'id': fields.String,
-    'status': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account_id': fields.String,
-    'created_at': TimestampField,
-    'model_config': fields.Nested(model_config_fields),
-    'message': fields.Nested(message_detail_fields, attribute='first_message'),
+    "id": fields.String,
+    "status": fields.String,
+    "from_source": fields.String,
+    "from_end_user_id": fields.String,
+    "from_account_id": fields.String,
+    "created_at": TimestampField,
+    "model_config": fields.Nested(model_config_fields),
+    "message": fields.Nested(message_detail_fields, attribute="first_message"),
 }
 }
 
 
 conversation_with_summary_fields = {
 conversation_with_summary_fields = {
-    'id': fields.String,
-    'status': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_end_user_session_id': fields.String,
-    'from_account_id': fields.String,
-    'name': fields.String,
-    'summary': fields.String(attribute='summary_or_query'),
-    'read_at': TimestampField,
-    'created_at': TimestampField,
-    'annotated': fields.Boolean,
-    'model_config': fields.Nested(simple_model_config_fields),
-    'message_count': fields.Integer,
-    'user_feedback_stats': fields.Nested(feedback_stat_fields),
-    'admin_feedback_stats': fields.Nested(feedback_stat_fields)
+    "id": fields.String,
+    "status": fields.String,
+    "from_source": fields.String,
+    "from_end_user_id": fields.String,
+    "from_end_user_session_id": fields.String,
+    "from_account_id": fields.String,
+    "name": fields.String,
+    "summary": fields.String(attribute="summary_or_query"),
+    "read_at": TimestampField,
+    "created_at": TimestampField,
+    "annotated": fields.Boolean,
+    "model_config": fields.Nested(simple_model_config_fields),
+    "message_count": fields.Integer,
+    "user_feedback_stats": fields.Nested(feedback_stat_fields),
+    "admin_feedback_stats": fields.Nested(feedback_stat_fields),
 }
 }
 
 
 conversation_with_summary_pagination_fields = {
 conversation_with_summary_pagination_fields = {
-    'page': fields.Integer,
-    'limit': fields.Integer(attribute='per_page'),
-    'total': fields.Integer,
-    'has_more': fields.Boolean(attribute='has_next'),
-    'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
+    "page": fields.Integer,
+    "limit": fields.Integer(attribute="per_page"),
+    "total": fields.Integer,
+    "has_more": fields.Boolean(attribute="has_next"),
+    "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
 }
 }
 
 
 conversation_detail_fields = {
 conversation_detail_fields = {
-    'id': fields.String,
-    'status': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account_id': fields.String,
-    'created_at': TimestampField,
-    'annotated': fields.Boolean,
-    'introduction': fields.String,
-    'model_config': fields.Nested(model_config_fields),
-    'message_count': fields.Integer,
-    'user_feedback_stats': fields.Nested(feedback_stat_fields),
-    'admin_feedback_stats': fields.Nested(feedback_stat_fields)
+    "id": fields.String,
+    "status": fields.String,
+    "from_source": fields.String,
+    "from_end_user_id": fields.String,
+    "from_account_id": fields.String,
+    "created_at": TimestampField,
+    "annotated": fields.Boolean,
+    "introduction": fields.String,
+    "model_config": fields.Nested(model_config_fields),
+    "message_count": fields.Integer,
+    "user_feedback_stats": fields.Nested(feedback_stat_fields),
+    "admin_feedback_stats": fields.Nested(feedback_stat_fields),
 }
 }
 
 
 simple_conversation_fields = {
 simple_conversation_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'inputs': fields.Raw,
-    'status': fields.String,
-    'introduction': fields.String,
-    'created_at': TimestampField
+    "id": fields.String,
+    "name": fields.String,
+    "inputs": fields.Raw,
+    "status": fields.String,
+    "introduction": fields.String,
+    "created_at": TimestampField,
 }
 }
 
 
 conversation_infinite_scroll_pagination_fields = {
 conversation_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(simple_conversation_fields))
+    "limit": fields.Integer,
+    "has_more": fields.Boolean,
+    "data": fields.List(fields.Nested(simple_conversation_fields)),
 }
 }
 
 
 conversation_with_model_config_fields = {
 conversation_with_model_config_fields = {
     **simple_conversation_fields,
     **simple_conversation_fields,
-    'model_config': fields.Raw,
+    "model_config": fields.Raw,
 }
 }
 
 
 conversation_with_model_config_infinite_scroll_pagination_fields = {
 conversation_with_model_config_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(conversation_with_model_config_fields))
+    "limit": fields.Integer,
+    "has_more": fields.Boolean,
+    "data": fields.List(fields.Nested(conversation_with_model_config_fields)),
 }
 }

+ 12 - 12
api/fields/conversation_variable_fields.py

@@ -3,19 +3,19 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 conversation_variable_fields = {
 conversation_variable_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'value_type': fields.String(attribute='value_type.value'),
-    'value': fields.String,
-    'description': fields.String,
-    'created_at': TimestampField,
-    'updated_at': TimestampField,
+    "id": fields.String,
+    "name": fields.String,
+    "value_type": fields.String(attribute="value_type.value"),
+    "value": fields.String,
+    "description": fields.String,
+    "created_at": TimestampField,
+    "updated_at": TimestampField,
 }
 }
 
 
 paginated_conversation_variable_fields = {
 paginated_conversation_variable_fields = {
-    'page': fields.Integer,
-    'limit': fields.Integer,
-    'total': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'),
+    "page": fields.Integer,
+    "limit": fields.Integer,
+    "total": fields.Integer,
+    "has_more": fields.Boolean,
+    "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"),
 }
 }

+ 32 - 40
api/fields/data_source_fields.py

@@ -2,64 +2,56 @@ from flask_restful import fields
 
 
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
-integrate_icon_fields = {
-    'type': fields.String,
-    'url': fields.String,
-    'emoji': fields.String
-}
+integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
 
 
 integrate_page_fields = {
 integrate_page_fields = {
-    'page_name': fields.String,
-    'page_id': fields.String,
-    'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
-    'is_bound': fields.Boolean,
-    'parent_id': fields.String,
-    'type': fields.String
+    "page_name": fields.String,
+    "page_id": fields.String,
+    "page_icon": fields.Nested(integrate_icon_fields, allow_null=True),
+    "is_bound": fields.Boolean,
+    "parent_id": fields.String,
+    "type": fields.String,
 }
 }
 
 
 integrate_workspace_fields = {
 integrate_workspace_fields = {
-    'workspace_name': fields.String,
-    'workspace_id': fields.String,
-    'workspace_icon': fields.String,
-    'pages': fields.List(fields.Nested(integrate_page_fields))
+    "workspace_name": fields.String,
+    "workspace_id": fields.String,
+    "workspace_icon": fields.String,
+    "pages": fields.List(fields.Nested(integrate_page_fields)),
 }
 }
 
 
 integrate_notion_info_list_fields = {
 integrate_notion_info_list_fields = {
-    'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
+    "notion_info": fields.List(fields.Nested(integrate_workspace_fields)),
 }
 }
 
 
-integrate_icon_fields = {
-    'type': fields.String,
-    'url': fields.String,
-    'emoji': fields.String
-}
+integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String}
 
 
 integrate_page_fields = {
 integrate_page_fields = {
-    'page_name': fields.String,
-    'page_id': fields.String,
-    'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
-    'parent_id': fields.String,
-    'type': fields.String
+    "page_name": fields.String,
+    "page_id": fields.String,
+    "page_icon": fields.Nested(integrate_icon_fields, allow_null=True),
+    "parent_id": fields.String,
+    "type": fields.String,
 }
 }
 
 
 integrate_workspace_fields = {
 integrate_workspace_fields = {
-    'workspace_name': fields.String,
-    'workspace_id': fields.String,
-    'workspace_icon': fields.String,
-    'pages': fields.List(fields.Nested(integrate_page_fields)),
-    'total': fields.Integer
+    "workspace_name": fields.String,
+    "workspace_id": fields.String,
+    "workspace_icon": fields.String,
+    "pages": fields.List(fields.Nested(integrate_page_fields)),
+    "total": fields.Integer,
 }
 }
 
 
 integrate_fields = {
 integrate_fields = {
-    'id': fields.String,
-    'provider': fields.String,
-    'created_at': TimestampField,
-    'is_bound': fields.Boolean,
-    'disabled': fields.Boolean,
-    'link': fields.String,
-    'source_info': fields.Nested(integrate_workspace_fields)
+    "id": fields.String,
+    "provider": fields.String,
+    "created_at": TimestampField,
+    "is_bound": fields.Boolean,
+    "disabled": fields.Boolean,
+    "link": fields.String,
+    "source_info": fields.Nested(integrate_workspace_fields),
 }
 }
 
 
 integrate_list_fields = {
 integrate_list_fields = {
-    'data': fields.List(fields.Nested(integrate_fields)),
-}
+    "data": fields.List(fields.Nested(integrate_fields)),
+}

+ 44 - 55
api/fields/dataset_fields.py

@@ -3,73 +3,64 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 dataset_fields = {
 dataset_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'permission': fields.String,
-    'data_source_type': fields.String,
-    'indexing_technique': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "permission": fields.String,
+    "data_source_type": fields.String,
+    "indexing_technique": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
 }
 }
 
 
-reranking_model_fields = {
-    'reranking_provider_name': fields.String,
-    'reranking_model_name': fields.String
-}
+reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String}
 
 
-keyword_setting_fields = {
-    'keyword_weight': fields.Float
-}
+keyword_setting_fields = {"keyword_weight": fields.Float}
 
 
 vector_setting_fields = {
 vector_setting_fields = {
-    'vector_weight': fields.Float,
-    'embedding_model_name': fields.String,
-    'embedding_provider_name': fields.String,
+    "vector_weight": fields.Float,
+    "embedding_model_name": fields.String,
+    "embedding_provider_name": fields.String,
 }
 }
 
 
 weighted_score_fields = {
 weighted_score_fields = {
-    'keyword_setting': fields.Nested(keyword_setting_fields),
-    'vector_setting': fields.Nested(vector_setting_fields),
+    "keyword_setting": fields.Nested(keyword_setting_fields),
+    "vector_setting": fields.Nested(vector_setting_fields),
 }
 }
 
 
 dataset_retrieval_model_fields = {
 dataset_retrieval_model_fields = {
-    'search_method': fields.String,
-    'reranking_enable': fields.Boolean,
-    'reranking_mode': fields.String,
-    'reranking_model': fields.Nested(reranking_model_fields),
-    'weights': fields.Nested(weighted_score_fields, allow_null=True),
-    'top_k': fields.Integer,
-    'score_threshold_enabled': fields.Boolean,
-    'score_threshold': fields.Float
+    "search_method": fields.String,
+    "reranking_enable": fields.Boolean,
+    "reranking_mode": fields.String,
+    "reranking_model": fields.Nested(reranking_model_fields),
+    "weights": fields.Nested(weighted_score_fields, allow_null=True),
+    "top_k": fields.Integer,
+    "score_threshold_enabled": fields.Boolean,
+    "score_threshold": fields.Float,
 }
 }
 
 
-tag_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'type': fields.String
-}
+tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
 
 
 dataset_detail_fields = {
 dataset_detail_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'provider': fields.String,
-    'permission': fields.String,
-    'data_source_type': fields.String,
-    'indexing_technique': fields.String,
-    'app_count': fields.Integer,
-    'document_count': fields.Integer,
-    'word_count': fields.Integer,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'updated_by': fields.String,
-    'updated_at': TimestampField,
-    'embedding_model': fields.String,
-    'embedding_model_provider': fields.String,
-    'embedding_available': fields.Boolean,
-    'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields),
-    'tags': fields.List(fields.Nested(tag_fields))
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "provider": fields.String,
+    "permission": fields.String,
+    "data_source_type": fields.String,
+    "indexing_technique": fields.String,
+    "app_count": fields.Integer,
+    "document_count": fields.Integer,
+    "word_count": fields.Integer,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "updated_by": fields.String,
+    "updated_at": TimestampField,
+    "embedding_model": fields.String,
+    "embedding_model_provider": fields.String,
+    "embedding_available": fields.Boolean,
+    "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
+    "tags": fields.List(fields.Nested(tag_fields)),
 }
 }
 
 
 dataset_query_detail_fields = {
 dataset_query_detail_fields = {
@@ -79,7 +70,5 @@ dataset_query_detail_fields = {
     "source_app_id": fields.String,
     "source_app_id": fields.String,
     "created_by_role": fields.String,
     "created_by_role": fields.String,
     "created_by": fields.String,
     "created_by": fields.String,
-    "created_at": TimestampField
+    "created_at": TimestampField,
 }
 }
-
-

+ 59 - 61
api/fields/document_fields.py

@@ -4,75 +4,73 @@ from fields.dataset_fields import dataset_fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 document_fields = {
 document_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'data_source_type': fields.String,
-    'data_source_info': fields.Raw(attribute='data_source_info_dict'),
-    'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'),
-    'dataset_process_rule_id': fields.String,
-    'name': fields.String,
-    'created_from': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'tokens': fields.Integer,
-    'indexing_status': fields.String,
-    'error': fields.String,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'archived': fields.Boolean,
-    'display_status': fields.String,
-    'word_count': fields.Integer,
-    'hit_count': fields.Integer,
-    'doc_form': fields.String,
+    "id": fields.String,
+    "position": fields.Integer,
+    "data_source_type": fields.String,
+    "data_source_info": fields.Raw(attribute="data_source_info_dict"),
+    "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
+    "dataset_process_rule_id": fields.String,
+    "name": fields.String,
+    "created_from": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "tokens": fields.Integer,
+    "indexing_status": fields.String,
+    "error": fields.String,
+    "enabled": fields.Boolean,
+    "disabled_at": TimestampField,
+    "disabled_by": fields.String,
+    "archived": fields.Boolean,
+    "display_status": fields.String,
+    "word_count": fields.Integer,
+    "hit_count": fields.Integer,
+    "doc_form": fields.String,
 }
 }
 
 
 document_with_segments_fields = {
 document_with_segments_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'data_source_type': fields.String,
-    'data_source_info': fields.Raw(attribute='data_source_info_dict'),
-    'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'),
-    'dataset_process_rule_id': fields.String,
-    'name': fields.String,
-    'created_from': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'tokens': fields.Integer,
-    'indexing_status': fields.String,
-    'error': fields.String,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'archived': fields.Boolean,
-    'display_status': fields.String,
-    'word_count': fields.Integer,
-    'hit_count': fields.Integer,
-    'completed_segments': fields.Integer,
-    'total_segments': fields.Integer
+    "id": fields.String,
+    "position": fields.Integer,
+    "data_source_type": fields.String,
+    "data_source_info": fields.Raw(attribute="data_source_info_dict"),
+    "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
+    "dataset_process_rule_id": fields.String,
+    "name": fields.String,
+    "created_from": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "tokens": fields.Integer,
+    "indexing_status": fields.String,
+    "error": fields.String,
+    "enabled": fields.Boolean,
+    "disabled_at": TimestampField,
+    "disabled_by": fields.String,
+    "archived": fields.Boolean,
+    "display_status": fields.String,
+    "word_count": fields.Integer,
+    "hit_count": fields.Integer,
+    "completed_segments": fields.Integer,
+    "total_segments": fields.Integer,
 }
 }
 
 
 dataset_and_document_fields = {
 dataset_and_document_fields = {
-    'dataset': fields.Nested(dataset_fields),
-    'documents': fields.List(fields.Nested(document_fields)),
-    'batch': fields.String
+    "dataset": fields.Nested(dataset_fields),
+    "documents": fields.List(fields.Nested(document_fields)),
+    "batch": fields.String,
 }
 }
 
 
 document_status_fields = {
 document_status_fields = {
-    'id': fields.String,
-    'indexing_status': fields.String,
-    'processing_started_at': TimestampField,
-    'parsing_completed_at': TimestampField,
-    'cleaning_completed_at': TimestampField,
-    'splitting_completed_at': TimestampField,
-    'completed_at': TimestampField,
-    'paused_at': TimestampField,
-    'error': fields.String,
-    'stopped_at': TimestampField,
-    'completed_segments': fields.Integer,
-    'total_segments': fields.Integer,
+    "id": fields.String,
+    "indexing_status": fields.String,
+    "processing_started_at": TimestampField,
+    "parsing_completed_at": TimestampField,
+    "cleaning_completed_at": TimestampField,
+    "splitting_completed_at": TimestampField,
+    "completed_at": TimestampField,
+    "paused_at": TimestampField,
+    "error": fields.String,
+    "stopped_at": TimestampField,
+    "completed_segments": fields.Integer,
+    "total_segments": fields.Integer,
 }
 }
 
 
-document_status_fields_list = {
-    'data': fields.List(fields.Nested(document_status_fields))
-}
+document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))}

+ 4 - 4
api/fields/end_user_fields.py

@@ -1,8 +1,8 @@
 from flask_restful import fields
 from flask_restful import fields
 
 
 simple_end_user_fields = {
 simple_end_user_fields = {
-    'id': fields.String,
-    'type': fields.String,
-    'is_anonymous': fields.Boolean,
-    'session_id': fields.String,
+    "id": fields.String,
+    "type": fields.String,
+    "is_anonymous": fields.Boolean,
+    "session_id": fields.String,
 }
 }

+ 11 - 11
api/fields/file_fields.py

@@ -3,17 +3,17 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 upload_config_fields = {
 upload_config_fields = {
-    'file_size_limit': fields.Integer,
-    'batch_count_limit': fields.Integer,
-    'image_file_size_limit': fields.Integer,
+    "file_size_limit": fields.Integer,
+    "batch_count_limit": fields.Integer,
+    "image_file_size_limit": fields.Integer,
 }
 }
 
 
 file_fields = {
 file_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'size': fields.Integer,
-    'extension': fields.String,
-    'mime_type': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-}
+    "id": fields.String,
+    "name": fields.String,
+    "size": fields.Integer,
+    "extension": fields.String,
+    "mime_type": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+}

+ 30 - 30
api/fields/hit_testing_fields.py

@@ -3,39 +3,39 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 document_fields = {
 document_fields = {
-    'id': fields.String,
-    'data_source_type': fields.String,
-    'name': fields.String,
-    'doc_type': fields.String,
+    "id": fields.String,
+    "data_source_type": fields.String,
+    "name": fields.String,
+    "doc_type": fields.String,
 }
 }
 
 
 segment_fields = {
 segment_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'document_id': fields.String,
-    'content': fields.String,
-    'answer': fields.String,
-    'word_count': fields.Integer,
-    'tokens': fields.Integer,
-    'keywords': fields.List(fields.String),
-    'index_node_id': fields.String,
-    'index_node_hash': fields.String,
-    'hit_count': fields.Integer,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'status': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'indexing_at': TimestampField,
-    'completed_at': TimestampField,
-    'error': fields.String,
-    'stopped_at': TimestampField,
-    'document': fields.Nested(document_fields),
+    "id": fields.String,
+    "position": fields.Integer,
+    "document_id": fields.String,
+    "content": fields.String,
+    "answer": fields.String,
+    "word_count": fields.Integer,
+    "tokens": fields.Integer,
+    "keywords": fields.List(fields.String),
+    "index_node_id": fields.String,
+    "index_node_hash": fields.String,
+    "hit_count": fields.Integer,
+    "enabled": fields.Boolean,
+    "disabled_at": TimestampField,
+    "disabled_by": fields.String,
+    "status": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "indexing_at": TimestampField,
+    "completed_at": TimestampField,
+    "error": fields.String,
+    "stopped_at": TimestampField,
+    "document": fields.Nested(document_fields),
 }
 }
 
 
 hit_testing_record_fields = {
 hit_testing_record_fields = {
-    'segment': fields.Nested(segment_fields),
-    'score': fields.Float,
-    'tsne_position': fields.Raw
-}
+    "segment": fields.Nested(segment_fields),
+    "score": fields.Float,
+    "tsne_position": fields.Raw,
+}

+ 13 - 15
api/fields/installed_app_fields.py

@@ -3,23 +3,21 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 app_fields = {
 app_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'mode': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String
+    "id": fields.String,
+    "name": fields.String,
+    "mode": fields.String,
+    "icon": fields.String,
+    "icon_background": fields.String,
 }
 }
 
 
 installed_app_fields = {
 installed_app_fields = {
-    'id': fields.String,
-    'app': fields.Nested(app_fields),
-    'app_owner_tenant_id': fields.String,
-    'is_pinned': fields.Boolean,
-    'last_used_at': TimestampField,
-    'editable': fields.Boolean,
-    'uninstallable': fields.Boolean
+    "id": fields.String,
+    "app": fields.Nested(app_fields),
+    "app_owner_tenant_id": fields.String,
+    "is_pinned": fields.Boolean,
+    "last_used_at": TimestampField,
+    "editable": fields.Boolean,
+    "uninstallable": fields.Boolean,
 }
 }
 
 
-installed_app_list_fields = {
-    'installed_apps': fields.List(fields.Nested(installed_app_fields))
-}
+installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))}

+ 22 - 28
api/fields/member_fields.py

@@ -2,38 +2,32 @@ from flask_restful import fields
 
 
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
-simple_account_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'email': fields.String
-}
+simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String}
 
 
 account_fields = {
 account_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'avatar': fields.String,
-    'email': fields.String,
-    'is_password_set': fields.Boolean,
-    'interface_language': fields.String,
-    'interface_theme': fields.String,
-    'timezone': fields.String,
-    'last_login_at': TimestampField,
-    'last_login_ip': fields.String,
-    'created_at': TimestampField
+    "id": fields.String,
+    "name": fields.String,
+    "avatar": fields.String,
+    "email": fields.String,
+    "is_password_set": fields.Boolean,
+    "interface_language": fields.String,
+    "interface_theme": fields.String,
+    "timezone": fields.String,
+    "last_login_at": TimestampField,
+    "last_login_ip": fields.String,
+    "created_at": TimestampField,
 }
 }
 
 
 account_with_role_fields = {
 account_with_role_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'avatar': fields.String,
-    'email': fields.String,
-    'last_login_at': TimestampField,
-    'last_active_at': TimestampField,
-    'created_at': TimestampField,
-    'role': fields.String,
-    'status': fields.String,
+    "id": fields.String,
+    "name": fields.String,
+    "avatar": fields.String,
+    "email": fields.String,
+    "last_login_at": TimestampField,
+    "last_active_at": TimestampField,
+    "created_at": TimestampField,
+    "role": fields.String,
+    "status": fields.String,
 }
 }
 
 
-account_with_role_list_fields = {
-    'accounts': fields.List(fields.Nested(account_with_role_fields))
-}
+account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}

+ 60 - 64
api/fields/message_fields.py

@@ -3,83 +3,79 @@ from flask_restful import fields
 from fields.conversation_fields import message_file_fields
 from fields.conversation_fields import message_file_fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
-feedback_fields = {
-    'rating': fields.String
-}
+feedback_fields = {"rating": fields.String}
 
 
 retriever_resource_fields = {
 retriever_resource_fields = {
-    'id': fields.String,
-    'message_id': fields.String,
-    'position': fields.Integer,
-    'dataset_id': fields.String,
-    'dataset_name': fields.String,
-    'document_id': fields.String,
-    'document_name': fields.String,
-    'data_source_type': fields.String,
-    'segment_id': fields.String,
-    'score': fields.Float,
-    'hit_count': fields.Integer,
-    'word_count': fields.Integer,
-    'segment_position': fields.Integer,
-    'index_node_hash': fields.String,
-    'content': fields.String,
-    'created_at': TimestampField
+    "id": fields.String,
+    "message_id": fields.String,
+    "position": fields.Integer,
+    "dataset_id": fields.String,
+    "dataset_name": fields.String,
+    "document_id": fields.String,
+    "document_name": fields.String,
+    "data_source_type": fields.String,
+    "segment_id": fields.String,
+    "score": fields.Float,
+    "hit_count": fields.Integer,
+    "word_count": fields.Integer,
+    "segment_position": fields.Integer,
+    "index_node_hash": fields.String,
+    "content": fields.String,
+    "created_at": TimestampField,
 }
 }
 
 
-feedback_fields = {
-    'rating': fields.String
-}
+feedback_fields = {"rating": fields.String}
 
 
 agent_thought_fields = {
 agent_thought_fields = {
-    'id': fields.String,
-    'chain_id': fields.String,
-    'message_id': fields.String,
-    'position': fields.Integer,
-    'thought': fields.String,
-    'tool': fields.String,
-    'tool_labels': fields.Raw,
-    'tool_input': fields.String,
-    'created_at': TimestampField,
-    'observation': fields.String,
-    'files': fields.List(fields.String)
+    "id": fields.String,
+    "chain_id": fields.String,
+    "message_id": fields.String,
+    "position": fields.Integer,
+    "thought": fields.String,
+    "tool": fields.String,
+    "tool_labels": fields.Raw,
+    "tool_input": fields.String,
+    "created_at": TimestampField,
+    "observation": fields.String,
+    "files": fields.List(fields.String),
 }
 }
 
 
 retriever_resource_fields = {
 retriever_resource_fields = {
-    'id': fields.String,
-    'message_id': fields.String,
-    'position': fields.Integer,
-    'dataset_id': fields.String,
-    'dataset_name': fields.String,
-    'document_id': fields.String,
-    'document_name': fields.String,
-    'data_source_type': fields.String,
-    'segment_id': fields.String,
-    'score': fields.Float,
-    'hit_count': fields.Integer,
-    'word_count': fields.Integer,
-    'segment_position': fields.Integer,
-    'index_node_hash': fields.String,
-    'content': fields.String,
-    'created_at': TimestampField
+    "id": fields.String,
+    "message_id": fields.String,
+    "position": fields.Integer,
+    "dataset_id": fields.String,
+    "dataset_name": fields.String,
+    "document_id": fields.String,
+    "document_name": fields.String,
+    "data_source_type": fields.String,
+    "segment_id": fields.String,
+    "score": fields.Float,
+    "hit_count": fields.Integer,
+    "word_count": fields.Integer,
+    "segment_position": fields.Integer,
+    "index_node_hash": fields.String,
+    "content": fields.String,
+    "created_at": TimestampField,
 }
 }
 
 
 message_fields = {
 message_fields = {
-    'id': fields.String,
-    'conversation_id': fields.String,
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'answer': fields.String(attribute='re_sign_file_url_answer'),
-    'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
-    'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
-    'created_at': TimestampField,
-    'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
-    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
-    'status': fields.String,
-    'error': fields.String,
+    "id": fields.String,
+    "conversation_id": fields.String,
+    "inputs": fields.Raw,
+    "query": fields.String,
+    "answer": fields.String(attribute="re_sign_file_url_answer"),
+    "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
+    "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
+    "created_at": TimestampField,
+    "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
+    "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+    "status": fields.String,
+    "error": fields.String,
 }
 }
 
 
 message_infinite_scroll_pagination_fields = {
 message_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(message_fields))
+    "limit": fields.Integer,
+    "has_more": fields.Boolean,
+    "data": fields.List(fields.Nested(message_fields)),
 }
 }

+ 24 - 24
api/fields/segment_fields.py

@@ -3,31 +3,31 @@ from flask_restful import fields
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 
 
 segment_fields = {
 segment_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'document_id': fields.String,
-    'content': fields.String,
-    'answer': fields.String,
-    'word_count': fields.Integer,
-    'tokens': fields.Integer,
-    'keywords': fields.List(fields.String),
-    'index_node_id': fields.String,
-    'index_node_hash': fields.String,
-    'hit_count': fields.Integer,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'status': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'indexing_at': TimestampField,
-    'completed_at': TimestampField,
-    'error': fields.String,
-    'stopped_at': TimestampField
+    "id": fields.String,
+    "position": fields.Integer,
+    "document_id": fields.String,
+    "content": fields.String,
+    "answer": fields.String,
+    "word_count": fields.Integer,
+    "tokens": fields.Integer,
+    "keywords": fields.List(fields.String),
+    "index_node_id": fields.String,
+    "index_node_hash": fields.String,
+    "hit_count": fields.Integer,
+    "enabled": fields.Boolean,
+    "disabled_at": TimestampField,
+    "disabled_by": fields.String,
+    "status": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "indexing_at": TimestampField,
+    "completed_at": TimestampField,
+    "error": fields.String,
+    "stopped_at": TimestampField,
 }
 }
 
 
 segment_list_response = {
 segment_list_response = {
-    'data': fields.List(fields.Nested(segment_fields)),
-    'has_more': fields.Boolean,
-    'limit': fields.Integer
+    "data": fields.List(fields.Nested(segment_fields)),
+    "has_more": fields.Boolean,
+    "limit": fields.Integer,
 }
 }

+ 1 - 6
api/fields/tag_fields.py

@@ -1,8 +1,3 @@
 from flask_restful import fields
 from flask_restful import fields
 
 
-tag_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'type': fields.String,
-    'binding_count': fields.String
-}
+tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String}

+ 9 - 9
api/fields/workflow_app_log_fields.py

@@ -7,18 +7,18 @@ from libs.helper import TimestampField
 
 
 workflow_app_log_partial_fields = {
 workflow_app_log_partial_fields = {
     "id": fields.String,
     "id": fields.String,
-    "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True),
+    "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
     "created_from": fields.String,
     "created_from": fields.String,
     "created_by_role": fields.String,
     "created_by_role": fields.String,
-    "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
-    "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True),
-    "created_at": TimestampField
+    "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
+    "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
+    "created_at": TimestampField,
 }
 }
 
 
 workflow_app_log_pagination_fields = {
 workflow_app_log_pagination_fields = {
-    'page': fields.Integer,
-    'limit': fields.Integer(attribute='per_page'),
-    'total': fields.Integer,
-    'has_more': fields.Boolean(attribute='has_next'),
-    'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items')
+    "page": fields.Integer,
+    "limit": fields.Integer(attribute="per_page"),
+    "total": fields.Integer,
+    "has_more": fields.Boolean(attribute="has_next"),
+    "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"),
 }
 }

+ 26 - 26
api/fields/workflow_fields.py

@@ -13,43 +13,43 @@ class EnvironmentVariableField(fields.Raw):
         # Mask secret variables values in environment_variables
         # Mask secret variables values in environment_variables
         if isinstance(value, SecretVariable):
         if isinstance(value, SecretVariable):
             return {
             return {
-                'id': value.id,
-                'name': value.name,
-                'value': encrypter.obfuscated_token(value.value),
-                'value_type': value.value_type.value,
+                "id": value.id,
+                "name": value.name,
+                "value": encrypter.obfuscated_token(value.value),
+                "value_type": value.value_type.value,
             }
             }
         if isinstance(value, Variable):
         if isinstance(value, Variable):
             return {
             return {
-                'id': value.id,
-                'name': value.name,
-                'value': value.value,
-                'value_type': value.value_type.value,
+                "id": value.id,
+                "name": value.name,
+                "value": value.value,
+                "value_type": value.value_type.value,
             }
             }
         if isinstance(value, dict):
         if isinstance(value, dict):
-            value_type = value.get('value_type')
+            value_type = value.get("value_type")
             if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
             if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
-                raise ValueError(f'Unsupported environment variable value type: {value_type}')
+                raise ValueError(f"Unsupported environment variable value type: {value_type}")
             return value
             return value
 
 
 
 
 conversation_variable_fields = {
 conversation_variable_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'value_type': fields.String(attribute='value_type.value'),
-    'value': fields.Raw,
-    'description': fields.String,
+    "id": fields.String,
+    "name": fields.String,
+    "value_type": fields.String(attribute="value_type.value"),
+    "value": fields.Raw,
+    "description": fields.String,
 }
 }
 
 
 workflow_fields = {
 workflow_fields = {
-    'id': fields.String,
-    'graph': fields.Raw(attribute='graph_dict'),
-    'features': fields.Raw(attribute='features_dict'),
-    'hash': fields.String(attribute='unique_hash'),
-    'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
-    'created_at': TimestampField,
-    'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
-    'updated_at': TimestampField,
-    'tool_published': fields.Boolean,
-    'environment_variables': fields.List(EnvironmentVariableField()),
-    'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)),
+    "id": fields.String,
+    "graph": fields.Raw(attribute="graph_dict"),
+    "features": fields.Raw(attribute="features_dict"),
+    "hash": fields.String(attribute="unique_hash"),
+    "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
+    "created_at": TimestampField,
+    "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
+    "updated_at": TimestampField,
+    "tool_published": fields.Boolean,
+    "environment_variables": fields.List(EnvironmentVariableField()),
+    "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
 }
 }

+ 25 - 25
api/fields/workflow_run_fields.py

@@ -13,7 +13,7 @@ workflow_run_for_log_fields = {
     "total_tokens": fields.Integer,
     "total_tokens": fields.Integer,
     "total_steps": fields.Integer,
     "total_steps": fields.Integer,
     "created_at": TimestampField,
     "created_at": TimestampField,
-    "finished_at": TimestampField
+    "finished_at": TimestampField,
 }
 }
 
 
 workflow_run_for_list_fields = {
 workflow_run_for_list_fields = {
@@ -24,9 +24,9 @@ workflow_run_for_list_fields = {
     "elapsed_time": fields.Float,
     "elapsed_time": fields.Float,
     "total_tokens": fields.Integer,
     "total_tokens": fields.Integer,
     "total_steps": fields.Integer,
     "total_steps": fields.Integer,
-    "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
+    "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
     "created_at": TimestampField,
     "created_at": TimestampField,
-    "finished_at": TimestampField
+    "finished_at": TimestampField,
 }
 }
 
 
 advanced_chat_workflow_run_for_list_fields = {
 advanced_chat_workflow_run_for_list_fields = {
@@ -39,40 +39,40 @@ advanced_chat_workflow_run_for_list_fields = {
     "elapsed_time": fields.Float,
     "elapsed_time": fields.Float,
     "total_tokens": fields.Integer,
     "total_tokens": fields.Integer,
     "total_steps": fields.Integer,
     "total_steps": fields.Integer,
-    "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
+    "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
     "created_at": TimestampField,
     "created_at": TimestampField,
-    "finished_at": TimestampField
+    "finished_at": TimestampField,
 }
 }
 
 
 advanced_chat_workflow_run_pagination_fields = {
 advanced_chat_workflow_run_pagination_fields = {
-    'limit': fields.Integer(attribute='limit'),
-    'has_more': fields.Boolean(attribute='has_more'),
-    'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data')
+    "limit": fields.Integer(attribute="limit"),
+    "has_more": fields.Boolean(attribute="has_more"),
+    "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"),
 }
 }
 
 
 workflow_run_pagination_fields = {
 workflow_run_pagination_fields = {
-    'limit': fields.Integer(attribute='limit'),
-    'has_more': fields.Boolean(attribute='has_more'),
-    'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data')
+    "limit": fields.Integer(attribute="limit"),
+    "has_more": fields.Boolean(attribute="has_more"),
+    "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"),
 }
 }
 
 
 workflow_run_detail_fields = {
 workflow_run_detail_fields = {
     "id": fields.String,
     "id": fields.String,
     "sequence_number": fields.Integer,
     "sequence_number": fields.Integer,
     "version": fields.String,
     "version": fields.String,
-    "graph": fields.Raw(attribute='graph_dict'),
-    "inputs": fields.Raw(attribute='inputs_dict'),
+    "graph": fields.Raw(attribute="graph_dict"),
+    "inputs": fields.Raw(attribute="inputs_dict"),
     "status": fields.String,
     "status": fields.String,
-    "outputs": fields.Raw(attribute='outputs_dict'),
+    "outputs": fields.Raw(attribute="outputs_dict"),
     "error": fields.String,
     "error": fields.String,
     "elapsed_time": fields.Float,
     "elapsed_time": fields.Float,
     "total_tokens": fields.Integer,
     "total_tokens": fields.Integer,
     "total_steps": fields.Integer,
     "total_steps": fields.Integer,
     "created_by_role": fields.String,
     "created_by_role": fields.String,
-    "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
-    "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True),
+    "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
+    "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
     "created_at": TimestampField,
     "created_at": TimestampField,
-    "finished_at": TimestampField
+    "finished_at": TimestampField,
 }
 }
 
 
 workflow_run_node_execution_fields = {
 workflow_run_node_execution_fields = {
@@ -82,21 +82,21 @@ workflow_run_node_execution_fields = {
     "node_id": fields.String,
     "node_id": fields.String,
     "node_type": fields.String,
     "node_type": fields.String,
     "title": fields.String,
     "title": fields.String,
-    "inputs": fields.Raw(attribute='inputs_dict'),
-    "process_data": fields.Raw(attribute='process_data_dict'),
-    "outputs": fields.Raw(attribute='outputs_dict'),
+    "inputs": fields.Raw(attribute="inputs_dict"),
+    "process_data": fields.Raw(attribute="process_data_dict"),
+    "outputs": fields.Raw(attribute="outputs_dict"),
     "status": fields.String,
     "status": fields.String,
     "error": fields.String,
     "error": fields.String,
     "elapsed_time": fields.Float,
     "elapsed_time": fields.Float,
-    "execution_metadata": fields.Raw(attribute='execution_metadata_dict'),
+    "execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
     "extras": fields.Raw,
     "extras": fields.Raw,
     "created_at": TimestampField,
     "created_at": TimestampField,
     "created_by_role": fields.String,
     "created_by_role": fields.String,
-    "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
-    "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True),
-    "finished_at": TimestampField
+    "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
+    "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
+    "finished_at": TimestampField,
 }
 }
 
 
 workflow_run_node_execution_list_fields = {
 workflow_run_node_execution_list_fields = {
-    'data': fields.List(fields.Nested(workflow_run_node_execution_fields)),
+    "data": fields.List(fields.Nested(workflow_run_node_execution_fields)),
 }
 }

+ 12 - 1
api/pyproject.toml

@@ -69,7 +69,18 @@ ignore = [
 ]
 ]
 
 
 [tool.ruff.format]
 [tool.ruff.format]
-quote-style = "single"
+exclude = [
+    "core/**/*.py",
+    "controllers/**/*.py",
+    "models/**/*.py",
+    "utils/**/*.py",
+    "migrations/**/*",
+    "services/**/*.py",
+    "tasks/**/*.py",
+    "tests/**/*.py",
+    "libs/**/*.py",
+    "configs/**/*.py",
+]
 
 
 [tool.pytest_env]
 [tool.pytest_env]
 OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
 OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"

+ 13 - 8
api/schedule/clean_embedding_cache_task.py

@@ -11,27 +11,32 @@ from extensions.ext_database import db
 from models.dataset import Embedding
 from models.dataset import Embedding
 
 
 
 
-@app.celery.task(queue='dataset')
+@app.celery.task(queue="dataset")
 def clean_embedding_cache_task():
 def clean_embedding_cache_task():
-    click.echo(click.style('Start clean embedding cache.', fg='green'))
+    click.echo(click.style("Start clean embedding cache.", fg="green"))
     clean_days = int(dify_config.CLEAN_DAY_SETTING)
     clean_days = int(dify_config.CLEAN_DAY_SETTING)
     start_at = time.perf_counter()
     start_at = time.perf_counter()
     thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
     thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
     while True:
     while True:
         try:
         try:
-            embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \
-                .order_by(Embedding.created_at.desc()).limit(100).all()
+            embedding_ids = (
+                db.session.query(Embedding.id)
+                .filter(Embedding.created_at < thirty_days_ago)
+                .order_by(Embedding.created_at.desc())
+                .limit(100)
+                .all()
+            )
             embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
             embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
         except NotFound:
         except NotFound:
             break
             break
         if embedding_ids:
         if embedding_ids:
             for embedding_id in embedding_ids:
             for embedding_id in embedding_ids:
-                db.session.execute(text(
-                    "DELETE FROM embeddings WHERE id = :embedding_id"
-                ), {'embedding_id': embedding_id})
+                db.session.execute(
+                    text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id}
+                )
 
 
             db.session.commit()
             db.session.commit()
         else:
         else:
             break
             break
     end_at = time.perf_counter()
     end_at = time.perf_counter()
-    click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green'))
+    click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green"))

+ 46 - 44
api/schedule/clean_unused_datasets_task.py

@@ -12,9 +12,9 @@ from extensions.ext_database import db
 from models.dataset import Dataset, DatasetQuery, Document
 from models.dataset import Dataset, DatasetQuery, Document
 
 
 
 
-@app.celery.task(queue='dataset')
+@app.celery.task(queue="dataset")
 def clean_unused_datasets_task():
 def clean_unused_datasets_task():
-    click.echo(click.style('Start clean unused datasets indexes.', fg='green'))
+    click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
     clean_days = dify_config.CLEAN_DAY_SETTING
     clean_days = dify_config.CLEAN_DAY_SETTING
     start_at = time.perf_counter()
     start_at = time.perf_counter()
     thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
     thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
@@ -22,40 +22,44 @@ def clean_unused_datasets_task():
     while True:
     while True:
         try:
         try:
             # Subquery for counting new documents
             # Subquery for counting new documents
-            document_subquery_new = db.session.query(
-                Document.dataset_id,
-                func.count(Document.id).label('document_count')
-            ).filter(
-                Document.indexing_status == 'completed',
-                Document.enabled == True,
-                Document.archived == False,
-                Document.updated_at > thirty_days_ago
-            ).group_by(Document.dataset_id).subquery()
+            document_subquery_new = (
+                db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
+                .filter(
+                    Document.indexing_status == "completed",
+                    Document.enabled == True,
+                    Document.archived == False,
+                    Document.updated_at > thirty_days_ago,
+                )
+                .group_by(Document.dataset_id)
+                .subquery()
+            )
 
 
             # Subquery for counting old documents
             # Subquery for counting old documents
-            document_subquery_old = db.session.query(
-                Document.dataset_id,
-                func.count(Document.id).label('document_count')
-            ).filter(
-                Document.indexing_status == 'completed',
-                Document.enabled == True,
-                Document.archived == False,
-                Document.updated_at < thirty_days_ago
-            ).group_by(Document.dataset_id).subquery()
+            document_subquery_old = (
+                db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
+                .filter(
+                    Document.indexing_status == "completed",
+                    Document.enabled == True,
+                    Document.archived == False,
+                    Document.updated_at < thirty_days_ago,
+                )
+                .group_by(Document.dataset_id)
+                .subquery()
+            )
 
 
             # Main query with join and filter
             # Main query with join and filter
-            datasets = (db.session.query(Dataset)
-                        .outerjoin(
-                document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id
-            ).outerjoin(
-                document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id
-            ).filter(
-                Dataset.created_at < thirty_days_ago,
-                func.coalesce(document_subquery_new.c.document_count, 0) == 0,
-                func.coalesce(document_subquery_old.c.document_count, 0) > 0
-            ).order_by(
-                Dataset.created_at.desc()
-            ).paginate(page=page, per_page=50))
+            datasets = (
+                db.session.query(Dataset)
+                .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
+                .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
+                .filter(
+                    Dataset.created_at < thirty_days_ago,
+                    func.coalesce(document_subquery_new.c.document_count, 0) == 0,
+                    func.coalesce(document_subquery_old.c.document_count, 0) > 0,
+                )
+                .order_by(Dataset.created_at.desc())
+                .paginate(page=page, per_page=50)
+            )
 
 
         except NotFound:
         except NotFound:
             break
             break
@@ -63,10 +67,11 @@ def clean_unused_datasets_task():
             break
             break
         page += 1
         page += 1
         for dataset in datasets:
         for dataset in datasets:
-            dataset_query = db.session.query(DatasetQuery).filter(
-                DatasetQuery.created_at > thirty_days_ago,
-                DatasetQuery.dataset_id == dataset.id
-            ).all()
+            dataset_query = (
+                db.session.query(DatasetQuery)
+                .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id)
+                .all()
+            )
             if not dataset_query or len(dataset_query) == 0:
             if not dataset_query or len(dataset_query) == 0:
                 try:
                 try:
                     # remove index
                     # remove index
@@ -74,17 +79,14 @@ def clean_unused_datasets_task():
                     index_processor.clean(dataset, None)
                     index_processor.clean(dataset, None)
 
 
                     # update document
                     # update document
-                    update_params = {
-                        Document.enabled: False
-                    }
+                    update_params = {Document.enabled: False}
 
 
                     Document.query.filter_by(dataset_id=dataset.id).update(update_params)
                     Document.query.filter_by(dataset_id=dataset.id).update(update_params)
                     db.session.commit()
                     db.session.commit()
-                    click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
-                                           fg='green'))
+                    click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
                 except Exception as e:
                 except Exception as e:
                     click.echo(
                     click.echo(
-                        click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
-                                    fg='red'))
+                        click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
+                    )
     end_at = time.perf_counter()
     end_at = time.perf_counter()
-    click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
+    click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green"))

+ 3 - 0
dev/reformat

@@ -11,5 +11,8 @@ fi
 # run ruff linter
 # run ruff linter
 ruff check --fix ./api
 ruff check --fix ./api
 
 
+# run ruff formatter
+ruff format ./api
+
 # run dotenv-linter linter
 # run dotenv-linter linter
 dotenv-linter ./api/.env.example ./web/.env.example
 dotenv-linter ./api/.env.example ./web/.env.example