浏览代码

chore(api/controllers): Apply Ruff Formatter. (#7645)

-LAN- 8 月之前
父节点
当前提交
13be84e4d4
共有 100 个文件被更改,包括 3760 次插入3882 次删除
  1. 0 2
      api/controllers/__init__.py
  2. 1 1
      api/controllers/console/__init__.py
  3. 49 46
      api/controllers/console/admin.py
  4. 56 60
      api/controllers/console/apikey.py
  5. 6 7
      api/controllers/console/app/advanced_prompt_template.py
  6. 6 9
      api/controllers/console/app/agent.py
  7. 62 70
      api/controllers/console/app/annotation.py
  8. 85 91
      api/controllers/console/app/app.py
  9. 25 25
      api/controllers/console/app/audio.py
  10. 25 34
      api/controllers/console/app/completion.py
  11. 100 93
      api/controllers/console/app/conversation.py
  12. 12 12
      api/controllers/console/app/conversation_variables.py
  13. 27 23
      api/controllers/console/app/error.py
  14. 9 9
      api/controllers/console/app/generator.py
  15. 75 72
      api/controllers/console/app/message.py
  16. 19 24
      api/controllers/console/app/model_config.py
  17. 12 21
      api/controllers/console/app/ops_trace.py
  18. 35 38
      api/controllers/console/app/site.py
  19. 131 161
      api/controllers/console/app/statistic.py
  20. 88 113
      api/controllers/console/app/workflow.py
  21. 6 7
      api/controllers/console/app/workflow_app_log.py
  22. 11 19
      api/controllers/console/app/workflow_run.py
  23. 100 91
      api/controllers/console/app/workflow_statistic.py
  24. 12 12
      api/controllers/console/app/wraps.py
  25. 27 26
      api/controllers/console/auth/activate.py
  26. 20 19
      api/controllers/console/auth/data_source_bearer_auth.py
  27. 35 37
      api/controllers/console/auth/data_source_oauth.py
  28. 5 6
      api/controllers/console/auth/error.py
  29. 17 20
      api/controllers/console/auth/forgot_password.py
  30. 17 15
      api/controllers/console/auth/login.py
  31. 13 13
      api/controllers/console/auth/oauth.py
  32. 7 11
      api/controllers/console/billing/billing.py
  33. 93 90
      api/controllers/console/datasets/data_source.py
  34. 254 220
      api/controllers/console/datasets/datasets.py
  35. 260 281
      api/controllers/console/datasets/datasets_document.py
  36. 100 114
      api/controllers/console/datasets/datasets_segments.py
  37. 15 15
      api/controllers/console/datasets/error.py
  38. 12 14
      api/controllers/console/datasets/file.py
  39. 9 9
      api/controllers/console/datasets/hit_testing.py
  40. 7 9
      api/controllers/console/datasets/website.py
  41. 16 10
      api/controllers/console/error.py
  42. 23 26
      api/controllers/console/explore/audio.py
  43. 35 32
      api/controllers/console/explore/completion.py
  44. 33 22
      api/controllers/console/explore/conversation.py
  45. 4 4
      api/controllers/console/explore/error.py
  46. 38 38
      api/controllers/console/explore/installed_app.py
  47. 34 21
      api/controllers/console/explore/message.py
  48. 52 48
      api/controllers/console/explore/parameter.py
  49. 21 21
      api/controllers/console/explore/recommended_app.py
  50. 31 25
      api/controllers/console/explore/saved_message.py
  51. 8 12
      api/controllers/console/explore/workflow.py
  52. 14 10
      api/controllers/console/explore/wraps.py
  53. 19 25
      api/controllers/console/extension.py
  54. 2 3
      api/controllers/console/feature.py
  55. 16 16
      api/controllers/console/init_validate.py
  56. 2 5
      api/controllers/console/ping.py
  57. 14 23
      api/controllers/console/setup.py
  58. 32 49
      api/controllers/console/tag/tags.py
  59. 16 20
      api/controllers/console/version.py
  60. 75 81
      api/controllers/console/workspace/account.py
  61. 6 6
      api/controllers/console/workspace/error.py
  62. 39 23
      api/controllers/console/workspace/load_balancing_config.py
  63. 39 41
      api/controllers/console/workspace/members.py
  64. 64 74
      api/controllers/console/workspace/model_providers.py
  65. 145 139
      api/controllers/console/workspace/models.py
  66. 228 172
      api/controllers/console/workspace/tool_providers.py
  67. 73 70
      api/controllers/console/workspace/workspace.py
  68. 21 18
      api/controllers/console/wraps.py
  69. 1 1
      api/controllers/files/__init__.py
  70. 11 16
      api/controllers/files/image_preview.py
  71. 16 13
      api/controllers/files/tool_files.py
  72. 1 2
      api/controllers/inner_api/__init__.py
  73. 8 13
      api/controllers/inner_api/workspace/workspace.py
  74. 10 10
      api/controllers/inner_api/wraps.py
  75. 1 1
      api/controllers/service_api/__init__.py
  76. 52 52
      api/controllers/service_api/app/app.py
  77. 23 25
      api/controllers/service_api/app/audio.py
  78. 24 28
      api/controllers/service_api/app/completion.py
  79. 20 22
      api/controllers/service_api/app/conversation.py
  80. 25 21
      api/controllers/service_api/app/error.py
  81. 3 5
      api/controllers/service_api/app/file.py
  82. 57 61
      api/controllers/service_api/app/message.py
  83. 23 26
      api/controllers/service_api/app/workflow.py
  84. 46 42
      api/controllers/service_api/dataset/dataset.py
  85. 119 176
      api/controllers/service_api/dataset/document.py
  86. 13 13
      api/controllers/service_api/dataset/error.py
  87. 53 74
      api/controllers/service_api/dataset/segment.py
  88. 1 1
      api/controllers/service_api/index.py
  89. 63 46
      api/controllers/service_api/wraps.py
  90. 1 1
      api/controllers/web/__init__.py
  91. 50 46
      api/controllers/web/app.py
  92. 24 26
      api/controllers/web/audio.py
  93. 25 34
      api/controllers/web/completion.py
  94. 24 27
      api/controllers/web/conversation.py
  95. 28 24
      api/controllers/web/error.py
  96. 1 1
      api/controllers/web/feature.py
  97. 3 4
      api/controllers/web/file.py
  98. 53 55
      api/controllers/web/message.py
  99. 15 18
      api/controllers/web/passport.py
  100. 23 25
      api/controllers/web/saved_message.py

+ 0 - 2
api/controllers/__init__.py

@@ -1,3 +1 @@
 
-
-

+ 1 - 1
api/controllers/console/__init__.py

@@ -2,7 +2,7 @@ from flask import Blueprint
 
 from libs.external_api import ExternalApi
 
-bp = Blueprint('console', __name__, url_prefix='/console/api')
+bp = Blueprint("console", __name__, url_prefix="/console/api")
 api = ExternalApi(bp)
 
 # Import other controllers

+ 49 - 46
api/controllers/console/admin.py

@@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
 def admin_required(view):
     @wraps(view)
     def decorated(*args, **kwargs):
-        if not os.getenv('ADMIN_API_KEY'):
-            raise Unauthorized('API key is invalid.')
+        if not os.getenv("ADMIN_API_KEY"):
+            raise Unauthorized("API key is invalid.")
 
-        auth_header = request.headers.get('Authorization')
+        auth_header = request.headers.get("Authorization")
         if auth_header is None:
-            raise Unauthorized('Authorization header is missing.')
+            raise Unauthorized("Authorization header is missing.")
 
-        if ' ' not in auth_header:
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+        if " " not in auth_header:
+            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
 
         auth_scheme, auth_token = auth_header.split(None, 1)
         auth_scheme = auth_scheme.lower()
 
-        if auth_scheme != 'bearer':
-            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+        if auth_scheme != "bearer":
+            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
 
-        if os.getenv('ADMIN_API_KEY') != auth_token:
-            raise Unauthorized('API key is invalid.')
+        if os.getenv("ADMIN_API_KEY") != auth_token:
+            raise Unauthorized("API key is invalid.")
 
         return view(*args, **kwargs)
 
@@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource):
     @admin_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('desc', type=str, location='json')
-        parser.add_argument('copyright', type=str, location='json')
-        parser.add_argument('privacy_policy', type=str, location='json')
-        parser.add_argument('custom_disclaimer', type=str, location='json')
-        parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
-        parser.add_argument('category', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('position', type=int, required=True, nullable=False, location='json')
+        parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("desc", type=str, location="json")
+        parser.add_argument("copyright", type=str, location="json")
+        parser.add_argument("privacy_policy", type=str, location="json")
+        parser.add_argument("custom_disclaimer", type=str, location="json")
+        parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
+        parser.add_argument("category", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("position", type=int, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
-        app = App.query.filter(App.id == args['app_id']).first()
+        app = App.query.filter(App.id == args["app_id"]).first()
         if not app:
             raise NotFound(f'App \'{args["app_id"]}\' is not found')
 
         site = app.site
         if not site:
-            desc = args['desc'] if args['desc'] else ''
-            copy_right = args['copyright'] if args['copyright'] else ''
-            privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
-            custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
+            desc = args["desc"] if args["desc"] else ""
+            copy_right = args["copyright"] if args["copyright"] else ""
+            privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
+            custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
         else:
-            desc = site.description if site.description else \
-                args['desc'] if args['desc'] else ''
-            copy_right = site.copyright if site.copyright else \
-                args['copyright'] if args['copyright'] else ''
-            privacy_policy = site.privacy_policy if site.privacy_policy else \
-                args['privacy_policy'] if args['privacy_policy']  else ''
-            custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
-                args['custom_disclaimer'] if args['custom_disclaimer'] else ''
+            desc = site.description if site.description else args["desc"] if args["desc"] else ""
+            copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
+            privacy_policy = (
+                site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
+            )
+            custom_disclaimer = (
+                site.custom_disclaimer
+                if site.custom_disclaimer
+                else args["custom_disclaimer"]
+                if args["custom_disclaimer"]
+                else ""
+            )
 
-        recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
+        recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
 
         if not recommended_app:
             recommended_app = RecommendedApp(
@@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource):
                 copyright=copy_right,
                 privacy_policy=privacy_policy,
                 custom_disclaimer=custom_disclaimer,
-                language=args['language'],
-                category=args['category'],
-                position=args['position']
+                language=args["language"],
+                category=args["category"],
+                position=args["position"],
             )
 
             db.session.add(recommended_app)
@@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource):
             app.is_public = True
             db.session.commit()
 
-            return {'result': 'success'}, 201
+            return {"result": "success"}, 201
         else:
             recommended_app.description = desc
             recommended_app.copyright = copy_right
             recommended_app.privacy_policy = privacy_policy
             recommended_app.custom_disclaimer = custom_disclaimer
-            recommended_app.language = args['language']
-            recommended_app.category = args['category']
-            recommended_app.position = args['position']
+            recommended_app.language = args["language"]
+            recommended_app.category = args["category"]
+            recommended_app.position = args["position"]
 
             app.is_public = True
 
             db.session.commit()
 
-            return {'result': 'success'}, 200
+            return {"result": "success"}, 200
 
 
 class InsertExploreAppApi(Resource):
@@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource):
     def delete(self, app_id):
         recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
         if not recommended_app:
-            return {'result': 'success'}, 204
+            return {"result": "success"}, 204
 
         app = App.query.filter(App.id == recommended_app.app_id).first()
         if app:
             app.is_public = False
 
         installed_apps = InstalledApp.query.filter(
-            InstalledApp.app_id == recommended_app.app_id,
-            InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
+            InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
         ).all()
 
         for installed_app in installed_apps:
@@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource):
         db.session.delete(recommended_app)
         db.session.commit()
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
-api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
-api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')
+api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
+api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")

+ 56 - 60
api/controllers/console/apikey.py

@@ -14,26 +14,21 @@ from .setup import setup_required
 from .wraps import account_initialization_required
 
 api_key_fields = {
-    'id': fields.String,
-    'type': fields.String,
-    'token': fields.String,
-    'last_used_at': TimestampField,
-    'created_at': TimestampField
+    "id": fields.String,
+    "type": fields.String,
+    "token": fields.String,
+    "last_used_at": TimestampField,
+    "created_at": TimestampField,
 }
 
-api_key_list = {
-    'data': fields.List(fields.Nested(api_key_fields), attribute="items")
-}
+api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
 
 
 def _get_resource(resource_id, tenant_id, resource_model):
-    resource = resource_model.query.filter_by(
-        id=resource_id, tenant_id=tenant_id
-    ).first()
+    resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
 
     if resource is None:
-        flask_restful.abort(
-            404, message=f"{resource_model.__name__} not found.")
+        flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
 
     return resource
 
@@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource):
     @marshal_with(api_key_list)
     def get(self, resource_id):
         resource_id = str(resource_id)
-        _get_resource(resource_id, current_user.current_tenant_id,
-                      self.resource_model)
-        keys = db.session.query(ApiToken). \
-            filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
-            all()
+        _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
+        keys = (
+            db.session.query(ApiToken)
+            .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
+            .all()
+        )
         return {"items": keys}
 
     @marshal_with(api_key_fields)
     def post(self, resource_id):
         resource_id = str(resource_id)
-        _get_resource(resource_id, current_user.current_tenant_id,
-                      self.resource_model)
+        _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
         if not current_user.is_admin_or_owner:
             raise Forbidden()
 
-        current_key_count = db.session.query(ApiToken). \
-            filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
-            count()
+        current_key_count = (
+            db.session.query(ApiToken)
+            .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
+            .count()
+        )
 
         if current_key_count >= self.max_keys:
             flask_restful.abort(
                 400,
                 message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
-                code='max_keys_exceeded'
+                code="max_keys_exceeded",
             )
 
         key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource):
     def delete(self, resource_id, api_key_id):
         resource_id = str(resource_id)
         api_key_id = str(api_key_id)
-        _get_resource(resource_id, current_user.current_tenant_id,
-                      self.resource_model)
+        _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
 
         # The role of the current user in the ta table must be admin or owner
         if not current_user.is_admin_or_owner:
             raise Forbidden()
 
-        key = db.session.query(ApiToken). \
-            filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
-            first()
+        key = (
+            db.session.query(ApiToken)
+            .filter(
+                getattr(ApiToken, self.resource_id_field) == resource_id,
+                ApiToken.type == self.resource_type,
+                ApiToken.id == api_key_id,
+            )
+            .first()
+        )
 
         if key is None:
-            flask_restful.abort(404, message='API key not found')
+            flask_restful.abort(404, message="API key not found")
 
         db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
         db.session.commit()
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class AppApiKeyListResource(BaseApiKeyListResource):
-
     def after_request(self, resp):
-        resp.headers['Access-Control-Allow-Origin'] = '*'
-        resp.headers['Access-Control-Allow-Credentials'] = 'true'
+        resp.headers["Access-Control-Allow-Origin"] = "*"
+        resp.headers["Access-Control-Allow-Credentials"] = "true"
         return resp
 
-    resource_type = 'app'
+    resource_type = "app"
     resource_model = App
-    resource_id_field = 'app_id'
-    token_prefix = 'app-'
+    resource_id_field = "app_id"
+    token_prefix = "app-"
 
 
 class AppApiKeyResource(BaseApiKeyResource):
-
     def after_request(self, resp):
-        resp.headers['Access-Control-Allow-Origin'] = '*'
-        resp.headers['Access-Control-Allow-Credentials'] = 'true'
+        resp.headers["Access-Control-Allow-Origin"] = "*"
+        resp.headers["Access-Control-Allow-Credentials"] = "true"
         return resp
 
-    resource_type = 'app'
+    resource_type = "app"
     resource_model = App
-    resource_id_field = 'app_id'
+    resource_id_field = "app_id"
 
 
 class DatasetApiKeyListResource(BaseApiKeyListResource):
-
     def after_request(self, resp):
-        resp.headers['Access-Control-Allow-Origin'] = '*'
-        resp.headers['Access-Control-Allow-Credentials'] = 'true'
+        resp.headers["Access-Control-Allow-Origin"] = "*"
+        resp.headers["Access-Control-Allow-Credentials"] = "true"
         return resp
 
-    resource_type = 'dataset'
+    resource_type = "dataset"
     resource_model = Dataset
-    resource_id_field = 'dataset_id'
-    token_prefix = 'ds-'
+    resource_id_field = "dataset_id"
+    token_prefix = "ds-"
 
 
 class DatasetApiKeyResource(BaseApiKeyResource):
-
     def after_request(self, resp):
-        resp.headers['Access-Control-Allow-Origin'] = '*'
-        resp.headers['Access-Control-Allow-Credentials'] = 'true'
+        resp.headers["Access-Control-Allow-Origin"] = "*"
+        resp.headers["Access-Control-Allow-Credentials"] = "true"
         return resp
-    resource_type = 'dataset'
+
+    resource_type = "dataset"
     resource_model = Dataset
-    resource_id_field = 'dataset_id'
+    resource_id_field = "dataset_id"
 
 
-api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
-api.add_resource(AppApiKeyResource,
-                 '/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
-api.add_resource(DatasetApiKeyListResource,
-                 '/datasets/<uuid:resource_id>/api-keys')
-api.add_resource(DatasetApiKeyResource,
-                 '/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
+api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
+api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
+api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
+api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")

+ 6 - 7
api/controllers/console/app/advanced_prompt_template.py

@@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
 
 
 class AdvancedPromptTemplateList(Resource):
-    
     @setup_required
     @login_required
     @account_initialization_required
     def get(self):
-         
         parser = reqparse.RequestParser()
-        parser.add_argument('app_mode', type=str, required=True, location='args')
-        parser.add_argument('model_mode', type=str, required=True, location='args')
-        parser.add_argument('has_context', type=str, required=False, default='true', location='args')
-        parser.add_argument('model_name', type=str, required=True, location='args')
+        parser.add_argument("app_mode", type=str, required=True, location="args")
+        parser.add_argument("model_mode", type=str, required=True, location="args")
+        parser.add_argument("has_context", type=str, required=False, default="true", location="args")
+        parser.add_argument("model_name", type=str, required=True, location="args")
         args = parser.parse_args()
 
         return AdvancedPromptTemplateService.get_prompt(args)
 
-api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
+
+api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")

+ 6 - 9
api/controllers/console/app/agent.py

@@ -18,15 +18,12 @@ class AgentLogApi(Resource):
     def get(self, app_model):
         """Get agent logs"""
         parser = reqparse.RequestParser()
-        parser.add_argument('message_id', type=uuid_value, required=True, location='args')
-        parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
+        parser.add_argument("message_id", type=uuid_value, required=True, location="args")
+        parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
 
         args = parser.parse_args()
 
-        return AgentService.get_agent_logs(
-            app_model,
-            args['conversation_id'],
-            args['message_id']
-        )
-    
-api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')
+        return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
+
+
+api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")

+ 62 - 70
api/controllers/console/app/annotation.py

@@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     def post(self, app_id, action):
         if not current_user.is_editor:
             raise Forbidden()
 
         app_id = str(app_id)
         parser = reqparse.RequestParser()
-        parser.add_argument('score_threshold', required=True, type=float, location='json')
-        parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
-        parser.add_argument('embedding_model_name', required=True, type=str, location='json')
+        parser.add_argument("score_threshold", required=True, type=float, location="json")
+        parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
+        parser.add_argument("embedding_model_name", required=True, type=str, location="json")
         args = parser.parse_args()
-        if action == 'enable':
+        if action == "enable":
             result = AppAnnotationService.enable_app_annotation(args, app_id)
-        elif action == 'disable':
+        elif action == "disable":
             result = AppAnnotationService.disable_app_annotation(app_id)
         else:
-            raise ValueError('Unsupported annotation reply action')
+            raise ValueError("Unsupported annotation reply action")
         return result, 200
 
 
@@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
         annotation_setting_id = str(annotation_setting_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('score_threshold', required=True, type=float, location='json')
+        parser.add_argument("score_threshold", required=True, type=float, location="json")
         args = parser.parse_args()
 
         result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     def get(self, app_id, job_id, action):
         if not current_user.is_editor:
             raise Forbidden()
 
         job_id = str(job_id)
-        app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
+        app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
         cache_result = redis_client.get(app_annotation_job_key)
         if cache_result is None:
             raise ValueError("The job is not exist.")
 
         job_status = cache_result.decode()
-        error_msg = ''
-        if job_status == 'error':
-            app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
+        error_msg = ""
+        if job_status == "error":
+            app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
             error_msg = redis_client.get(app_annotation_error_key).decode()
 
-        return {
-            'job_id': job_id,
-            'job_status': job_status,
-            'error_msg': error_msg
-        }, 200
+        return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
 
 
 class AnnotationListApi(Resource):
@@ -109,18 +105,18 @@ class AnnotationListApi(Resource):
         if not current_user.is_editor:
             raise Forbidden()
 
-        page = request.args.get('page', default=1, type=int)
-        limit = request.args.get('limit', default=20, type=int)
-        keyword = request.args.get('keyword', default=None, type=str)
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        keyword = request.args.get("keyword", default=None, type=str)
 
         app_id = str(app_id)
         annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
         response = {
-            'data': marshal(annotation_list, annotation_fields),
-            'has_more': len(annotation_list) == limit,
-            'limit': limit,
-            'total': total,
-            'page': page
+            "data": marshal(annotation_list, annotation_fields),
+            "has_more": len(annotation_list) == limit,
+            "limit": limit,
+            "total": total,
+            "page": page,
         }
         return response, 200
 
@@ -135,9 +131,7 @@ class AnnotationExportApi(Resource):
 
         app_id = str(app_id)
         annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
-        response = {
-            'data': marshal(annotation_list, annotation_fields)
-        }
+        response = {"data": marshal(annotation_list, annotation_fields)}
         return response, 200
 
 
@@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     @marshal_with(annotation_fields)
     def post(self, app_id):
         if not current_user.is_editor:
@@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource):
 
         app_id = str(app_id)
         parser = reqparse.RequestParser()
-        parser.add_argument('question', required=True, type=str, location='json')
-        parser.add_argument('answer', required=True, type=str, location='json')
+        parser.add_argument("question", required=True, type=str, location="json")
+        parser.add_argument("answer", required=True, type=str, location="json")
         args = parser.parse_args()
         annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
         return annotation
@@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     @marshal_with(annotation_fields)
     def post(self, app_id, annotation_id):
         if not current_user.is_editor:
@@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource):
         app_id = str(app_id)
         annotation_id = str(annotation_id)
         parser = reqparse.RequestParser()
-        parser.add_argument('question', required=True, type=str, location='json')
-        parser.add_argument('answer', required=True, type=str, location='json')
+        parser.add_argument("question", required=True, type=str, location="json")
+        parser.add_argument("answer", required=True, type=str, location="json")
         args = parser.parse_args()
         annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
         return annotation
@@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource):
         app_id = str(app_id)
         annotation_id = str(annotation_id)
         AppAnnotationService.delete_app_annotation(app_id, annotation_id)
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class AnnotationBatchImportApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     def post(self, app_id):
         if not current_user.is_editor:
             raise Forbidden()
 
         app_id = str(app_id)
         # get file from request
-        file = request.files['file']
+        file = request.files["file"]
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if len(request.files) > 1:
             raise TooManyFilesError()
         # check file type
-        if not file.filename.endswith('.csv'):
+        if not file.filename.endswith(".csv"):
             raise ValueError("Invalid file type. Only CSV files are allowed")
         return AppAnnotationService.batch_import_app_annotations(app_id, file)
 
@@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     def get(self, app_id, job_id):
         if not current_user.is_editor:
             raise Forbidden()
 
         job_id = str(job_id)
-        indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
+        indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
         cache_result = redis_client.get(indexing_cache_key)
         if cache_result is None:
             raise ValueError("The job is not exist.")
         job_status = cache_result.decode()
-        error_msg = ''
-        if job_status == 'error':
-            indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
+        error_msg = ""
+        if job_status == "error":
+            indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
             error_msg = redis_client.get(indexing_error_msg_key).decode()
 
-        return {
-            'job_id': job_id,
-            'job_status': job_status,
-            'error_msg': error_msg
-        }, 200
+        return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
 
 
 class AnnotationHitHistoryListApi(Resource):
@@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource):
         if not current_user.is_editor:
             raise Forbidden()
 
-        page = request.args.get('page', default=1, type=int)
-        limit = request.args.get('limit', default=20, type=int)
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
         app_id = str(app_id)
         annotation_id = str(annotation_id)
-        annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
-                                                                                               page, limit)
+        annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
+            app_id, annotation_id, page, limit
+        )
         response = {
-            'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
-            'has_more': len(annotation_hit_history_list) == limit,
-            'limit': limit,
-            'total': total,
-            'page': page
+            "data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
+            "has_more": len(annotation_hit_history_list) == limit,
+            "limit": limit,
+            "total": total,
+            "page": page,
         }
         return response
 
 
-api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
-api.add_resource(AnnotationReplyActionStatusApi,
-                 '/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
-api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
-api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
-api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
-api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
-api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
-api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
-api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
-api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
+api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
+api.add_resource(
+    AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
+)
+api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
+api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
+api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
+api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
+api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
+api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
+api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
+api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")

+ 85 - 91
api/controllers/console/app/app.py

@@ -18,27 +18,35 @@ from libs.login import login_required
 from services.app_dsl_service import AppDslService
 from services.app_service import AppService
 
-ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
+ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
 
 
 class AppListApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def get(self):
         """Get app list"""
+
         def uuid_list(value):
             try:
-                return [str(uuid.UUID(v)) for v in value.split(',')]
+                return [str(uuid.UUID(v)) for v in value.split(",")]
             except ValueError:
                 abort(400, message="Invalid UUID format in tag_ids.")
+
         parser = reqparse.RequestParser()
-        parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
-        parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
-        parser.add_argument('name', type=str, location='args', required=False)
-        parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
+        parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
+        parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
+        parser.add_argument(
+            "mode",
+            type=str,
+            choices=["chat", "workflow", "agent-chat", "channel", "all"],
+            default="all",
+            location="args",
+            required=False,
+        )
+        parser.add_argument("name", type=str, location="args", required=False)
+        parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
 
         args = parser.parse_args()
 
@@ -46,7 +54,7 @@ class AppListApi(Resource):
         app_service = AppService()
         app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
         if not app_pagination:
-            return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
+            return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
 
         return marshal(app_pagination, app_pagination_fields)
 
@@ -54,23 +62,23 @@ class AppListApi(Resource):
     @login_required
     @account_initialization_required
     @marshal_with(app_detail_fields)
-    @cloud_edition_billing_resource_check('apps')
+    @cloud_edition_billing_resource_check("apps")
     def post(self):
         """Create app"""
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
-        parser.add_argument('description', type=str, location='json')
-        parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
-        parser.add_argument('icon_type', type=str, location='json')
-        parser.add_argument('icon', type=str, location='json')
-        parser.add_argument('icon_background', type=str, location='json')
+        parser.add_argument("name", type=str, required=True, location="json")
+        parser.add_argument("description", type=str, location="json")
+        parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
+        parser.add_argument("icon_type", type=str, location="json")
+        parser.add_argument("icon", type=str, location="json")
+        parser.add_argument("icon_background", type=str, location="json")
         args = parser.parse_args()
 
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
 
-        if 'mode' not in args or args['mode'] is None:
+        if "mode" not in args or args["mode"] is None:
             raise BadRequest("mode is required")
 
         app_service = AppService()
@@ -84,7 +92,7 @@ class AppImportApi(Resource):
     @login_required
     @account_initialization_required
     @marshal_with(app_detail_fields_with_site)
-    @cloud_edition_billing_resource_check('apps')
+    @cloud_edition_billing_resource_check("apps")
     def post(self):
         """Import app"""
         # The role of the current user in the ta table must be admin, owner, or editor
@@ -92,19 +100,16 @@ class AppImportApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('data', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('name', type=str, location='json')
-        parser.add_argument('description', type=str, location='json')
-        parser.add_argument('icon_type', type=str, location='json')
-        parser.add_argument('icon', type=str, location='json')
-        parser.add_argument('icon_background', type=str, location='json')
+        parser.add_argument("data", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("name", type=str, location="json")
+        parser.add_argument("description", type=str, location="json")
+        parser.add_argument("icon_type", type=str, location="json")
+        parser.add_argument("icon", type=str, location="json")
+        parser.add_argument("icon_background", type=str, location="json")
         args = parser.parse_args()
 
         app = AppDslService.import_and_create_new_app(
-            tenant_id=current_user.current_tenant_id,
-            data=args['data'],
-            args=args,
-            account=current_user
+            tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
         )
 
         return app, 201
@@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource):
     @login_required
     @account_initialization_required
     @marshal_with(app_detail_fields_with_site)
-    @cloud_edition_billing_resource_check('apps')
+    @cloud_edition_billing_resource_check("apps")
     def post(self):
         """Import app from url"""
         # The role of the current user in the ta table must be admin, owner, or editor
@@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('url', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('name', type=str, location='json')
-        parser.add_argument('description', type=str, location='json')
-        parser.add_argument('icon', type=str, location='json')
-        parser.add_argument('icon_background', type=str, location='json')
+        parser.add_argument("url", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("name", type=str, location="json")
+        parser.add_argument("description", type=str, location="json")
+        parser.add_argument("icon", type=str, location="json")
+        parser.add_argument("icon_background", type=str, location="json")
         args = parser.parse_args()
 
         app = AppDslService.import_and_create_new_app_from_url(
-            tenant_id=current_user.current_tenant_id,
-            url=args['url'],
-            args=args,
-            account=current_user
+            tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
         )
 
         return app, 201
 
 
 class AppApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -165,14 +166,14 @@ class AppApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('description', type=str, location='json')
-        parser.add_argument('icon_type', type=str, location='json')
-        parser.add_argument('icon', type=str, location='json')
-        parser.add_argument('icon_background', type=str, location='json')
-        parser.add_argument('max_active_requests', type=int, location='json')
+        parser.add_argument("name", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("description", type=str, location="json")
+        parser.add_argument("icon_type", type=str, location="json")
+        parser.add_argument("icon", type=str, location="json")
+        parser.add_argument("icon_background", type=str, location="json")
+        parser.add_argument("max_active_requests", type=int, location="json")
         args = parser.parse_args()
 
         app_service = AppService()
@@ -193,7 +194,7 @@ class AppApi(Resource):
         app_service = AppService()
         app_service.delete_app(app_model)
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class AppCopyApi(Resource):
@@ -209,19 +210,16 @@ class AppCopyApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, location='json')
-        parser.add_argument('description', type=str, location='json')
-        parser.add_argument('icon_type', type=str, location='json')
-        parser.add_argument('icon', type=str, location='json')
-        parser.add_argument('icon_background', type=str, location='json')
+        parser.add_argument("name", type=str, location="json")
+        parser.add_argument("description", type=str, location="json")
+        parser.add_argument("icon_type", type=str, location="json")
+        parser.add_argument("icon", type=str, location="json")
+        parser.add_argument("icon_background", type=str, location="json")
         args = parser.parse_args()
 
         data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
         app = AppDslService.import_and_create_new_app(
-            tenant_id=current_user.current_tenant_id,
-            data=data,
-            args=args,
-            account=current_user
+            tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
         )
 
         return app, 201
@@ -240,12 +238,10 @@ class AppExportApi(Resource):
 
         # Add include_secret params
         parser = reqparse.RequestParser()
-        parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
+        parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
         args = parser.parse_args()
 
-        return {
-            "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
-        }
+        return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
 
 
 class AppNameApi(Resource):
@@ -258,13 +254,13 @@ class AppNameApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument("name", type=str, required=True, location="json")
         args = parser.parse_args()
 
         app_service = AppService()
-        app_model = app_service.update_app_name(app_model, args.get('name'))
+        app_model = app_service.update_app_name(app_model, args.get("name"))
 
         return app_model
 
@@ -279,14 +275,14 @@ class AppIconApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('icon', type=str, location='json')
-        parser.add_argument('icon_background', type=str, location='json')
+        parser.add_argument("icon", type=str, location="json")
+        parser.add_argument("icon_background", type=str, location="json")
         args = parser.parse_args()
 
         app_service = AppService()
-        app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
+        app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
 
         return app_model
 
@@ -301,13 +297,13 @@ class AppSiteStatus(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('enable_site', type=bool, required=True, location='json')
+        parser.add_argument("enable_site", type=bool, required=True, location="json")
         args = parser.parse_args()
 
         app_service = AppService()
-        app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
+        app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
 
         return app_model
 
@@ -322,13 +318,13 @@ class AppApiStatus(Resource):
         # The role of the current user in the ta table must be admin or owner
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('enable_api', type=bool, required=True, location='json')
+        parser.add_argument("enable_api", type=bool, required=True, location="json")
         args = parser.parse_args()
 
         app_service = AppService()
-        app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
+        app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
 
         return app_model
 
@@ -339,9 +335,7 @@ class AppTraceApi(Resource):
     @account_initialization_required
     def get(self, app_id):
         """Get app trace"""
-        app_trace_config = OpsTraceManager.get_app_tracing_config(
-            app_id=app_id
-        )
+        app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
 
         return app_trace_config
 
@@ -353,27 +347,27 @@ class AppTraceApi(Resource):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
         parser = reqparse.RequestParser()
-        parser.add_argument('enabled', type=bool, required=True, location='json')
-        parser.add_argument('tracing_provider', type=str, required=True, location='json')
+        parser.add_argument("enabled", type=bool, required=True, location="json")
+        parser.add_argument("tracing_provider", type=str, required=True, location="json")
         args = parser.parse_args()
 
         OpsTraceManager.update_app_tracing_config(
             app_id=app_id,
-            enabled=args['enabled'],
-            tracing_provider=args['tracing_provider'],
+            enabled=args["enabled"],
+            tracing_provider=args["tracing_provider"],
         )
 
         return {"result": "success"}
 
 
-api.add_resource(AppListApi, '/apps')
-api.add_resource(AppImportApi, '/apps/import')
-api.add_resource(AppImportFromUrlApi, '/apps/import/url')
-api.add_resource(AppApi, '/apps/<uuid:app_id>')
-api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
-api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
-api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
-api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
-api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
-api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
-api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')
+api.add_resource(AppListApi, "/apps")
+api.add_resource(AppImportApi, "/apps/import")
+api.add_resource(AppImportFromUrlApi, "/apps/import/url")
+api.add_resource(AppApi, "/apps/<uuid:app_id>")
+api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
+api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
+api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
+api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
+api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
+api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
+api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")

+ 25 - 25
api/controllers/console/app/audio.py

@@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
     @account_initialization_required
     @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
     def post(self, app_model):
-        file = request.files['file']
+        file = request.files["file"]
 
         try:
             response = AudioService.transcript_asr(
@@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource):
 
         try:
             parser = reqparse.RequestParser()
-            parser.add_argument('message_id', type=str, location='json')
-            parser.add_argument('text', type=str, location='json')
-            parser.add_argument('voice', type=str, location='json')
-            parser.add_argument('streaming', type=bool, location='json')
+            parser.add_argument("message_id", type=str, location="json")
+            parser.add_argument("text", type=str, location="json")
+            parser.add_argument("voice", type=str, location="json")
+            parser.add_argument("streaming", type=bool, location="json")
             args = parser.parse_args()
 
-            message_id = args.get('message_id', None)
-            text = args.get('text', None)
-            if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
-                    and app_model.workflow
-                    and app_model.workflow.features_dict):
-                text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
-                voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
+            message_id = args.get("message_id", None)
+            text = args.get("text", None)
+            if (
+                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                and app_model.workflow
+                and app_model.workflow.features_dict
+            ):
+                text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
+                voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
             else:
                 try:
-                    voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
-                        'voice')
+                    voice = (
+                        args.get("voice")
+                        if args.get("voice")
+                        else app_model.app_model_config.text_to_speech_dict.get("voice")
+                    )
                 except Exception:
                     voice = None
-            response = AudioService.transcript_tts(
-                app_model=app_model,
-                text=text,
-                message_id=message_id,
-                voice=voice
-            )
+            response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
             return response
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
@@ -145,12 +145,12 @@ class TextModesApi(Resource):
     def get(self, app_model):
         try:
             parser = reqparse.RequestParser()
-            parser.add_argument('language', type=str, required=True, location='args')
+            parser.add_argument("language", type=str, required=True, location="args")
             args = parser.parse_args()
 
             response = AudioService.transcript_tts_voices(
                 tenant_id=app_model.tenant_id,
-                language=args['language'],
+                language=args["language"],
             )
 
             return response
@@ -179,6 +179,6 @@ class TextModesApi(Resource):
             raise InternalServerError()
 
 
-api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
-api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
-api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
+api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
+api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
+api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")

+ 25 - 34
api/controllers/console/app/completion.py

@@ -35,33 +35,28 @@ from services.app_generate_service import AppGenerateService
 
 # define completion message api for user
 class CompletionMessageApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     @get_app_model(mode=AppMode.COMPLETION)
     def post(self, app_model):
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, location='json', default='')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('model_config', type=dict, required=True, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, location="json", default="")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("model_config", type=dict, required=True, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
         args = parser.parse_args()
 
-        streaming = args['response_mode'] != 'blocking'
-        args['auto_generate_name'] = False
+        streaming = args["response_mode"] != "blocking"
+        args["auto_generate_name"] = False
 
         account = flask_login.current_user
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=account,
-                args=args,
-                invoke_from=InvokeFrom.DEBUGGER,
-                streaming=streaming
+                app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -97,7 +92,7 @@ class CompletionMessageStopApi(Resource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class ChatMessageApi(Resource):
@@ -107,27 +102,23 @@ class ChatMessageApi(Resource):
     @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
     def post(self, app_model):
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, required=True, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('model_config', type=dict, required=True, location='json')
-        parser.add_argument('conversation_id', type=uuid_value, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, required=True, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("model_config", type=dict, required=True, location="json")
+        parser.add_argument("conversation_id", type=uuid_value, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
         args = parser.parse_args()
 
-        streaming = args['response_mode'] != 'blocking'
-        args['auto_generate_name'] = False
+        streaming = args["response_mode"] != "blocking"
+        args["auto_generate_name"] = False
 
         account = flask_login.current_user
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=account,
-                args=args,
-                invoke_from=InvokeFrom.DEBUGGER,
-                streaming=streaming
+                app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -163,10 +154,10 @@ class ChatMessageStopApi(Resource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
-api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
-api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
-api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
-api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')
+api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
+api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
+api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
+api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")

+ 100 - 93
api/controllers/console/app/conversation.py

@@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat
 
 
 class CompletionConversationApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -36,24 +35,23 @@ class CompletionConversationApi(Resource):
         if not current_user.is_editor:
             raise Forbidden()
         parser = reqparse.RequestParser()
-        parser.add_argument('keyword', type=str, location='args')
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('annotation_status', type=str,
-                            choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
-        parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
+        parser.add_argument("keyword", type=str, location="args")
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument(
+            "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
+        )
+        parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
         args = parser.parse_args()
 
-        query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
+        query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
 
-        if args['keyword']:
-            query = query.join(
-                Message, Message.conversation_id == Conversation.id
-            ).filter(
+        if args["keyword"]:
+            query = query.join(Message, Message.conversation_id == Conversation.id).filter(
                 or_(
-                    Message.query.ilike('%{}%'.format(args['keyword'])),
-                    Message.answer.ilike('%{}%'.format(args['keyword']))
+                    Message.query.ilike("%{}%".format(args["keyword"])),
+                    Message.answer.ilike("%{}%".format(args["keyword"])),
                 )
             )
 
@@ -61,8 +59,8 @@ class CompletionConversationApi(Resource):
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
@@ -70,8 +68,8 @@ class CompletionConversationApi(Resource):
 
             query = query.where(Conversation.created_at >= start_datetime_utc)
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=59)
 
             end_datetime_timezone = timezone.localize(end_datetime)
@@ -79,29 +77,25 @@ class CompletionConversationApi(Resource):
 
             query = query.where(Conversation.created_at < end_datetime_utc)
 
-        if args['annotation_status'] == "annotated":
+        if args["annotation_status"] == "annotated":
             query = query.options(joinedload(Conversation.message_annotations)).join(
                 MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
             )
-        elif args['annotation_status'] == "not_annotated":
-            query = query.outerjoin(
-                MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
-            ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
+        elif args["annotation_status"] == "not_annotated":
+            query = (
+                query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
+                .group_by(Conversation.id)
+                .having(func.count(MessageAnnotation.id) == 0)
+            )
 
         query = query.order_by(Conversation.created_at.desc())
 
-        conversations = db.paginate(
-            query,
-            page=args['page'],
-            per_page=args['limit'],
-            error_out=False
-        )
+        conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
 
         return conversations
 
 
 class CompletionConversationDetailApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource):
             raise Forbidden()
         conversation_id = str(conversation_id)
 
-        conversation = db.session.query(Conversation) \
-            .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
+        conversation = (
+            db.session.query(Conversation)
+            .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
+            .first()
+        )
 
         if not conversation:
             raise NotFound("Conversation Not Exists.")
@@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource):
         conversation.is_deleted = True
         db.session.commit()
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class ChatConversationApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -146,22 +142,28 @@ class ChatConversationApi(Resource):
         if not current_user.is_editor:
             raise Forbidden()
         parser = reqparse.RequestParser()
-        parser.add_argument('keyword', type=str, location='args')
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('annotation_status', type=str,
-                            choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
-        parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
-        parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
-                            required=False, default='-updated_at', location='args')
+        parser.add_argument("keyword", type=str, location="args")
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument(
+            "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
+        )
+        parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
+        parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
+        parser.add_argument(
+            "sort_by",
+            type=str,
+            choices=["created_at", "-created_at", "updated_at", "-updated_at"],
+            required=False,
+            default="-updated_at",
+            location="args",
+        )
         args = parser.parse_args()
 
         subquery = (
             db.session.query(
-                Conversation.id.label('conversation_id'),
-                EndUser.session_id.label('from_end_user_session_id')
+                Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
             )
             .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
             .subquery()
@@ -169,28 +171,31 @@ class ChatConversationApi(Resource):
 
         query = db.select(Conversation).where(Conversation.app_id == app_model.id)
 
-        if args['keyword']:
-            keyword_filter = '%{}%'.format(args['keyword'])
-            query = query.join(
-                Message, Message.conversation_id == Conversation.id,
-            ).join(
-                subquery, subquery.c.conversation_id == Conversation.id
-            ).filter(
-                or_(
-                    Message.query.ilike(keyword_filter),
-                    Message.answer.ilike(keyword_filter),
-                    Conversation.name.ilike(keyword_filter),
-                    Conversation.introduction.ilike(keyword_filter),
-                    subquery.c.from_end_user_session_id.ilike(keyword_filter)
-                ),
+        if args["keyword"]:
+            keyword_filter = "%{}%".format(args["keyword"])
+            query = (
+                query.join(
+                    Message,
+                    Message.conversation_id == Conversation.id,
+                )
+                .join(subquery, subquery.c.conversation_id == Conversation.id)
+                .filter(
+                    or_(
+                        Message.query.ilike(keyword_filter),
+                        Message.answer.ilike(keyword_filter),
+                        Conversation.name.ilike(keyword_filter),
+                        Conversation.introduction.ilike(keyword_filter),
+                        subquery.c.from_end_user_session_id.ilike(keyword_filter),
+                    ),
+                )
             )
 
         account = current_user
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
@@ -198,8 +203,8 @@ class ChatConversationApi(Resource):
 
             query = query.where(Conversation.created_at >= start_datetime_utc)
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=59)
 
             end_datetime_timezone = timezone.localize(end_datetime)
@@ -207,50 +212,46 @@ class ChatConversationApi(Resource):
 
             query = query.where(Conversation.created_at < end_datetime_utc)
 
-        if args['annotation_status'] == "annotated":
+        if args["annotation_status"] == "annotated":
             query = query.options(joinedload(Conversation.message_annotations)).join(
                 MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
             )
-        elif args['annotation_status'] == "not_annotated":
-            query = query.outerjoin(
-                MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
-            ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
+        elif args["annotation_status"] == "not_annotated":
+            query = (
+                query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
+                .group_by(Conversation.id)
+                .having(func.count(MessageAnnotation.id) == 0)
+            )
 
-        if args['message_count_gte'] and args['message_count_gte'] >= 1:
+        if args["message_count_gte"] and args["message_count_gte"] >= 1:
             query = (
                 query.options(joinedload(Conversation.messages))
                 .join(Message, Message.conversation_id == Conversation.id)
                 .group_by(Conversation.id)
-                .having(func.count(Message.id) >= args['message_count_gte'])
+                .having(func.count(Message.id) >= args["message_count_gte"])
             )
 
         if app_model.mode == AppMode.ADVANCED_CHAT.value:
             query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
 
-        match args['sort_by']:
-            case 'created_at':
+        match args["sort_by"]:
+            case "created_at":
                 query = query.order_by(Conversation.created_at.asc())
-            case '-created_at':
+            case "-created_at":
                 query = query.order_by(Conversation.created_at.desc())
-            case 'updated_at':
+            case "updated_at":
                 query = query.order_by(Conversation.updated_at.asc())
-            case '-updated_at':
+            case "-updated_at":
                 query = query.order_by(Conversation.updated_at.desc())
             case _:
                 query = query.order_by(Conversation.created_at.desc())
 
-        conversations = db.paginate(
-            query,
-            page=args['page'],
-            per_page=args['limit'],
-            error_out=False
-        )
+        conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
 
         return conversations
 
 
 class ChatConversationDetailApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -272,8 +273,11 @@ class ChatConversationDetailApi(Resource):
             raise Forbidden()
         conversation_id = str(conversation_id)
 
-        conversation = db.session.query(Conversation) \
-            .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
+        conversation = (
+            db.session.query(Conversation)
+            .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
+            .first()
+        )
 
         if not conversation:
             raise NotFound("Conversation Not Exists.")
@@ -281,18 +285,21 @@ class ChatConversationDetailApi(Resource):
         conversation.is_deleted = True
         db.session.commit()
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
-api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
-api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
-api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
-api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
+api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
+api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
+api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
+api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
 
 
 def _get_conversation(app_model, conversation_id):
-    conversation = db.session.query(Conversation) \
-        .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
+    conversation = (
+        db.session.query(Conversation)
+        .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
+        .first()
+    )
 
     if not conversation:
         raise NotFound("Conversation Not Exists.")

+ 12 - 12
api/controllers/console/app/conversation_variables.py

@@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
     @marshal_with(paginated_conversation_variable_fields)
     def get(self, app_model):
         parser = reqparse.RequestParser()
-        parser.add_argument('conversation_id', type=str, location='args')
+        parser.add_argument("conversation_id", type=str, location="args")
         args = parser.parse_args()
 
         stmt = (
@@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
             .where(ConversationVariable.app_id == app_model.id)
             .order_by(ConversationVariable.created_at)
         )
-        if args['conversation_id']:
-            stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
+        if args["conversation_id"]:
+            stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
         else:
-            raise ValueError('conversation_id is required')
+            raise ValueError("conversation_id is required")
 
         # NOTE: This is a temporary solution to avoid performance issues.
         page = 1
@@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
             rows = session.scalars(stmt).all()
 
         return {
-            'page': page,
-            'limit': page_size,
-            'total': len(rows),
-            'has_more': False,
-            'data': [
+            "page": page,
+            "limit": page_size,
+            "total": len(rows),
+            "has_more": False,
+            "data": [
                 {
-                    'created_at': row.created_at,
-                    'updated_at': row.updated_at,
+                    "created_at": row.created_at,
+                    "updated_at": row.updated_at,
                     **row.to_variable().model_dump(),
                 }
                 for row in rows
@@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
         }
 
 
-api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')
+api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")

+ 27 - 23
api/controllers/console/app/error.py

@@ -2,116 +2,120 @@ from libs.exception import BaseHTTPException
 
 
 class AppNotFoundError(BaseHTTPException):
-    error_code = 'app_not_found'
+    error_code = "app_not_found"
     description = "App not found."
     code = 404
 
 
 class ProviderNotInitializeError(BaseHTTPException):
-    error_code = 'provider_not_initialize'
-    description = "No valid model provider credentials found. " \
-                  "Please go to Settings -> Model Provider to complete your provider credentials."
+    error_code = "provider_not_initialize"
+    description = (
+        "No valid model provider credentials found. "
+        "Please go to Settings -> Model Provider to complete your provider credentials."
+    )
     code = 400
 
 
 class ProviderQuotaExceededError(BaseHTTPException):
-    error_code = 'provider_quota_exceeded'
-    description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
-                  "Please go to Settings -> Model Provider to complete your own provider credentials."
+    error_code = "provider_quota_exceeded"
+    description = (
+        "Your quota for Dify Hosted Model Provider has been exhausted. "
+        "Please go to Settings -> Model Provider to complete your own provider credentials."
+    )
     code = 400
 
 
 class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
-    error_code = 'model_currently_not_support'
+    error_code = "model_currently_not_support"
     description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
     code = 400
 
 
 class ConversationCompletedError(BaseHTTPException):
-    error_code = 'conversation_completed'
+    error_code = "conversation_completed"
     description = "The conversation has ended. Please start a new conversation."
     code = 400
 
 
 class AppUnavailableError(BaseHTTPException):
-    error_code = 'app_unavailable'
+    error_code = "app_unavailable"
     description = "App unavailable, please check your app configurations."
     code = 400
 
 
 class CompletionRequestError(BaseHTTPException):
-    error_code = 'completion_request_error'
+    error_code = "completion_request_error"
     description = "Completion request failed."
     code = 400
 
 
 class AppMoreLikeThisDisabledError(BaseHTTPException):
-    error_code = 'app_more_like_this_disabled'
+    error_code = "app_more_like_this_disabled"
     description = "The 'More like this' feature is disabled. Please refresh your page."
     code = 403
 
 
 class NoAudioUploadedError(BaseHTTPException):
-    error_code = 'no_audio_uploaded'
+    error_code = "no_audio_uploaded"
     description = "Please upload your audio."
     code = 400
 
 
 class AudioTooLargeError(BaseHTTPException):
-    error_code = 'audio_too_large'
+    error_code = "audio_too_large"
     description = "Audio size exceeded. {message}"
     code = 413
 
 
 class UnsupportedAudioTypeError(BaseHTTPException):
-    error_code = 'unsupported_audio_type'
+    error_code = "unsupported_audio_type"
     description = "Audio type not allowed."
     code = 415
 
 
 class ProviderNotSupportSpeechToTextError(BaseHTTPException):
-    error_code = 'provider_not_support_speech_to_text'
+    error_code = "provider_not_support_speech_to_text"
     description = "Provider not support speech to text."
     code = 400
 
 
 class NoFileUploadedError(BaseHTTPException):
-    error_code = 'no_file_uploaded'
+    error_code = "no_file_uploaded"
     description = "Please upload your file."
     code = 400
 
 
 class TooManyFilesError(BaseHTTPException):
-    error_code = 'too_many_files'
+    error_code = "too_many_files"
     description = "Only one file is allowed."
     code = 400
 
 
 class DraftWorkflowNotExist(BaseHTTPException):
-    error_code = 'draft_workflow_not_exist'
+    error_code = "draft_workflow_not_exist"
     description = "Draft workflow need to be initialized."
     code = 400
 
 
 class DraftWorkflowNotSync(BaseHTTPException):
-    error_code = 'draft_workflow_not_sync'
+    error_code = "draft_workflow_not_sync"
     description = "Workflow graph might have been modified, please refresh and resubmit."
     code = 400
 
 
 class TracingConfigNotExist(BaseHTTPException):
-    error_code = 'trace_config_not_exist'
+    error_code = "trace_config_not_exist"
     description = "Trace config not exist."
     code = 400
 
 
 class TracingConfigIsExist(BaseHTTPException):
-    error_code = 'trace_config_is_exist'
+    error_code = "trace_config_is_exist"
     description = "Trace config is exist."
     code = 400
 
 
 class TracingConfigCheckError(BaseHTTPException):
-    error_code = 'trace_config_check_error'
+    error_code = "trace_config_check_error"
     description = "Invalid Credentials."
     code = 400

+ 9 - 9
api/controllers/console/app/generator.py

@@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
+        parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
         args = parser.parse_args()
 
         account = current_user
-        PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
+        PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
 
         try:
             rules = LLMGenerator.generate_rule_config(
                 tenant_id=account.current_tenant_id,
-                instruction=args['instruction'],
-                model_config=args['model_config'],
-                no_variable=args['no_variable'],
-                rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
+                instruction=args["instruction"],
+                model_config=args["model_config"],
+                no_variable=args["no_variable"],
+                rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
@@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
         return rules
 
 
-api.add_resource(RuleGenerateApi, '/rule-generate')
+api.add_resource(RuleGenerateApi, "/rule-generate")

+ 75 - 72
api/controllers/console/app/message.py

@@ -33,9 +33,9 @@ from services.message_service import MessageService
 
 class ChatMessageListApi(Resource):
     message_infinite_scroll_pagination_fields = {
-        'limit': fields.Integer,
-        'has_more': fields.Boolean,
-        'data': fields.List(fields.Nested(message_detail_fields))
+        "limit": fields.Integer,
+        "has_more": fields.Boolean,
+        "data": fields.List(fields.Nested(message_detail_fields)),
     }
 
     @setup_required
@@ -45,55 +45,69 @@ class ChatMessageListApi(Resource):
     @marshal_with(message_infinite_scroll_pagination_fields)
     def get(self, app_model):
         parser = reqparse.RequestParser()
-        parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
-        parser.add_argument('first_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
+        parser.add_argument("first_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
-        conversation = db.session.query(Conversation).filter(
-            Conversation.id == args['conversation_id'],
-            Conversation.app_id == app_model.id
-        ).first()
+        conversation = (
+            db.session.query(Conversation)
+            .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
+            .first()
+        )
 
         if not conversation:
             raise NotFound("Conversation Not Exists.")
 
-        if args['first_id']:
-            first_message = db.session.query(Message) \
-                .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
+        if args["first_id"]:
+            first_message = (
+                db.session.query(Message)
+                .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
+                .first()
+            )
 
             if not first_message:
                 raise NotFound("First message not found")
 
-            history_messages = db.session.query(Message).filter(
-                Message.conversation_id == conversation.id,
-                Message.created_at < first_message.created_at,
-                Message.id != first_message.id
-            ) \
-                .order_by(Message.created_at.desc()).limit(args['limit']).all()
+            history_messages = (
+                db.session.query(Message)
+                .filter(
+                    Message.conversation_id == conversation.id,
+                    Message.created_at < first_message.created_at,
+                    Message.id != first_message.id,
+                )
+                .order_by(Message.created_at.desc())
+                .limit(args["limit"])
+                .all()
+            )
         else:
-            history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
-                .order_by(Message.created_at.desc()).limit(args['limit']).all()
+            history_messages = (
+                db.session.query(Message)
+                .filter(Message.conversation_id == conversation.id)
+                .order_by(Message.created_at.desc())
+                .limit(args["limit"])
+                .all()
+            )
 
         has_more = False
-        if len(history_messages) == args['limit']:
+        if len(history_messages) == args["limit"]:
             current_page_first_message = history_messages[-1]
-            rest_count = db.session.query(Message).filter(
-                Message.conversation_id == conversation.id,
-                Message.created_at < current_page_first_message.created_at,
-                Message.id != current_page_first_message.id
-            ).count()
+            rest_count = (
+                db.session.query(Message)
+                .filter(
+                    Message.conversation_id == conversation.id,
+                    Message.created_at < current_page_first_message.created_at,
+                    Message.id != current_page_first_message.id,
+                )
+                .count()
+            )
 
             if rest_count > 0:
                 has_more = True
 
         history_messages = list(reversed(history_messages))
 
-        return InfiniteScrollPagination(
-            data=history_messages,
-            limit=args['limit'],
-            has_more=has_more
-        )
+        return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
 
 
 class MessageFeedbackApi(Resource):
@@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource):
     @get_app_model
     def post(self, app_model):
         parser = reqparse.RequestParser()
-        parser.add_argument('message_id', required=True, type=uuid_value, location='json')
-        parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
+        parser.add_argument("message_id", required=True, type=uuid_value, location="json")
+        parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
         args = parser.parse_args()
 
-        message_id = str(args['message_id'])
+        message_id = str(args["message_id"])
 
-        message = db.session.query(Message).filter(
-            Message.id == message_id,
-            Message.app_id == app_model.id
-        ).first()
+        message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
 
         if not message:
             raise NotFound("Message Not Exists.")
 
         feedback = message.admin_feedback
 
-        if not args['rating'] and feedback:
+        if not args["rating"] and feedback:
             db.session.delete(feedback)
-        elif args['rating'] and feedback:
-            feedback.rating = args['rating']
-        elif not args['rating'] and not feedback:
-            raise ValueError('rating cannot be None when feedback not exists')
+        elif args["rating"] and feedback:
+            feedback.rating = args["rating"]
+        elif not args["rating"] and not feedback:
+            raise ValueError("rating cannot be None when feedback not exists")
         else:
             feedback = MessageFeedback(
                 app_id=app_model.id,
                 conversation_id=message.conversation_id,
                 message_id=message.id,
-                rating=args['rating'],
-                from_source='admin',
-                from_account_id=current_user.id
+                rating=args["rating"],
+                from_source="admin",
+                from_account_id=current_user.id,
             )
             db.session.add(feedback)
 
         db.session.commit()
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class MessageAnnotationApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('annotation')
+    @cloud_edition_billing_resource_check("annotation")
     @get_app_model
     @marshal_with(annotation_fields)
     def post(self, app_model):
@@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('message_id', required=False, type=uuid_value, location='json')
-        parser.add_argument('question', required=True, type=str, location='json')
-        parser.add_argument('answer', required=True, type=str, location='json')
-        parser.add_argument('annotation_reply', required=False, type=dict, location='json')
+        parser.add_argument("message_id", required=False, type=uuid_value, location="json")
+        parser.add_argument("question", required=True, type=str, location="json")
+        parser.add_argument("answer", required=True, type=str, location="json")
+        parser.add_argument("annotation_reply", required=False, type=dict, location="json")
         args = parser.parse_args()
         annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
 
@@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource):
     @account_initialization_required
     @get_app_model
     def get(self, app_model):
-        count = db.session.query(MessageAnnotation).filter(
-            MessageAnnotation.app_id == app_model.id
-        ).count()
+        count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
 
-        return {'count': count}
+        return {"count": count}
 
 
 class MessageSuggestedQuestionApi(Resource):
@@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource):
 
         try:
             questions = MessageService.get_suggested_questions_after_answer(
-                app_model=app_model,
-                message_id=message_id,
-                user=current_user,
-                invoke_from=InvokeFrom.DEBUGGER
+                app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
             )
         except MessageNotExistsError:
             raise NotFound("Message not found")
@@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
-        return {'data': questions}
+        return {"data": questions}
 
 
 class MessageApi(Resource):
@@ -221,10 +227,7 @@ class MessageApi(Resource):
     def get(self, app_model, message_id):
         message_id = str(message_id)
 
-        message = db.session.query(Message).filter(
-            Message.id == message_id,
-            Message.app_id == app_model.id
-        ).first()
+        message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
 
         if not message:
             raise NotFound("Message Not Exists.")
@@ -232,9 +235,9 @@ class MessageApi(Resource):
         return message
 
 
-api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
-api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
-api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
-api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
-api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
-api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')
+api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
+api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
+api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
+api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
+api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
+api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")

+ 19 - 24
api/controllers/console/app/model_config.py

@@ -19,19 +19,15 @@ from services.app_model_config_service import AppModelConfigService
 
 
 class ModelConfigResource(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
     def post(self, app_model):
-        
         """Modify app model config"""
         # validate config
         model_configuration = AppModelConfigService.validate_configuration(
-            tenant_id=current_user.current_tenant_id,
-            config=request.json,
-            app_mode=AppMode.value_of(app_model.mode)
+            tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
         )
 
         new_app_model_config = AppModelConfig(
@@ -41,15 +37,15 @@ class ModelConfigResource(Resource):
 
         if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
             # get original app model config
-            original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
-                AppModelConfig.id == app_model.app_model_config_id
-            ).first()
+            original_app_model_config: AppModelConfig = (
+                db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
+            )
             agent_mode = original_app_model_config.agent_mode_dict
             # decrypt agent tool parameters if it's secret-input
             parameter_map = {}
             masked_parameter_map = {}
             tool_map = {}
-            for tool in agent_mode.get('tools') or []:
+            for tool in agent_mode.get("tools") or []:
                 if not isinstance(tool, dict) or len(tool.keys()) <= 3:
                     continue
 
@@ -66,7 +62,7 @@ class ModelConfigResource(Resource):
                         tool_runtime=tool_runtime,
                         provider_name=agent_tool_entity.provider_id,
                         provider_type=agent_tool_entity.provider_type,
-                        identity_id=f'AGENT.{app_model.id}'
+                        identity_id=f"AGENT.{app_model.id}",
                     )
                 except Exception as e:
                     continue
@@ -79,18 +75,18 @@ class ModelConfigResource(Resource):
                     parameters = {}
                     masked_parameter = {}
 
-                key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
+                key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
                 masked_parameter_map[key] = masked_parameter
                 parameter_map[key] = parameters
                 tool_map[key] = tool_runtime
 
             # encrypt agent tool parameters if it's secret-input
             agent_mode = new_app_model_config.agent_mode_dict
-            for tool in agent_mode.get('tools') or []:
+            for tool in agent_mode.get("tools") or []:
                 agent_tool_entity = AgentToolEntity(**tool)
 
                 # get tool
-                key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
+                key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
                 if key in tool_map:
                     tool_runtime = tool_map[key]
                 else:
@@ -108,7 +104,7 @@ class ModelConfigResource(Resource):
                     tool_runtime=tool_runtime,
                     provider_name=agent_tool_entity.provider_id,
                     provider_type=agent_tool_entity.provider_type,
-                    identity_id=f'AGENT.{app_model.id}'
+                    identity_id=f"AGENT.{app_model.id}",
                 )
                 manager.delete_tool_parameters_cache()
 
@@ -116,15 +112,17 @@ class ModelConfigResource(Resource):
                 if agent_tool_entity.tool_parameters:
                     if key not in masked_parameter_map:
                         continue
-                    
+
                     for masked_key, masked_value in masked_parameter_map[key].items():
-                        if masked_key in agent_tool_entity.tool_parameters and \
-                                agent_tool_entity.tool_parameters[masked_key] == masked_value:
+                        if (
+                            masked_key in agent_tool_entity.tool_parameters
+                            and agent_tool_entity.tool_parameters[masked_key] == masked_value
+                        ):
                             agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
 
                 # encrypt parameters
                 if agent_tool_entity.tool_parameters:
-                    tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
+                    tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
 
             # update app model config
             new_app_model_config.agent_mode = json.dumps(agent_mode)
@@ -135,12 +133,9 @@ class ModelConfigResource(Resource):
         app_model.app_model_config_id = new_app_model_config.id
         db.session.commit()
 
-        app_model_config_was_updated.send(
-            app_model,
-            app_model_config=new_app_model_config
-        )
+        app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')
+api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")

+ 12 - 21
api/controllers/console/app/ops_trace.py

@@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource):
     @account_initialization_required
     def get(self, app_id):
         parser = reqparse.RequestParser()
-        parser.add_argument('tracing_provider', type=str, required=True, location='args')
+        parser.add_argument("tracing_provider", type=str, required=True, location="args")
         args = parser.parse_args()
 
         try:
-            trace_config = OpsService.get_tracing_app_config(
-                app_id=app_id, tracing_provider=args['tracing_provider']
-                )
+            trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
             if not trace_config:
                 return {"has_not_configured": True}
             return trace_config
@@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource):
     def post(self, app_id):
         """Create a new trace app configuration"""
         parser = reqparse.RequestParser()
-        parser.add_argument('tracing_provider', type=str, required=True, location='json')
-        parser.add_argument('tracing_config', type=dict, required=True, location='json')
+        parser.add_argument("tracing_provider", type=str, required=True, location="json")
+        parser.add_argument("tracing_config", type=dict, required=True, location="json")
         args = parser.parse_args()
 
         try:
             result = OpsService.create_tracing_app_config(
-                app_id=app_id,
-                tracing_provider=args['tracing_provider'],
-                tracing_config=args['tracing_config']
+                app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
             )
             if not result:
                 raise TracingConfigIsExist()
-            if result.get('error'):
+            if result.get("error"):
                 raise TracingConfigCheckError()
             return result
         except Exception as e:
@@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource):
     def patch(self, app_id):
         """Update an existing trace app configuration"""
         parser = reqparse.RequestParser()
-        parser.add_argument('tracing_provider', type=str, required=True, location='json')
-        parser.add_argument('tracing_config', type=dict, required=True, location='json')
+        parser.add_argument("tracing_provider", type=str, required=True, location="json")
+        parser.add_argument("tracing_config", type=dict, required=True, location="json")
         args = parser.parse_args()
 
         try:
             result = OpsService.update_tracing_app_config(
-                app_id=app_id,
-                tracing_provider=args['tracing_provider'],
-                tracing_config=args['tracing_config']
+                app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
             )
             if not result:
                 raise TracingConfigNotExist()
@@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource):
     def delete(self, app_id):
         """Delete an existing trace app configuration"""
         parser = reqparse.RequestParser()
-        parser.add_argument('tracing_provider', type=str, required=True, location='args')
+        parser.add_argument("tracing_provider", type=str, required=True, location="args")
         args = parser.parse_args()
 
         try:
-            result = OpsService.delete_tracing_app_config(
-                app_id=app_id,
-                tracing_provider=args['tracing_provider']
-            )
+            result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
             if not result:
                 raise TracingConfigNotExist()
             return {"result": "success"}
@@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource):
             raise e
 
 
-api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')
+api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")

+ 35 - 38
api/controllers/console/app/site.py

@@ -15,23 +15,23 @@ from models.model import Site
 
 def parse_app_site_args():
     parser = reqparse.RequestParser()
-    parser.add_argument('title', type=str, required=False, location='json')
-    parser.add_argument('icon_type', type=str, required=False, location='json')
-    parser.add_argument('icon', type=str, required=False, location='json')
-    parser.add_argument('icon_background', type=str, required=False, location='json')
-    parser.add_argument('description', type=str, required=False, location='json')
-    parser.add_argument('default_language', type=supported_language, required=False, location='json')
-    parser.add_argument('chat_color_theme', type=str, required=False, location='json')
-    parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
-    parser.add_argument('customize_domain', type=str, required=False, location='json')
-    parser.add_argument('copyright', type=str, required=False, location='json')
-    parser.add_argument('privacy_policy', type=str, required=False, location='json')
-    parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
-    parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
-                        required=False,
-                        location='json')
-    parser.add_argument('prompt_public', type=bool, required=False, location='json')
-    parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
+    parser.add_argument("title", type=str, required=False, location="json")
+    parser.add_argument("icon_type", type=str, required=False, location="json")
+    parser.add_argument("icon", type=str, required=False, location="json")
+    parser.add_argument("icon_background", type=str, required=False, location="json")
+    parser.add_argument("description", type=str, required=False, location="json")
+    parser.add_argument("default_language", type=supported_language, required=False, location="json")
+    parser.add_argument("chat_color_theme", type=str, required=False, location="json")
+    parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
+    parser.add_argument("customize_domain", type=str, required=False, location="json")
+    parser.add_argument("copyright", type=str, required=False, location="json")
+    parser.add_argument("privacy_policy", type=str, required=False, location="json")
+    parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
+    parser.add_argument(
+        "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
+    )
+    parser.add_argument("prompt_public", type=bool, required=False, location="json")
+    parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
     return parser.parse_args()
 
 
@@ -48,26 +48,24 @@ class AppSite(Resource):
         if not current_user.is_editor:
             raise Forbidden()
 
-        site = db.session.query(Site). \
-            filter(Site.app_id == app_model.id). \
-            one_or_404()
+        site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
 
         for attr_name in [
-            'title',
-            'icon_type',
-            'icon',
-            'icon_background',
-            'description',
-            'default_language',
-            'chat_color_theme',
-            'chat_color_theme_inverted',
-            'customize_domain',
-            'copyright',
-            'privacy_policy',
-            'custom_disclaimer',
-            'customize_token_strategy',
-            'prompt_public',
-            'show_workflow_steps'
+            "title",
+            "icon_type",
+            "icon",
+            "icon_background",
+            "description",
+            "default_language",
+            "chat_color_theme",
+            "chat_color_theme_inverted",
+            "customize_domain",
+            "copyright",
+            "privacy_policy",
+            "custom_disclaimer",
+            "customize_token_strategy",
+            "prompt_public",
+            "show_workflow_steps",
         ]:
             value = args.get(attr_name)
             if value is not None:
@@ -79,7 +77,6 @@ class AppSite(Resource):
 
 
 class AppSiteAccessTokenReset(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -101,5 +98,5 @@ class AppSiteAccessTokenReset(Resource):
         return site
 
 
-api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
-api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')
+api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
+api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")

+ 131 - 161
api/controllers/console/app/statistic.py

@@ -17,7 +17,6 @@ from models.model import AppMode
 
 
 class DailyConversationStatistic(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
         SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
             FROM messages where app_id = :app_id 
-        '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+        """
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'conversation_count': i.conversation_count
-                })
+                response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
 
-        return jsonify({
-            'data': response_data
-        })
+        return jsonify({"data": response_data})
 
 
 class DailyTerminalsStatistic(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -86,54 +79,49 @@ class DailyTerminalsStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
                 SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
                     FROM messages where app_id = :app_id 
-                '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+                """
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
-            rs = conn.execute(db.text(sql_query), arg_dict)            
+            rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'terminal_count': i.terminal_count
-                })
+                response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
 
-        return jsonify({
-            'data': response_data
-        })
+        return jsonify({"data": response_data})
 
 
 class DailyTokenCostStatistic(Resource):
@@ -145,58 +133,53 @@ class DailyTokenCostStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
                 SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
                     (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
                     sum(total_price) as total_price
                     FROM messages where app_id = :app_id 
-                '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+                """
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'token_count': i.token_count,
-                    'total_price': i.total_price,
-                    'currency': 'USD'
-                })
+                response_data.append(
+                    {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
+                )
 
-        return jsonify({
-            'data': response_data
-        })
+        return jsonify({"data": response_data})
 
 
 class AverageSessionInteractionStatistic(Resource):
@@ -208,8 +191,8 @@ class AverageSessionInteractionStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
         sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
@@ -218,30 +201,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
     FROM conversations c
     JOIN messages m ON c.id = m.conversation_id
     WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and c.created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and c.created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and c.created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and c.created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
         sql_query += """
         GROUP BY m.conversation_id) subquery
@@ -250,18 +233,15 @@ GROUP BY date
 ORDER BY date"""
 
         response_data = []
-        
+
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'interactions': float(i.interactions.quantize(Decimal('0.01')))
-                })
+                response_data.append(
+                    {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
+                )
 
-        return jsonify({
-            'data': response_data
-        })
+        return jsonify({"data": response_data})
 
 
 class UserSatisfactionRateStatistic(Resource):
@@ -273,57 +253,57 @@ class UserSatisfactionRateStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
                         SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
                             COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count 
                             FROM messages m
                             LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
                             WHERE m.app_id = :app_id 
-                        '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+                        """
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and m.created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and m.created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and m.created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and m.created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
-                })
+                response_data.append(
+                    {
+                        "date": str(i.date),
+                        "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
+                    }
+                )
 
-        return jsonify({
-            'data': response_data
-        })
+        return jsonify({"data": response_data})
 
 
 class AverageResponseTimeStatistic(Resource):
@@ -335,56 +315,51 @@ class AverageResponseTimeStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
                 SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
                     AVG(provider_response_latency) as latency
                     FROM messages
                     WHERE app_id = :app_id
-                '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+                """
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
-            rs = conn.execute(db.text(sql_query), arg_dict)            
+            rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'latency': round(i.latency * 1000, 4)
-                })
+                response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
 
-        return jsonify({
-            'data': response_data
-        })
+        return jsonify({"data": response_data})
 
 
 class TokensPerSecondStatistic(Resource):
@@ -396,63 +371,58 @@ class TokensPerSecondStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
+        sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
     CASE 
         WHEN SUM(provider_response_latency) = 0 THEN 0
         ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
     END as tokens_per_second
 FROM messages
-WHERE app_id = :app_id'''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
+WHERE app_id = :app_id"""
+        arg_dict = {"tz": account.timezone, "app_id": app_model.id}
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'tps': round(i.tokens_per_second, 4)
-                })
-
-        return jsonify({
-            'data': response_data
-        })
-
-
-api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
-api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
-api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
-api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
-api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
-api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
-api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')
+                response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
+
+        return jsonify({"data": response_data})
+
+
+api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
+api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
+api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
+api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
+api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
+api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
+api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")

+ 88 - 113
api/controllers/console/app/workflow.py

@@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
-        content_type = request.headers.get('Content-Type', '')
 
-        if 'application/json' in content_type:
+        content_type = request.headers.get("Content-Type", "")
+
+        if "application/json" in content_type:
             parser = reqparse.RequestParser()
-            parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
-            parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
-            parser.add_argument('hash', type=str, required=False, location='json')
+            parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
+            parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
+            parser.add_argument("hash", type=str, required=False, location="json")
             # TODO: set this to required=True after frontend is updated
-            parser.add_argument('environment_variables', type=list, required=False, location='json')
-            parser.add_argument('conversation_variables', type=list, required=False, location='json')
+            parser.add_argument("environment_variables", type=list, required=False, location="json")
+            parser.add_argument("conversation_variables", type=list, required=False, location="json")
             args = parser.parse_args()
-        elif 'text/plain' in content_type:
+        elif "text/plain" in content_type:
             try:
-                data = json.loads(request.data.decode('utf-8'))
-                if 'graph' not in data or 'features' not in data:
-                    raise ValueError('graph or features not found in data')
+                data = json.loads(request.data.decode("utf-8"))
+                if "graph" not in data or "features" not in data:
+                    raise ValueError("graph or features not found in data")
 
-                if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
-                    raise ValueError('graph or features is not a dict')
+                if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
+                    raise ValueError("graph or features is not a dict")
 
                 args = {
-                    'graph': data.get('graph'),
-                    'features': data.get('features'),
-                    'hash': data.get('hash'),
-                    'environment_variables': data.get('environment_variables'),
-                    'conversation_variables': data.get('conversation_variables'),
+                    "graph": data.get("graph"),
+                    "features": data.get("features"),
+                    "hash": data.get("hash"),
+                    "environment_variables": data.get("environment_variables"),
+                    "conversation_variables": data.get("conversation_variables"),
                 }
             except json.JSONDecodeError:
-                return {'message': 'Invalid JSON data'}, 400
+                return {"message": "Invalid JSON data"}, 400
         else:
             abort(415)
 
         workflow_service = WorkflowService()
 
         try:
-            environment_variables_list = args.get('environment_variables') or []
+            environment_variables_list = args.get("environment_variables") or []
             environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
-            conversation_variables_list = args.get('conversation_variables') or []
+            conversation_variables_list = args.get("conversation_variables") or []
             conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
             workflow = workflow_service.sync_draft_workflow(
                 app_model=app_model,
-                graph=args['graph'],
-                features=args['features'],
-                unique_hash=args.get('hash'),
+                graph=args["graph"],
+                features=args["features"],
+                unique_hash=args.get("hash"),
                 account=current_user,
                 environment_variables=environment_variables,
                 conversation_variables=conversation_variables,
@@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
         return {
             "result": "success",
             "hash": workflow.unique_hash,
-            "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
+            "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
         }
 
 
@@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('data', type=str, required=True, nullable=False, location='json')
+        parser.add_argument("data", type=str, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         workflow = AppDslService.import_and_overwrite_workflow(
-            app_model=app_model,
-            data=args['data'],
-            account=current_user
+            app_model=app_model, data=args["data"], account=current_user
         )
 
         return workflow
@@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, location='json')
-        parser.add_argument('query', type=str, required=True, location='json', default='')
-        parser.add_argument('files', type=list, location='json')
-        parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument("inputs", type=dict, location="json")
+        parser.add_argument("query", type=str, required=True, location="json", default="")
+        parser.add_argument("files", type=list, location="json")
+        parser.add_argument("conversation_id", type=uuid_value, location="json")
         args = parser.parse_args()
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=current_user,
-                args=args,
-                invoke_from=InvokeFrom.DEBUGGER,
-                streaming=True
+                app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
             )
 
             return helper.compact_generate_response(response)
@@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
+
 class AdvancedChatDraftRunIterationNodeApi(Resource):
     @setup_required
     @login_required
@@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, location='json')
+        parser.add_argument("inputs", type=dict, location="json")
         args = parser.parse_args()
 
         try:
             response = AppGenerateService.generate_single_iteration(
-                app_model=app_model,
-                user=current_user,
-                node_id=node_id,
-                args=args,
-                streaming=True
+                app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
             )
 
             return helper.compact_generate_response(response)
@@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
+
 class WorkflowDraftRunIterationNodeApi(Resource):
     @setup_required
     @login_required
@@ -239,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, location='json')
+        parser.add_argument("inputs", type=dict, location="json")
         args = parser.parse_args()
 
         try:
             response = AppGenerateService.generate_single_iteration(
-                app_model=app_model,
-                user=current_user,
-                node_id=node_id,
-                args=args,
-                streaming=True
+                app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
             )
 
             return helper.compact_generate_response(response)
@@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
+
 class DraftWorkflowRunApi(Resource):
     @setup_required
     @login_required
@@ -276,19 +265,15 @@ class DraftWorkflowRunApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
+        parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
         args = parser.parse_args()
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=current_user,
-                args=args,
-                invoke_from=InvokeFrom.DEBUGGER,
-                streaming=True
+                app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
             )
 
             return helper.compact_generate_response(response)
@@ -311,12 +296,10 @@ class WorkflowTaskStopApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
 
-        return {
-            "result": "success"
-        }
+        return {"result": "success"}
 
 
 class DraftWorkflowNodeRunApi(Resource):
@@ -332,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         workflow_service = WorkflowService()
         workflow_node_execution = workflow_service.run_draft_workflow_node(
-            app_model=app_model,
-            node_id=node_id,
-            user_inputs=args.get('inputs'),
-            account=current_user
+            app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
         )
 
         return workflow_node_execution
 
 
 class PublishedWorkflowApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -362,7 +341,7 @@ class PublishedWorkflowApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         # fetch published workflow by app_model
         workflow_service = WorkflowService()
         workflow = workflow_service.get_published_workflow(app_model=app_model)
@@ -381,14 +360,11 @@ class PublishedWorkflowApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         workflow_service = WorkflowService()
         workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
 
-        return {
-            "result": "success",
-            "created_at": TimestampField().format(workflow.created_at)
-        }
+        return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
 
 
 class DefaultBlockConfigsApi(Resource):
@@ -403,7 +379,7 @@ class DefaultBlockConfigsApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         # Get default block configs
         workflow_service = WorkflowService()
         return workflow_service.get_default_block_configs()
@@ -421,24 +397,21 @@ class DefaultBlockConfigApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('q', type=str, location='args')
+        parser.add_argument("q", type=str, location="args")
         args = parser.parse_args()
 
         filters = None
-        if args.get('q'):
+        if args.get("q"):
             try:
-                filters = json.loads(args.get('q'))
+                filters = json.loads(args.get("q"))
             except json.JSONDecodeError:
-                raise ValueError('Invalid filters')
+                raise ValueError("Invalid filters")
 
         # Get default block configs
         workflow_service = WorkflowService()
-        return workflow_service.get_default_block_config(
-            node_type=block_type,
-            filters=filters
-        )
+        return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
 
 
 class ConvertToWorkflowApi(Resource):
@@ -455,41 +428,43 @@ class ConvertToWorkflowApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
-        
+
         if request.data:
             parser = reqparse.RequestParser()
-            parser.add_argument('name', type=str, required=False, nullable=True, location='json')
-            parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
-            parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
-            parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
+            parser.add_argument("name", type=str, required=False, nullable=True, location="json")
+            parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
+            parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
+            parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
             args = parser.parse_args()
         else:
             args = {}
 
         # convert to workflow mode
         workflow_service = WorkflowService()
-        new_app_model = workflow_service.convert_to_workflow(
-            app_model=app_model,
-            account=current_user,
-            args=args
-        )
+        new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
 
         # return app id
         return {
-            'new_app_id': new_app_model.id,
+            "new_app_id": new_app_model.id,
         }
 
 
-api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
-api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
-api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
-api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
-api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
-api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
-api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
-api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
-api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
-api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
-api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
-                                        '/<string:block_type>')
-api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
+api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
+api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
+api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
+api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
+api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
+api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
+api.add_resource(
+    AdvancedChatDraftRunIterationNodeApi,
+    "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
+)
+api.add_resource(
+    WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
+)
+api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
+api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
+api.add_resource(
+    DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
+)
+api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

+ 6 - 7
api/controllers/console/app/workflow_app_log.py

@@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource):
         Get workflow app logs
         """
         parser = reqparse.RequestParser()
-        parser.add_argument('keyword', type=str, location='args')
-        parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
-        parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
+        parser.add_argument("keyword", type=str, location="args")
+        parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
+        parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
         args = parser.parse_args()
 
         # get paginate workflow app logs
         workflow_app_service = WorkflowAppService()
         workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
-            app_model=app_model,
-            args=args
+            app_model=app_model, args=args
         )
 
         return workflow_app_log_pagination
 
 
-api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')
+api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")

+ 11 - 19
api/controllers/console/app/workflow_run.py

@@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
         Get advanced chat app workflow run list
         """
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
         workflow_run_service = WorkflowRunService()
-        result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
-            app_model=app_model,
-            args=args
-        )
+        result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
 
         return result
 
@@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource):
         Get workflow run list
         """
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
         workflow_run_service = WorkflowRunService()
-        result = workflow_run_service.get_paginate_workflow_runs(
-            app_model=app_model,
-            args=args
-        )
+        result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
 
         return result
 
@@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource):
         workflow_run_service = WorkflowRunService()
         node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
 
-        return {
-            'data': node_executions
-        }
+        return {"data": node_executions}
 
 
-api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
-api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
-api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
-api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')
+api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs")
+api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs")
+api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
+api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")

+ 100 - 91
api/controllers/console/app/workflow_statistic.py

@@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
         SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
             FROM workflow_runs 
             WHERE app_id = :app_id 
                 AND triggered_from = :triggered_from
-        '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
+        """
+        arg_dict = {
+            "tz": account.timezone,
+            "app_id": app_model.id,
+            "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
+        }
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'runs': i.runs
-                })
+                response_data.append({"date": str(i.date), "runs": i.runs})
+
+        return jsonify({"data": response_data})
 
-        return jsonify({
-            'data': response_data
-        })
 
 class WorkflowDailyTerminalsStatistic(Resource):
     @setup_required
@@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
                 SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
                     FROM workflow_runs 
                     WHERE app_id = :app_id 
                         AND triggered_from = :triggered_from
-                '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
+                """
+        arg_dict = {
+            "tz": account.timezone,
+            "app_id": app_model.id,
+            "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
+        }
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
-            rs = conn.execute(db.text(sql_query), arg_dict)            
+            rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'terminal_count': i.terminal_count
-                })
+                response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
+
+        return jsonify({"data": response_data})
 
-        return jsonify({
-            'data': response_data
-        })
 
 class WorkflowDailyTokenCostStatistic(Resource):
     @setup_required
@@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
-        sql_query = '''
+        sql_query = """
                 SELECT 
                     date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, 
                     SUM(workflow_runs.total_tokens) as token_count
                 FROM workflow_runs 
                 WHERE app_id = :app_id 
                     AND triggered_from = :triggered_from
-                '''
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
+                """
+        arg_dict = {
+            "tz": account.timezone,
+            "app_id": app_model.id,
+            "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
+        }
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at >= :start'
-            arg_dict['start'] = start_datetime_utc
+            sql_query += " and created_at >= :start"
+            arg_dict["start"] = start_datetime_utc
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query += ' and created_at < :end'
-            arg_dict['end'] = end_datetime_utc
+            sql_query += " and created_at < :end"
+            arg_dict["end"] = end_datetime_utc
 
-        sql_query += ' GROUP BY date order by date'
+        sql_query += " GROUP BY date order by date"
 
         response_data = []
 
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'token_count': i.token_count,
-                })
+                response_data.append(
+                    {
+                        "date": str(i.date),
+                        "token_count": i.token_count,
+                    }
+                )
+
+        return jsonify({"data": response_data})
 
-        return jsonify({
-            'data': response_data
-        })
 
 class WorkflowAverageAppInteractionStatistic(Resource):
     @setup_required
@@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
         account = current_user
 
         parser = reqparse.RequestParser()
-        parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
-        parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
+        parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
+        parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
         args = parser.parse_args()
 
         sql_query = """
@@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource):
                 GROUP BY date, c.created_by) sub
             GROUP BY sub.date
             """
-        arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
+        arg_dict = {
+            "tz": account.timezone,
+            "app_id": app_model.id,
+            "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
+        }
 
         timezone = pytz.timezone(account.timezone)
         utc_timezone = pytz.utc
 
-        if args['start']:
-            start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
+        if args["start"]:
+            start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
             start_datetime = start_datetime.replace(second=0)
 
             start_datetime_timezone = timezone.localize(start_datetime)
             start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
-            arg_dict['start'] = start_datetime_utc
+            sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
+            arg_dict["start"] = start_datetime_utc
         else:
-            sql_query = sql_query.replace('{{start}}', '')
+            sql_query = sql_query.replace("{{start}}", "")
 
-        if args['end']:
-            end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
+        if args["end"]:
+            end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
             end_datetime = end_datetime.replace(second=0)
 
             end_datetime_timezone = timezone.localize(end_datetime)
             end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
 
-            sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
-            arg_dict['end'] = end_datetime_utc
+            sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
+            arg_dict["end"] = end_datetime_utc
         else:
-            sql_query = sql_query.replace('{{end}}', '')
+            sql_query = sql_query.replace("{{end}}", "")
 
         response_data = []
-        
+
         with db.engine.begin() as conn:
             rs = conn.execute(db.text(sql_query), arg_dict)
             for i in rs:
-                response_data.append({
-                    'date': str(i.date),
-                    'interactions': float(i.interactions.quantize(Decimal('0.01')))
-                })
-
-        return jsonify({
-            'data': response_data
-        })
-
-api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
-api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
-api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
-api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')
+                response_data.append(
+                    {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
+                )
+
+        return jsonify({"data": response_data})
+
+
+api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
+api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
+api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs")
+api.add_resource(
+    WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
+)

+ 12 - 12
api/controllers/console/app/wraps.py

@@ -8,24 +8,23 @@ from libs.login import current_user
 from models.model import App, AppMode
 
 
-def get_app_model(view: Optional[Callable] = None, *,
-                  mode: Union[AppMode, list[AppMode]] = None):
+def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
     def decorator(view_func):
         @wraps(view_func)
         def decorated_view(*args, **kwargs):
-            if not kwargs.get('app_id'):
-                raise ValueError('missing app_id in path parameters')
+            if not kwargs.get("app_id"):
+                raise ValueError("missing app_id in path parameters")
 
-            app_id = kwargs.get('app_id')
+            app_id = kwargs.get("app_id")
             app_id = str(app_id)
 
-            del kwargs['app_id']
+            del kwargs["app_id"]
 
-            app_model = db.session.query(App).filter(
-                App.id == app_id,
-                App.tenant_id == current_user.current_tenant_id,
-                App.status == 'normal'
-            ).first()
+            app_model = (
+                db.session.query(App)
+                .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+                .first()
+            )
 
             if not app_model:
                 raise AppNotFoundError()
@@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *,
                     mode_values = {m.value for m in modes}
                     raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
 
-            kwargs['app_model'] = app_model
+            kwargs["app_model"] = app_model
 
             return view_func(*args, **kwargs)
+
         return decorated_view
 
     if view is None:

+ 27 - 26
api/controllers/console/auth/activate.py

@@ -17,60 +17,61 @@ from services.account_service import RegisterService
 class ActivateCheckApi(Resource):
     def get(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
-        parser.add_argument('email', type=email, required=False, nullable=True, location='args')
-        parser.add_argument('token', type=str, required=True, nullable=False, location='args')
+        parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
+        parser.add_argument("email", type=email, required=False, nullable=True, location="args")
+        parser.add_argument("token", type=str, required=True, nullable=False, location="args")
         args = parser.parse_args()
 
-        workspaceId = args['workspace_id']
-        reg_email = args['email']
-        token = args['token']
+        workspaceId = args["workspace_id"]
+        reg_email = args["email"]
+        token = args["token"]
 
         invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
 
-        return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
+        return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
 
 
 class ActivateApi(Resource):
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
-        parser.add_argument('email', type=email, required=False, nullable=True, location='json')
-        parser.add_argument('token', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
-        parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
-        parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
-                            location='json')
-        parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
+        parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("email", type=email, required=False, nullable=True, location="json")
+        parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
+        parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "interface_language", type=supported_language, required=True, nullable=False, location="json"
+        )
+        parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
-        invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
+        invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
         if invitation is None:
             raise AlreadyActivateError()
 
-        RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
+        RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
 
-        account = invitation['account']
-        account.name = args['name']
+        account = invitation["account"]
+        account.name = args["name"]
 
         # generate password salt
         salt = secrets.token_bytes(16)
         base64_salt = base64.b64encode(salt).decode()
 
         # encrypt password with salt
-        password_hashed = hash_password(args['password'], salt)
+        password_hashed = hash_password(args["password"], salt)
         base64_password_hashed = base64.b64encode(password_hashed).decode()
         account.password = base64_password_hashed
         account.password_salt = base64_salt
-        account.interface_language = args['interface_language']
-        account.timezone = args['timezone']
-        account.interface_theme = 'light'
+        account.interface_language = args["interface_language"]
+        account.timezone = args["timezone"]
+        account.interface_theme = "light"
         account.status = AccountStatus.ACTIVE.value
         account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         db.session.commit()
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(ActivateCheckApi, '/activate/check')
-api.add_resource(ActivateApi, '/activate')
+api.add_resource(ActivateCheckApi, "/activate/check")
+api.add_resource(ActivateApi, "/activate")

+ 20 - 19
api/controllers/console/auth/data_source_bearer_auth.py

@@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource):
         data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
         if data_source_api_key_bindings:
             return {
-                'sources': [{
-                                'id': data_source_api_key_binding.id,
-                                'category': data_source_api_key_binding.category,
-                                'provider': data_source_api_key_binding.provider,
-                                'disabled': data_source_api_key_binding.disabled,
-                                'created_at': int(data_source_api_key_binding.created_at.timestamp()),
-                                'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
-                            }
-                            for data_source_api_key_binding in
-                             data_source_api_key_bindings]
+                "sources": [
+                    {
+                        "id": data_source_api_key_binding.id,
+                        "category": data_source_api_key_binding.category,
+                        "provider": data_source_api_key_binding.provider,
+                        "disabled": data_source_api_key_binding.disabled,
+                        "created_at": int(data_source_api_key_binding.created_at.timestamp()),
+                        "updated_at": int(data_source_api_key_binding.updated_at.timestamp()),
+                    }
+                    for data_source_api_key_binding in data_source_api_key_bindings
+                ]
             }
-        return {'sources': []}
+        return {"sources": []}
 
 
 class ApiKeyAuthDataSourceBinding(Resource):
@@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
         parser = reqparse.RequestParser()
-        parser.add_argument('category', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("category", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
         ApiKeyAuthService.validate_api_key_auth_args(args)
         try:
             ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
         except Exception as e:
             raise ApiKeyAuthFailedError(str(e))
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class ApiKeyAuthDataSourceBindingDelete(Resource):
@@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
 
         ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
-api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
-api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
-api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
+api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
+api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
+api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")

+ 35 - 37
api/controllers/console/auth/data_source_oauth.py

@@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
 
 def get_oauth_providers():
     with current_app.app_context():
-        notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
-                                   client_secret=dify_config.NOTION_CLIENT_SECRET,
-                                   redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
+        notion_oauth = NotionOAuth(
+            client_id=dify_config.NOTION_CLIENT_ID,
+            client_secret=dify_config.NOTION_CLIENT_SECRET,
+            redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
+        )
 
-        OAUTH_PROVIDERS = {
-            'notion': notion_oauth
-        }
+        OAUTH_PROVIDERS = {"notion": notion_oauth}
         return OAUTH_PROVIDERS
 
 
@@ -37,18 +37,16 @@ class OAuthDataSource(Resource):
             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
             print(vars(oauth_provider))
         if not oauth_provider:
-            return {'error': 'Invalid provider'}, 400
-        if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
+            return {"error": "Invalid provider"}, 400
+        if dify_config.NOTION_INTEGRATION_TYPE == "internal":
             internal_secret = dify_config.NOTION_INTERNAL_SECRET
             if not internal_secret:
-                return {'error': 'Internal secret is not set'},
+                return ({"error": "Internal secret is not set"},)
             oauth_provider.save_internal_access_token(internal_secret)
-            return { 'data': '' }
+            return {"data": ""}
         else:
             auth_url = oauth_provider.get_authorization_url()
-            return { 'data': auth_url }, 200
-
-
+            return {"data": auth_url}, 200
 
 
 class OAuthDataSourceCallback(Resource):
@@ -57,18 +55,18 @@ class OAuthDataSourceCallback(Resource):
         with current_app.app_context():
             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
         if not oauth_provider:
-            return {'error': 'Invalid provider'}, 400
-        if 'code' in request.args:
-            code = request.args.get('code')
+            return {"error": "Invalid provider"}, 400
+        if "code" in request.args:
+            code = request.args.get("code")
 
-            return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
-        elif 'error' in request.args:
-            error = request.args.get('error')
+            return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}")
+        elif "error" in request.args:
+            error = request.args.get("error")
 
-            return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
+            return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}")
         else:
-            return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
-        
+            return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
+
 
 class OAuthDataSourceBinding(Resource):
     def get(self, provider: str):
@@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource):
         with current_app.app_context():
             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
         if not oauth_provider:
-            return {'error': 'Invalid provider'}, 400
-        if 'code' in request.args:
-            code = request.args.get('code')
+            return {"error": "Invalid provider"}, 400
+        if "code" in request.args:
+            code = request.args.get("code")
             try:
                 oauth_provider.get_access_token(code)
             except requests.exceptions.HTTPError as e:
                 logging.exception(
-                    f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
-                return {'error': 'OAuth data source process failed'}, 400
+                    f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
+                )
+                return {"error": "OAuth data source process failed"}, 400
 
-            return {'result': 'success'}, 200
+            return {"result": "success"}, 200
 
 
 class OAuthDataSourceSync(Resource):
@@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource):
         with current_app.app_context():
             oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
         if not oauth_provider:
-            return {'error': 'Invalid provider'}, 400
+            return {"error": "Invalid provider"}, 400
         try:
             oauth_provider.sync_data_source(binding_id)
         except requests.exceptions.HTTPError as e:
-            logging.exception(
-                f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
-            return {'error': 'OAuth data source process failed'}, 400
+            logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
+            return {"error": "OAuth data source process failed"}, 400
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
-api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
-api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
-api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
-api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
+api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
+api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
+api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
+api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")

+ 5 - 6
api/controllers/console/auth/error.py

@@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException
 
 
 class ApiKeyAuthFailedError(BaseHTTPException):
-    error_code = 'auth_failed'
+    error_code = "auth_failed"
     description = "{message}"
     code = 500
 
 
 class InvalidEmailError(BaseHTTPException):
-    error_code = 'invalid_email'
+    error_code = "invalid_email"
     description = "The email address is not valid."
     code = 400
 
 
 class PasswordMismatchError(BaseHTTPException):
-    error_code = 'password_mismatch'
+    error_code = "password_mismatch"
     description = "The passwords do not match."
     code = 400
 
 
 class InvalidTokenError(BaseHTTPException):
-    error_code = 'invalid_or_expired_token'
+    error_code = "invalid_or_expired_token"
     description = "The token is invalid or has expired."
     code = 400
 
 
 class PasswordResetRateLimitExceededError(BaseHTTPException):
-    error_code = 'password_reset_rate_limit_exceeded'
+    error_code = "password_reset_rate_limit_exceeded"
     description = "Password reset rate limit exceeded. Try again later."
     code = 429
-

+ 17 - 20
api/controllers/console/auth/forgot_password.py

@@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError
 
 
 class ForgotPasswordSendEmailApi(Resource):
-
     @setup_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('email', type=str, required=True, location='json')
+        parser.add_argument("email", type=str, required=True, location="json")
         args = parser.parse_args()
 
-        email = args['email']
+        email = args["email"]
 
         if not email_validate(email):
             raise InvalidEmailError()
@@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource):
 
 
 class ForgotPasswordCheckApi(Resource):
-
     @setup_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('token', type=str, required=True, nullable=False, location='json')
+        parser.add_argument("token", type=str, required=True, nullable=False, location="json")
         args = parser.parse_args()
-        token = args['token']
+        token = args["token"]
 
         reset_data = AccountService.get_reset_password_data(token)
 
         if reset_data is None:
-            return {'is_valid': False, 'email': None}
-        return {'is_valid': True, 'email': reset_data.get('email')}
+            return {"is_valid": False, "email": None}
+        return {"is_valid": True, "email": reset_data.get("email")}
 
 
 class ForgotPasswordResetApi(Resource):
-
     @setup_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('token', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
-        parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
+        parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
+        parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
-        new_password = args['new_password']
-        password_confirm = args['password_confirm']
+        new_password = args["new_password"]
+        password_confirm = args["password_confirm"]
 
         if str(new_password).strip() != str(password_confirm).strip():
             raise PasswordMismatchError()
 
-        token = args['token']
+        token = args["token"]
         reset_data = AccountService.get_reset_password_data(token)
 
         if reset_data is None:
@@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource):
         password_hashed = hash_password(new_password, salt)
         base64_password_hashed = base64.b64encode(password_hashed).decode()
 
-        account = Account.query.filter_by(email=reset_data.get('email')).first()
+        account = Account.query.filter_by(email=reset_data.get("email")).first()
         account.password = base64_password_hashed
         account.password_salt = base64_salt
         db.session.commit()
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
-api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
-api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
+api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
+api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
+api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

+ 17 - 15
api/controllers/console/auth/login.py

@@ -20,37 +20,39 @@ class LoginApi(Resource):
     def post(self):
         """Authenticate user and login."""
         parser = reqparse.RequestParser()
-        parser.add_argument('email', type=email, required=True, location='json')
-        parser.add_argument('password', type=valid_password, required=True, location='json')
-        parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
+        parser.add_argument("email", type=email, required=True, location="json")
+        parser.add_argument("password", type=valid_password, required=True, location="json")
+        parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
         args = parser.parse_args()
 
         # todo: Verify the recaptcha
 
         try:
-            account = AccountService.authenticate(args['email'], args['password'])
+            account = AccountService.authenticate(args["email"], args["password"])
         except services.errors.account.AccountLoginError as e:
-            return {'code': 'unauthorized', 'message': str(e)}, 401
+            return {"code": "unauthorized", "message": str(e)}, 401
 
         # SELF_HOSTED only have one workspace
         tenants = TenantService.get_join_tenants(account)
         if len(tenants) == 0:
-            return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
+            return {
+                "result": "fail",
+                "data": "workspace not found, please contact system admin to invite you to join in a workspace",
+            }
 
         token = AccountService.login(account, ip_address=get_remote_ip(request))
 
-        return {'result': 'success', 'data': token}
+        return {"result": "success", "data": token}
 
 
 class LogoutApi(Resource):
-
     @setup_required
     def get(self):
         account = cast(Account, flask_login.current_user)
-        token = request.headers.get('Authorization', '').split(' ')[1]
+        token = request.headers.get("Authorization", "").split(" ")[1]
         AccountService.logout(account=account, token=token)
         flask_login.logout_user()
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class ResetPasswordApi(Resource):
@@ -80,11 +82,11 @@ class ResetPasswordApi(Resource):
         #     'subject': 'Reset your Dify password',
         #     'html': """
         #         <p>Dear User,</p>
-        #         <p>The Dify team has generated a new password for you, details as follows:</p> 
+        #         <p>The Dify team has generated a new password for you, details as follows:</p>
         #         <p><strong>{new_password}</strong></p>
         #         <p>Please change your password to log in as soon as possible.</p>
         #         <p>Regards,</p>
-        #         <p>The Dify Team</p> 
+        #         <p>The Dify Team</p>
         #     """
         # }
 
@@ -101,8 +103,8 @@ class ResetPasswordApi(Resource):
         #     # handle error
         #     pass
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(LoginApi, '/login')
-api.add_resource(LogoutApi, '/logout')
+api.add_resource(LoginApi, "/login")
+api.add_resource(LogoutApi, "/logout")

+ 13 - 13
api/controllers/console/auth/oauth.py

@@ -25,7 +25,7 @@ def get_oauth_providers():
             github_oauth = GitHubOAuth(
                 client_id=dify_config.GITHUB_CLIENT_ID,
                 client_secret=dify_config.GITHUB_CLIENT_SECRET,
-                redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
+                redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
             )
         if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
             google_oauth = None
@@ -33,10 +33,10 @@ def get_oauth_providers():
             google_oauth = GoogleOAuth(
                 client_id=dify_config.GOOGLE_CLIENT_ID,
                 client_secret=dify_config.GOOGLE_CLIENT_SECRET,
-                redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
+                redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
             )
 
-        OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
+        OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
         return OAUTH_PROVIDERS
 
 
@@ -47,7 +47,7 @@ class OAuthLogin(Resource):
             oauth_provider = OAUTH_PROVIDERS.get(provider)
             print(vars(oauth_provider))
         if not oauth_provider:
-            return {'error': 'Invalid provider'}, 400
+            return {"error": "Invalid provider"}, 400
 
         auth_url = oauth_provider.get_authorization_url()
         return redirect(auth_url)
@@ -59,20 +59,20 @@ class OAuthCallback(Resource):
         with current_app.app_context():
             oauth_provider = OAUTH_PROVIDERS.get(provider)
         if not oauth_provider:
-            return {'error': 'Invalid provider'}, 400
+            return {"error": "Invalid provider"}, 400
 
-        code = request.args.get('code')
+        code = request.args.get("code")
         try:
             token = oauth_provider.get_access_token(code)
             user_info = oauth_provider.get_user_info(token)
         except requests.exceptions.HTTPError as e:
-            logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
-            return {'error': 'OAuth process failed'}, 400
+            logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
+            return {"error": "OAuth process failed"}, 400
 
         account = _generate_account(provider, user_info)
         # Check account status
         if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
-            return {'error': 'Account is banned or closed.'}, 403
+            return {"error": "Account is banned or closed."}, 403
 
         if account.status == AccountStatus.PENDING.value:
             account.status = AccountStatus.ACTIVE.value
@@ -83,7 +83,7 @@ class OAuthCallback(Resource):
 
         token = AccountService.login(account, ip_address=get_remote_ip(request))
 
-        return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
+        return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
 
 
 def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
 
     if not account:
         # Create account
-        account_name = user_info.name if user_info.name else 'Dify'
+        account_name = user_info.name if user_info.name else "Dify"
         account = RegisterService.register(
             email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
         )
@@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
     return account
 
 
-api.add_resource(OAuthLogin, '/oauth/login/<provider>')
-api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')
+api.add_resource(OAuthLogin, "/oauth/login/<provider>")
+api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

+ 7 - 11
api/controllers/console/billing/billing.py

@@ -9,28 +9,24 @@ from services.billing_service import BillingService
 
 
 class Subscription(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     @only_edition_cloud
     def get(self):
-
         parser = reqparse.RequestParser()
-        parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
-        parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
+        parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
+        parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
         args = parser.parse_args()
 
         BillingService.is_tenant_owner_or_admin(current_user)
 
-        return BillingService.get_subscription(args['plan'],
-                                               args['interval'],
-                                               current_user.email,
-                                               current_user.current_tenant_id)
+        return BillingService.get_subscription(
+            args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
+        )
 
 
 class Invoices(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -40,5 +36,5 @@ class Invoices(Resource):
         return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
 
 
-api.add_resource(Subscription, '/billing/subscription')
-api.add_resource(Invoices, '/billing/invoices')
+api.add_resource(Subscription, "/billing/subscription")
+api.add_resource(Invoices, "/billing/invoices")

+ 93 - 90
api/controllers/console/datasets/data_source.py

@@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
 
 
 class DataSourceApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     @marshal_with(integrate_list_fields)
     def get(self):
         # get workspace data source integrates
-        data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
-            DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-            DataSourceOauthBinding.disabled == False
-        ).all()
+        data_source_integrates = (
+            db.session.query(DataSourceOauthBinding)
+            .filter(
+                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                DataSourceOauthBinding.disabled == False,
+            )
+            .all()
+        )
 
-        base_url = request.url_root.rstrip('/')
+        base_url = request.url_root.rstrip("/")
         data_source_oauth_base_path = "/console/api/oauth/data-source"
         providers = ["notion"]
 
@@ -44,26 +47,30 @@ class DataSourceApi(Resource):
             existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
             if existing_integrates:
                 for existing_integrate in list(existing_integrates):
-                    integrate_data.append({
-                        'id': existing_integrate.id,
-                        'provider': provider,
-                        'created_at': existing_integrate.created_at,
-                        'is_bound': True,
-                        'disabled': existing_integrate.disabled,
-                        'source_info': existing_integrate.source_info,
-                        'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
-                })
+                    integrate_data.append(
+                        {
+                            "id": existing_integrate.id,
+                            "provider": provider,
+                            "created_at": existing_integrate.created_at,
+                            "is_bound": True,
+                            "disabled": existing_integrate.disabled,
+                            "source_info": existing_integrate.source_info,
+                            "link": f"{base_url}{data_source_oauth_base_path}/{provider}",
+                        }
+                    )
             else:
-                integrate_data.append({
-                    'id': None,
-                    'provider': provider,
-                    'created_at': None,
-                    'source_info': None,
-                    'is_bound': False,
-                    'disabled': None,
-                    'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
-                })
-        return {'data': integrate_data}, 200
+                integrate_data.append(
+                    {
+                        "id": None,
+                        "provider": provider,
+                        "created_at": None,
+                        "source_info": None,
+                        "is_bound": False,
+                        "disabled": None,
+                        "link": f"{base_url}{data_source_oauth_base_path}/{provider}",
+                    }
+                )
+        return {"data": integrate_data}, 200
 
     @setup_required
     @login_required
@@ -71,92 +78,82 @@ class DataSourceApi(Resource):
     def patch(self, binding_id, action):
         binding_id = str(binding_id)
         action = str(action)
-        data_source_binding = DataSourceOauthBinding.query.filter_by(
-            id=binding_id
-        ).first()
+        data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
         if data_source_binding is None:
-            raise NotFound('Data source binding not found.')
+            raise NotFound("Data source binding not found.")
         # enable binding
-        if action == 'enable':
+        if action == "enable":
             if data_source_binding.disabled:
                 data_source_binding.disabled = False
                 data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                 db.session.add(data_source_binding)
                 db.session.commit()
             else:
-                raise ValueError('Data source is not disabled.')
+                raise ValueError("Data source is not disabled.")
         # disable binding
-        if action == 'disable':
+        if action == "disable":
             if not data_source_binding.disabled:
                 data_source_binding.disabled = True
                 data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                 db.session.add(data_source_binding)
                 db.session.commit()
             else:
-                raise ValueError('Data source is disabled.')
-        return {'result': 'success'}, 200
+                raise ValueError("Data source is disabled.")
+        return {"result": "success"}, 200
 
 
 class DataSourceNotionListApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     @marshal_with(integrate_notion_info_list_fields)
     def get(self):
-        dataset_id = request.args.get('dataset_id', default=None, type=str)
+        dataset_id = request.args.get("dataset_id", default=None, type=str)
         exist_page_ids = []
         # import notion in the exist dataset
         if dataset_id:
             dataset = DatasetService.get_dataset(dataset_id)
             if not dataset:
-                raise NotFound('Dataset not found.')
-            if dataset.data_source_type != 'notion_import':
-                raise ValueError('Dataset is not notion type.')
+                raise NotFound("Dataset not found.")
+            if dataset.data_source_type != "notion_import":
+                raise ValueError("Dataset is not notion type.")
             documents = Document.query.filter_by(
                 dataset_id=dataset_id,
                 tenant_id=current_user.current_tenant_id,
-                data_source_type='notion_import',
-                enabled=True
+                data_source_type="notion_import",
+                enabled=True,
             ).all()
             if documents:
                 for document in documents:
                     data_source_info = json.loads(document.data_source_info)
-                    exist_page_ids.append(data_source_info['notion_page_id'])
+                    exist_page_ids.append(data_source_info["notion_page_id"])
         # get all authorized pages
         data_source_bindings = DataSourceOauthBinding.query.filter_by(
-            tenant_id=current_user.current_tenant_id,
-            provider='notion',
-            disabled=False
+            tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
         ).all()
         if not data_source_bindings:
-            return {
-                'notion_info': []
-            }, 200
+            return {"notion_info": []}, 200
         pre_import_info_list = []
         for data_source_binding in data_source_bindings:
             source_info = data_source_binding.source_info
-            pages = source_info['pages']
+            pages = source_info["pages"]
             # Filter out already bound pages
             for page in pages:
-                if page['page_id'] in exist_page_ids:
-                    page['is_bound'] = True
+                if page["page_id"] in exist_page_ids:
+                    page["is_bound"] = True
                 else:
-                    page['is_bound'] = False
+                    page["is_bound"] = False
             pre_import_info = {
-                'workspace_name': source_info['workspace_name'],
-                'workspace_icon': source_info['workspace_icon'],
-                'workspace_id': source_info['workspace_id'],
-                'pages': pages,
+                "workspace_name": source_info["workspace_name"],
+                "workspace_icon": source_info["workspace_icon"],
+                "workspace_id": source_info["workspace_id"],
+                "pages": pages,
             }
             pre_import_info_list.append(pre_import_info)
-        return {
-            'notion_info': pre_import_info_list
-        }, 200
+        return {"notion_info": pre_import_info_list}, 200
 
 
 class DataSourceNotionApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource):
         data_source_binding = DataSourceOauthBinding.query.filter(
             db.and_(
                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceOauthBinding.provider == 'notion',
+                DataSourceOauthBinding.provider == "notion",
                 DataSourceOauthBinding.disabled == False,
-                DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
+                DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
             )
         ).first()
         if not data_source_binding:
-            raise NotFound('Data source binding not found.')
+            raise NotFound("Data source binding not found.")
 
         extractor = NotionExtractor(
             notion_workspace_id=workspace_id,
             notion_obj_id=page_id,
             notion_page_type=page_type,
             notion_access_token=data_source_binding.access_token,
-            tenant_id=current_user.current_tenant_id
+            tenant_id=current_user.current_tenant_id,
         )
 
         text_docs = extractor.extract()
-        return {
-            'content': "\n".join([doc.page_content for doc in text_docs])
-        }, 200
+        return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
 
     @setup_required
     @login_required
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
-        parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
-        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
+        parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
+        parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
+        parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+        parser.add_argument(
+            "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+        )
         args = parser.parse_args()
         # validate args
         DocumentService.estimate_args_validate(args)
-        notion_info_list = args['notion_info_list']
+        notion_info_list = args["notion_info_list"]
         extract_settings = []
         for notion_info in notion_info_list:
-            workspace_id = notion_info['workspace_id']
-            for page in notion_info['pages']:
+            workspace_id = notion_info["workspace_id"]
+            for page in notion_info["pages"]:
                 extract_setting = ExtractSetting(
                     datasource_type="notion_import",
                     notion_info={
                         "notion_workspace_id": workspace_id,
-                        "notion_obj_id": page['page_id'],
-                        "notion_page_type": page['type'],
-                        "tenant_id": current_user.current_tenant_id
+                        "notion_obj_id": page["page_id"],
+                        "notion_page_type": page["type"],
+                        "tenant_id": current_user.current_tenant_id,
                     },
-                    document_model=args['doc_form']
+                    document_model=args["doc_form"],
                 )
                 extract_settings.append(extract_setting)
         indexing_runner = IndexingRunner()
-        response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
-                                                     args['process_rule'], args['doc_form'],
-                                                     args['doc_language'])
+        response = indexing_runner.indexing_estimate(
+            current_user.current_tenant_id,
+            extract_settings,
+            args["process_rule"],
+            args["doc_form"],
+            args["doc_language"],
+        )
         return response, 200
 
 
 class DataSourceNotionDatasetSyncApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource):
 
 
 class DataSourceNotionDocumentSyncApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource):
         return 200
 
 
-api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
-api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
-api.add_resource(DataSourceNotionApi,
-                 '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
-                 '/datasets/notion-indexing-estimate')
-api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
-api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
+api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
+api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
+api.add_resource(
+    DataSourceNotionApi,
+    "/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
+    "/datasets/notion-indexing-estimate",
+)
+api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
+api.add_resource(
+    DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
+)

+ 254 - 220
api/controllers/console/datasets/datasets.py

@@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
 
 def _validate_name(name):
     if not name or len(name) < 1 or len(name) > 40:
-        raise ValueError('Name must be between 1 to 40 characters.')
+        raise ValueError("Name must be between 1 to 40 characters.")
     return name
 
 
 def _validate_description_length(description):
     if len(description) > 400:
-        raise ValueError('Description cannot exceed 400 characters.')
+        raise ValueError("Description cannot exceed 400 characters.")
     return description
 
 
 class DatasetListApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def get(self):
-        page = request.args.get('page', default=1, type=int)
-        limit = request.args.get('limit', default=20, type=int)
-        ids = request.args.getlist('ids')
-        provider = request.args.get('provider', default="vendor")
-        search = request.args.get('keyword', default=None, type=str)
-        tag_ids = request.args.getlist('tag_ids')
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        ids = request.args.getlist("ids")
+        provider = request.args.get("provider", default="vendor")
+        search = request.args.get("keyword", default=None, type=str)
+        tag_ids = request.args.getlist("tag_ids")
 
         if ids:
             datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
         else:
-            datasets, total = DatasetService.get_datasets(page, limit, provider,
-                                                          current_user.current_tenant_id, current_user, search, tag_ids)
+            datasets, total = DatasetService.get_datasets(
+                page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
+            )
 
         # check embedding setting
         provider_manager = ProviderManager()
-        configurations = provider_manager.get_configurations(
-            tenant_id=current_user.current_tenant_id
-        )
+        configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
 
-        embedding_models = configurations.get_models(
-            model_type=ModelType.TEXT_EMBEDDING,
-            only_active=True
-        )
+        embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
 
         model_names = []
         for embedding_model in embedding_models:
@@ -77,28 +72,22 @@ class DatasetListApi(Resource):
 
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
-            if item['indexing_technique'] == 'high_quality':
+            if item["indexing_technique"] == "high_quality":
                 item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
                 if item_model in model_names:
-                    item['embedding_available'] = True
+                    item["embedding_available"] = True
                 else:
-                    item['embedding_available'] = False
+                    item["embedding_available"] = False
             else:
-                item['embedding_available'] = True
+                item["embedding_available"] = True
 
-            if item.get('permission') == 'partial_members':
-                part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
-                item.update({'partial_member_list': part_users_list})
+            if item.get("permission") == "partial_members":
+                part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
+                item.update({"partial_member_list": part_users_list})
             else:
-                item.update({'partial_member_list': []})
+                item.update({"partial_member_list": []})
 
-        response = {
-            'data': data,
-            'has_more': len(datasets) == limit,
-            'limit': limit,
-            'total': total,
-            'page': page
-        }
+        response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
         return response, 200
 
     @setup_required
@@ -106,13 +95,21 @@ class DatasetListApi(Resource):
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('name', nullable=False, required=True,
-                            help='type is required. Name must be between 1 to 40 characters.',
-                            type=_validate_name)
-        parser.add_argument('indexing_technique', type=str, location='json',
-                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
-                            nullable=True,
-                            help='Invalid indexing technique.')
+        parser.add_argument(
+            "name",
+            nullable=False,
+            required=True,
+            help="type is required. Name must be between 1 to 40 characters.",
+            type=_validate_name,
+        )
+        parser.add_argument(
+            "indexing_technique",
+            type=str,
+            location="json",
+            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+            nullable=True,
+            help="Invalid indexing technique.",
+        )
         args = parser.parse_args()
 
         # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -122,9 +119,9 @@ class DatasetListApi(Resource):
         try:
             dataset = DatasetService.create_empty_dataset(
                 tenant_id=current_user.current_tenant_id,
-                name=args['name'],
-                indexing_technique=args['indexing_technique'],
-                account=current_user
+                name=args["name"],
+                indexing_technique=args["indexing_technique"],
+                account=current_user,
             )
         except services.errors.dataset.DatasetNameDuplicateError:
             raise DatasetNameDuplicateError()
@@ -142,42 +139,36 @@ class DatasetApi(Resource):
         if dataset is None:
             raise NotFound("Dataset not found.")
         try:
-            DatasetService.check_dataset_permission(
-                dataset, current_user)
+            DatasetService.check_dataset_permission(dataset, current_user)
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
         data = marshal(dataset, dataset_detail_fields)
-        if data.get('permission') == 'partial_members':
+        if data.get("permission") == "partial_members":
             part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
-            data.update({'partial_member_list': part_users_list})
+            data.update({"partial_member_list": part_users_list})
 
         # check embedding setting
         provider_manager = ProviderManager()
-        configurations = provider_manager.get_configurations(
-            tenant_id=current_user.current_tenant_id
-        )
+        configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
 
-        embedding_models = configurations.get_models(
-            model_type=ModelType.TEXT_EMBEDDING,
-            only_active=True
-        )
+        embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
 
         model_names = []
         for embedding_model in embedding_models:
             model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
 
-        if data['indexing_technique'] == 'high_quality':
+        if data["indexing_technique"] == "high_quality":
             item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
             if item_model in model_names:
-                data['embedding_available'] = True
+                data["embedding_available"] = True
             else:
-                data['embedding_available'] = False
+                data["embedding_available"] = False
         else:
-            data['embedding_available'] = True
+            data["embedding_available"] = True
 
-        if data.get('permission') == 'partial_members':
+        if data.get("permission") == "partial_members":
             part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
-            data.update({'partial_member_list': part_users_list})
+            data.update({"partial_member_list": part_users_list})
 
         return data, 200
 
@@ -191,42 +182,49 @@ class DatasetApi(Resource):
             raise NotFound("Dataset not found.")
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', nullable=False,
-                            help='type is required. Name must be between 1 to 40 characters.',
-                            type=_validate_name)
-        parser.add_argument('description',
-                            location='json', store_missing=False,
-                            type=_validate_description_length)
-        parser.add_argument('indexing_technique', type=str, location='json',
-                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
-                            nullable=True,
-                            help='Invalid indexing technique.')
-        parser.add_argument('permission', type=str, location='json', choices=(
-            DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
-                            )
-        parser.add_argument('embedding_model', type=str,
-                            location='json', help='Invalid embedding model.')
-        parser.add_argument('embedding_model_provider', type=str,
-                            location='json', help='Invalid embedding model provider.')
-        parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
-        parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
+        parser.add_argument(
+            "name",
+            nullable=False,
+            help="type is required. Name must be between 1 to 40 characters.",
+            type=_validate_name,
+        )
+        parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
+        parser.add_argument(
+            "indexing_technique",
+            type=str,
+            location="json",
+            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+            nullable=True,
+            help="Invalid indexing technique.",
+        )
+        parser.add_argument(
+            "permission",
+            type=str,
+            location="json",
+            choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
+            help="Invalid permission.",
+        )
+        parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
+        parser.add_argument(
+            "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
+        )
+        parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
+        parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
         args = parser.parse_args()
         data = request.get_json()
 
         # check embedding model setting
-        if data.get('indexing_technique') == 'high_quality':
-            DatasetService.check_embedding_model_setting(dataset.tenant_id,
-                                                         data.get('embedding_model_provider'),
-                                                         data.get('embedding_model')
-                                                         )
+        if data.get("indexing_technique") == "high_quality":
+            DatasetService.check_embedding_model_setting(
+                dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
+            )
 
         # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
         DatasetPermissionService.check_permission(
-            current_user, dataset, data.get('permission'), data.get('partial_member_list')
+            current_user, dataset, data.get("permission"), data.get("partial_member_list")
         )
 
-        dataset = DatasetService.update_dataset(
-            dataset_id_str, args, current_user)
+        dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
 
         if dataset is None:
             raise NotFound("Dataset not found.")
@@ -234,16 +232,19 @@ class DatasetApi(Resource):
         result_data = marshal(dataset, dataset_detail_fields)
         tenant_id = current_user.current_tenant_id
 
-        if data.get('partial_member_list') and data.get('permission') == 'partial_members':
+        if data.get("partial_member_list") and data.get("permission") == "partial_members":
             DatasetPermissionService.update_partial_member_list(
-                tenant_id, dataset_id_str, data.get('partial_member_list')
+                tenant_id, dataset_id_str, data.get("partial_member_list")
             )
         # clear partial member list when permission is only_me or all_team_members
-        elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
+        elif (
+            data.get("permission") == DatasetPermissionEnum.ONLY_ME
+            or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
+        ):
             DatasetPermissionService.clear_partial_member_list(dataset_id_str)
 
         partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
-        result_data.update({'partial_member_list': partial_member_list})
+        result_data.update({"partial_member_list": partial_member_list})
 
         return result_data, 200
 
@@ -260,12 +261,13 @@ class DatasetApi(Resource):
         try:
             if DatasetService.delete_dataset(dataset_id_str, current_user):
                 DatasetPermissionService.clear_partial_member_list(dataset_id_str)
-                return {'result': 'success'}, 204
+                return {"result": "success"}, 204
             else:
                 raise NotFound("Dataset not found.")
         except services.errors.dataset.DatasetInUseError:
             raise DatasetInUseError()
 
+
 class DatasetUseCheckApi(Resource):
     @setup_required
     @login_required
@@ -274,10 +276,10 @@ class DatasetUseCheckApi(Resource):
         dataset_id_str = str(dataset_id)
 
         dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
-        return {'is_using': dataset_is_using}, 200
+        return {"is_using": dataset_is_using}, 200
 
-class DatasetQueryApi(Resource):
 
+class DatasetQueryApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
@@ -292,51 +294,53 @@ class DatasetQueryApi(Resource):
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
 
-        page = request.args.get('page', default=1, type=int)
-        limit = request.args.get('limit', default=20, type=int)
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
 
-        dataset_queries, total = DatasetService.get_dataset_queries(
-            dataset_id=dataset.id,
-            page=page,
-            per_page=limit
-        )
+        dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
 
         response = {
-            'data': marshal(dataset_queries, dataset_query_detail_fields),
-            'has_more': len(dataset_queries) == limit,
-            'limit': limit,
-            'total': total,
-            'page': page
+            "data": marshal(dataset_queries, dataset_query_detail_fields),
+            "has_more": len(dataset_queries) == limit,
+            "limit": limit,
+            "total": total,
+            "page": page,
         }
         return response, 200
 
 
 class DatasetIndexingEstimateApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
-        parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
-        parser.add_argument('indexing_technique', type=str, required=True,
-                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
-                            nullable=True, location='json')
-        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
-        parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
-                            location='json')
+        parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
+        parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
+        parser.add_argument(
+            "indexing_technique",
+            type=str,
+            required=True,
+            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+            nullable=True,
+            location="json",
+        )
+        parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+        parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
+        parser.add_argument(
+            "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+        )
         args = parser.parse_args()
         # validate args
         DocumentService.estimate_args_validate(args)
         extract_settings = []
-        if args['info_list']['data_source_type'] == 'upload_file':
-            file_ids = args['info_list']['file_info_list']['file_ids']
-            file_details = db.session.query(UploadFile).filter(
-                UploadFile.tenant_id == current_user.current_tenant_id,
-                UploadFile.id.in_(file_ids)
-            ).all()
+        if args["info_list"]["data_source_type"] == "upload_file":
+            file_ids = args["info_list"]["file_info_list"]["file_ids"]
+            file_details = (
+                db.session.query(UploadFile)
+                .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
+                .all()
+            )
 
             if file_details is None:
                 raise NotFound("File not found.")
@@ -344,55 +348,58 @@ class DatasetIndexingEstimateApi(Resource):
             if file_details:
                 for file_detail in file_details:
                     extract_setting = ExtractSetting(
-                        datasource_type="upload_file",
-                        upload_file=file_detail,
-                        document_model=args['doc_form']
+                        datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
                     )
                     extract_settings.append(extract_setting)
-        elif args['info_list']['data_source_type'] == 'notion_import':
-            notion_info_list = args['info_list']['notion_info_list']
+        elif args["info_list"]["data_source_type"] == "notion_import":
+            notion_info_list = args["info_list"]["notion_info_list"]
             for notion_info in notion_info_list:
-                workspace_id = notion_info['workspace_id']
-                for page in notion_info['pages']:
+                workspace_id = notion_info["workspace_id"]
+                for page in notion_info["pages"]:
                     extract_setting = ExtractSetting(
                         datasource_type="notion_import",
                         notion_info={
                             "notion_workspace_id": workspace_id,
-                            "notion_obj_id": page['page_id'],
-                            "notion_page_type": page['type'],
-                            "tenant_id": current_user.current_tenant_id
+                            "notion_obj_id": page["page_id"],
+                            "notion_page_type": page["type"],
+                            "tenant_id": current_user.current_tenant_id,
                         },
-                        document_model=args['doc_form']
+                        document_model=args["doc_form"],
                     )
                     extract_settings.append(extract_setting)
-        elif args['info_list']['data_source_type'] == 'website_crawl':
-            website_info_list = args['info_list']['website_info_list']
-            for url in website_info_list['urls']:
+        elif args["info_list"]["data_source_type"] == "website_crawl":
+            website_info_list = args["info_list"]["website_info_list"]
+            for url in website_info_list["urls"]:
                 extract_setting = ExtractSetting(
                     datasource_type="website_crawl",
                     website_info={
-                        "provider": website_info_list['provider'],
-                        "job_id": website_info_list['job_id'],
+                        "provider": website_info_list["provider"],
+                        "job_id": website_info_list["job_id"],
                         "url": url,
                         "tenant_id": current_user.current_tenant_id,
-                        "mode": 'crawl',
-                        "only_main_content": website_info_list['only_main_content']
+                        "mode": "crawl",
+                        "only_main_content": website_info_list["only_main_content"],
                     },
-                    document_model=args['doc_form']
+                    document_model=args["doc_form"],
                 )
                 extract_settings.append(extract_setting)
         else:
-            raise ValueError('Data source type not support')
+            raise ValueError("Data source type not support")
         indexing_runner = IndexingRunner()
         try:
-            response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
-                                                         args['process_rule'], args['doc_form'],
-                                                         args['doc_language'], args['dataset_id'],
-                                                         args['indexing_technique'])
+            response = indexing_runner.indexing_estimate(
+                current_user.current_tenant_id,
+                extract_settings,
+                args["process_rule"],
+                args["doc_form"],
+                args["doc_language"],
+                args["dataset_id"],
+                args["indexing_technique"],
+            )
         except LLMBadRequestError:
             raise ProviderNotInitializeError(
-                "No Embedding Model available. Please configure a valid provider "
-                "in the Settings -> Model Provider.")
+                "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
+            )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         except Exception as e:
@@ -402,7 +409,6 @@ class DatasetIndexingEstimateApi(Resource):
 
 
 class DatasetRelatedAppListApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -426,52 +432,52 @@ class DatasetRelatedAppListApi(Resource):
             if app_model:
                 related_apps.append(app_model)
 
-        return {
-            'data': related_apps,
-            'total': len(related_apps)
-        }, 200
+        return {"data": related_apps, "total": len(related_apps)}, 200
 
 
 class DatasetIndexingStatusApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def get(self, dataset_id):
         dataset_id = str(dataset_id)
-        documents = db.session.query(Document).filter(
-            Document.dataset_id == dataset_id,
-            Document.tenant_id == current_user.current_tenant_id
-        ).all()
+        documents = (
+            db.session.query(Document)
+            .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
+            .all()
+        )
         documents_status = []
         for document in documents:
-            completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
-                                                              DocumentSegment.document_id == str(document.id),
-                                                              DocumentSegment.status != 're_segment').count()
-            total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
-                                                          DocumentSegment.status != 're_segment').count()
+            completed_segments = DocumentSegment.query.filter(
+                DocumentSegment.completed_at.isnot(None),
+                DocumentSegment.document_id == str(document.id),
+                DocumentSegment.status != "re_segment",
+            ).count()
+            total_segments = DocumentSegment.query.filter(
+                DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
+            ).count()
             document.completed_segments = completed_segments
             document.total_segments = total_segments
             documents_status.append(marshal(document, document_status_fields))
-        data = {
-            'data': documents_status
-        }
+        data = {"data": documents_status}
         return data
 
 
 class DatasetApiKeyApi(Resource):
     max_keys = 10
-    token_prefix = 'dataset-'
-    resource_type = 'dataset'
+    token_prefix = "dataset-"
+    resource_type = "dataset"
 
     @setup_required
     @login_required
     @account_initialization_required
     @marshal_with(api_key_list)
     def get(self):
-        keys = db.session.query(ApiToken). \
-            filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
-            all()
+        keys = (
+            db.session.query(ApiToken)
+            .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
+            .all()
+        )
         return {"items": keys}
 
     @setup_required
@@ -483,15 +489,17 @@ class DatasetApiKeyApi(Resource):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
 
-        current_key_count = db.session.query(ApiToken). \
-            filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
-            count()
+        current_key_count = (
+            db.session.query(ApiToken)
+            .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
+            .count()
+        )
 
         if current_key_count >= self.max_keys:
             flask_restful.abort(
                 400,
                 message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
-                code='max_keys_exceeded'
+                code="max_keys_exceeded",
             )
 
         key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -505,7 +513,7 @@ class DatasetApiKeyApi(Resource):
 
 
 class DatasetApiDeleteApi(Resource):
-    resource_type = 'dataset'
+    resource_type = "dataset"
 
     @setup_required
     @login_required
@@ -517,18 +525,23 @@ class DatasetApiDeleteApi(Resource):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
 
-        key = db.session.query(ApiToken). \
-            filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
-                   ApiToken.id == api_key_id). \
-            first()
+        key = (
+            db.session.query(ApiToken)
+            .filter(
+                ApiToken.tenant_id == current_user.current_tenant_id,
+                ApiToken.type == self.resource_type,
+                ApiToken.id == api_key_id,
+            )
+            .first()
+        )
 
         if key is None:
-            flask_restful.abort(404, message='API key not found')
+            flask_restful.abort(404, message="API key not found")
 
         db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
         db.session.commit()
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class DatasetApiBaseUrlApi(Resource):
@@ -537,8 +550,10 @@ class DatasetApiBaseUrlApi(Resource):
     @account_initialization_required
     def get(self):
         return {
-            'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
-                             else request.host_url.rstrip('/')) + '/v1'
+            "api_base_url": (
+                dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
+            )
+            + "/v1"
         }
 
 
@@ -549,15 +564,26 @@ class DatasetRetrievalSettingApi(Resource):
     def get(self):
         vector_type = dify_config.VECTOR_STORE
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
-                return {
-                    'retrieval_method': [
-                        RetrievalMethod.SEMANTIC_SEARCH.value
-                    ]
-                }
-            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
+            case (
+                VectorType.MILVUS
+                | VectorType.RELYT
+                | VectorType.PGVECTOR
+                | VectorType.TIDB_VECTOR
+                | VectorType.CHROMA
+                | VectorType.TENCENT
+            ):
+                return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
+            case (
+                VectorType.QDRANT
+                | VectorType.WEAVIATE
+                | VectorType.OPENSEARCH
+                | VectorType.ANALYTICDB
+                | VectorType.MYSCALE
+                | VectorType.ORACLE
+                | VectorType.ELASTICSEARCH
+            ):
                 return {
-                    'retrieval_method': [
+                    "retrieval_method": [
                         RetrievalMethod.SEMANTIC_SEARCH.value,
                         RetrievalMethod.FULL_TEXT_SEARCH.value,
                         RetrievalMethod.HYBRID_SEARCH.value,
@@ -573,15 +599,27 @@ class DatasetRetrievalSettingMockApi(Resource):
     @account_initialization_required
     def get(self, vector_type):
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
+            case (
+                VectorType.MILVUS
+                | VectorType.RELYT
+                | VectorType.TIDB_VECTOR
+                | VectorType.CHROMA
+                | VectorType.TENCENT
+                | VectorType.PGVECTO_RS
+            ):
+                return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
+            case (
+                VectorType.QDRANT
+                | VectorType.WEAVIATE
+                | VectorType.OPENSEARCH
+                | VectorType.ANALYTICDB
+                | VectorType.MYSCALE
+                | VectorType.ORACLE
+                | VectorType.ELASTICSEARCH
+                | VectorType.PGVECTOR
+            ):
                 return {
-                    'retrieval_method': [
-                        RetrievalMethod.SEMANTIC_SEARCH.value
-                    ]
-                }
-            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
-                return {
-                    'retrieval_method': [
+                    "retrieval_method": [
                         RetrievalMethod.SEMANTIC_SEARCH.value,
                         RetrievalMethod.FULL_TEXT_SEARCH.value,
                         RetrievalMethod.HYBRID_SEARCH.value,
@@ -591,7 +629,6 @@ class DatasetRetrievalSettingMockApi(Resource):
                 raise ValueError(f"Unsupported vector db type {vector_type}.")
 
 
-
 class DatasetErrorDocs(Resource):
     @setup_required
     @login_required
@@ -603,10 +640,7 @@ class DatasetErrorDocs(Resource):
             raise NotFound("Dataset not found.")
         results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
 
-        return {
-            'data': [marshal(item, document_status_fields) for item in results],
-            'total': len(results)
-        }, 200
+        return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
 
 
 class DatasetPermissionUserListApi(Resource):
@@ -626,21 +660,21 @@ class DatasetPermissionUserListApi(Resource):
         partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
 
         return {
-            'data': partial_members_list,
+            "data": partial_members_list,
         }, 200
 
 
-api.add_resource(DatasetListApi, '/datasets')
-api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
-api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
-api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
-api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
-api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
-api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
-api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
-api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
-api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
-api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
-api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
-api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
-api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
+api.add_resource(DatasetListApi, "/datasets")
+api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
+api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
+api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
+api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
+api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
+api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
+api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
+api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
+api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
+api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
+api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
+api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
+api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")

文件差异内容过多而无法显示
+ 260 - 281
api/controllers/console/datasets/datasets_document.py


+ 100 - 114
api/controllers/console/datasets/datasets_segments.py

@@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
         document_id = str(document_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
 
         try:
             DatasetService.check_dataset_permission(dataset, current_user)
@@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource):
         document = DocumentService.get_document(dataset_id, document_id)
 
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
 
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=str, default=None, location='args')
-        parser.add_argument('limit', type=int, default=20, location='args')
-        parser.add_argument('status', type=str,
-                            action='append', default=[], location='args')
-        parser.add_argument('hit_count_gte', type=int,
-                            default=None, location='args')
-        parser.add_argument('enabled', type=str, default='all', location='args')
-        parser.add_argument('keyword', type=str, default=None, location='args')
+        parser.add_argument("last_id", type=str, default=None, location="args")
+        parser.add_argument("limit", type=int, default=20, location="args")
+        parser.add_argument("status", type=str, action="append", default=[], location="args")
+        parser.add_argument("hit_count_gte", type=int, default=None, location="args")
+        parser.add_argument("enabled", type=str, default="all", location="args")
+        parser.add_argument("keyword", type=str, default=None, location="args")
         args = parser.parse_args()
 
-        last_id = args['last_id']
-        limit = min(args['limit'], 100)
-        status_list = args['status']
-        hit_count_gte = args['hit_count_gte']
-        keyword = args['keyword']
+        last_id = args["last_id"]
+        limit = min(args["limit"], 100)
+        status_list = args["status"]
+        hit_count_gte = args["hit_count_gte"]
+        keyword = args["keyword"]
 
         query = DocumentSegment.query.filter(
-            DocumentSegment.document_id == str(document_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         )
 
         if last_id is not None:
             last_segment = db.session.get(DocumentSegment, str(last_id))
             if last_segment:
-                query = query.filter(
-                    DocumentSegment.position > last_segment.position)
+                query = query.filter(DocumentSegment.position > last_segment.position)
             else:
-                return {'data': [], 'has_more': False, 'limit': limit}, 200
+                return {"data": [], "has_more": False, "limit": limit}, 200
 
         if status_list:
             query = query.filter(DocumentSegment.status.in_(status_list))
@@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource):
             query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
 
         if keyword:
-            query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
+            query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
 
-        if args['enabled'].lower() != 'all':
-            if args['enabled'].lower() == 'true':
+        if args["enabled"].lower() != "all":
+            if args["enabled"].lower() == "true":
                 query = query.filter(DocumentSegment.enabled == True)
-            elif args['enabled'].lower() == 'false':
+            elif args["enabled"].lower() == "false":
                 query = query.filter(DocumentSegment.enabled == False)
 
         total = query.count()
@@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource):
             segments = segments[:-1]
 
         return {
-            'data': marshal(segments, segment_fields),
-            'doc_form': document.doc_form,
-            'has_more': has_more,
-            'limit': limit,
-            'total': total
+            "data": marshal(segments, segment_fields),
+            "doc_form": document.doc_form,
+            "has_more": has_more,
+            "limit": limit,
+            "total": total,
         }, 200
 
 
@@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('vector_space')
+    @cloud_edition_billing_resource_check("vector_space")
     def patch(self, dataset_id, segment_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check user's model setting
         DatasetService.check_dataset_model_setting(dataset)
         # The role of the current user in the ta table must be admin, owner, or editor
@@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource):
             DatasetService.check_dataset_permission(dataset, current_user)
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
-        if dataset.indexing_technique == 'high_quality':
+        if dataset.indexing_technique == "high_quality":
             # check embedding model setting
             try:
                 model_manager = ModelManager()
@@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource):
                     tenant_id=current_user.current_tenant_id,
                     provider=dataset.embedding_model_provider,
                     model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
+                    model=dataset.embedding_model,
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
+                    "in the Settings -> Model Provider."
+                )
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
 
         segment = DocumentSegment.query.filter(
-            DocumentSegment.id == str(segment_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         ).first()
 
         if not segment:
-            raise NotFound('Segment not found.')
+            raise NotFound("Segment not found.")
 
-        if segment.status != 'completed':
-            raise NotFound('Segment is not completed, enable or disable function is not allowed')
+        if segment.status != "completed":
+            raise NotFound("Segment is not completed, enable or disable function is not allowed")
 
-        document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
+        document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
         cache_result = redis_client.get(document_indexing_cache_key)
         if cache_result is not None:
             raise InvalidActionError("Document is being indexed, please try again later")
 
-        indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
+        indexing_cache_key = "segment_{}_indexing".format(segment.id)
         cache_result = redis_client.get(indexing_cache_key)
         if cache_result is not None:
             raise InvalidActionError("Segment is being indexed, please try again later")
@@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
 
             enable_segment_to_index_task.delay(segment.id)
 
-            return {'result': 'success'}, 200
+            return {"result": "success"}, 200
         elif action == "disable":
             if not segment.enabled:
                 raise InvalidActionError("Segment is already disabled.")
@@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource):
 
             disable_segment_from_index_task.delay(segment.id)
 
-            return {'result': 'success'}, 200
+            return {"result": "success"}, 200
         else:
             raise InvalidActionError()
 
@@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('vector_space')
-    @cloud_edition_billing_knowledge_limit_check('add_segment')
+    @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_knowledge_limit_check("add_segment")
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
         if not current_user.is_editor:
             raise Forbidden()
         # check embedding model setting
-        if dataset.indexing_technique == 'high_quality':
+        if dataset.indexing_technique == "high_quality":
             try:
                 model_manager = ModelManager()
                 model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     provider=dataset.embedding_model_provider,
                     model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
+                    model=dataset.embedding_model,
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
+                    "in the Settings -> Model Provider."
+                )
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
         try:
@@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource):
             raise Forbidden(str(e))
         # validate args
         parser = reqparse.RequestParser()
-        parser.add_argument('content', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
-        parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
+        parser.add_argument("content", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
         args = parser.parse_args()
         SegmentService.segment_create_args_validate(args, document)
         segment = SegmentService.create_segment(args, document, dataset)
-        return {
-            'data': marshal(segment, segment_fields),
-            'doc_form': document.doc_form
-        }, 200
+        return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
 
 
 class DatasetDocumentSegmentUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('vector_space')
+    @cloud_edition_billing_resource_check("vector_space")
     def patch(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check user's model setting
         DatasetService.check_dataset_model_setting(dataset)
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
-            raise NotFound('Document not found.')
-        if dataset.indexing_technique == 'high_quality':
+            raise NotFound("Document not found.")
+        if dataset.indexing_technique == "high_quality":
             # check embedding model setting
             try:
                 model_manager = ModelManager()
@@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
                     tenant_id=current_user.current_tenant_id,
                     provider=dataset.embedding_model_provider,
                     model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
+                    model=dataset.embedding_model,
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
+                    "in the Settings -> Model Provider."
+                )
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
             # check segment
         segment_id = str(segment_id)
         segment = DocumentSegment.query.filter(
-            DocumentSegment.id == str(segment_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         ).first()
         if not segment:
-            raise NotFound('Segment not found.')
+            raise NotFound("Segment not found.")
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
             raise Forbidden()
@@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
             raise Forbidden(str(e))
         # validate args
         parser = reqparse.RequestParser()
-        parser.add_argument('content', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
-        parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
+        parser.add_argument("content", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
         args = parser.parse_args()
         SegmentService.segment_create_args_validate(args, document)
         segment = SegmentService.update_segment(args, segment, document, dataset)
-        return {
-            'data': marshal(segment, segment_fields),
-            'doc_form': document.doc_form
-        }, 200
+        return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
 
     @setup_required
     @login_required
@@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check user's model setting
         DatasetService.check_dataset_model_setting(dataset)
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
         # check segment
         segment_id = str(segment_id)
         segment = DocumentSegment.query.filter(
-            DocumentSegment.id == str(segment_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         ).first()
         if not segment:
-            raise NotFound('Segment not found.')
+            raise NotFound("Segment not found.")
         # The role of the current user in the ta table must be admin or owner
         if not current_user.is_editor:
             raise Forbidden()
@@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
         SegmentService.delete_segment(segment, document, dataset)
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class DatasetDocumentSegmentBatchImportApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('vector_space')
-    @cloud_edition_billing_knowledge_limit_check('add_segment')
+    @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_knowledge_limit_check("add_segment")
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
         # get file from request
-        file = request.files['file']
+        file = request.files["file"]
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if len(request.files) > 1:
             raise TooManyFilesError()
         # check file type
-        if not file.filename.endswith('.csv'):
+        if not file.filename.endswith(".csv"):
             raise ValueError("Invalid file type. Only CSV files are allowed")
 
         try:
@@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
             df = pd.read_csv(file)
             result = []
             for index, row in df.iterrows():
-                if document.doc_form == 'qa_model':
-                    data = {'content': row[0], 'answer': row[1]}
+                if document.doc_form == "qa_model":
+                    data = {"content": row[0], "answer": row[1]}
                 else:
-                    data = {'content': row[0]}
+                    data = {"content": row[0]}
                 result.append(data)
             if len(result) == 0:
                 raise ValueError("The CSV file is empty.")
             # async job
             job_id = str(uuid.uuid4())
-            indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
+            indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
             # send batch add segments task
-            redis_client.setnx(indexing_cache_key, 'waiting')
-            batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
-                                                     current_user.current_tenant_id, current_user.id)
+            redis_client.setnx(indexing_cache_key, "waiting")
+            batch_create_segment_to_index_task.delay(
+                str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
+            )
         except Exception as e:
-            return {'error': str(e)}, 500
-        return {
-            'job_id': job_id,
-            'job_status': 'waiting'
-        }, 200
+            return {"error": str(e)}, 500
+        return {"job_id": job_id, "job_status": "waiting"}, 200
 
     @setup_required
     @login_required
     @account_initialization_required
     def get(self, job_id):
         job_id = str(job_id)
-        indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
+        indexing_cache_key = "segment_batch_import_{}".format(job_id)
         cache_result = redis_client.get(indexing_cache_key)
         if cache_result is None:
             raise ValueError("The job is not exist.")
 
-        return {
-            'job_id': job_id,
-            'job_status': cache_result.decode()
-        }, 200
+        return {"job_id": job_id, "job_status": cache_result.decode()}, 200
 
 
-api.add_resource(DatasetDocumentSegmentListApi,
-                 '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
-api.add_resource(DatasetDocumentSegmentApi,
-                 '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
-api.add_resource(DatasetDocumentSegmentAddApi,
-                 '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
-api.add_resource(DatasetDocumentSegmentUpdateApi,
-                 '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
-api.add_resource(DatasetDocumentSegmentBatchImportApi,
-                 '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
-                 '/datasets/batch_import_status/<uuid:job_id>')
+api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
+api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
+api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
+api.add_resource(
+    DatasetDocumentSegmentUpdateApi,
+    "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
+)
+api.add_resource(
+    DatasetDocumentSegmentBatchImportApi,
+    "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
+    "/datasets/batch_import_status/<uuid:job_id>",
+)

+ 15 - 15
api/controllers/console/datasets/error.py

@@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
 
 
 class NoFileUploadedError(BaseHTTPException):
-    error_code = 'no_file_uploaded'
+    error_code = "no_file_uploaded"
     description = "Please upload your file."
     code = 400
 
 
 class TooManyFilesError(BaseHTTPException):
-    error_code = 'too_many_files'
+    error_code = "too_many_files"
     description = "Only one file is allowed."
     code = 400
 
 
 class FileTooLargeError(BaseHTTPException):
-    error_code = 'file_too_large'
+    error_code = "file_too_large"
     description = "File size exceeded. {message}"
     code = 413
 
 
 class UnsupportedFileTypeError(BaseHTTPException):
-    error_code = 'unsupported_file_type'
+    error_code = "unsupported_file_type"
     description = "File type not allowed."
     code = 415
 
 
 class HighQualityDatasetOnlyError(BaseHTTPException):
-    error_code = 'high_quality_dataset_only'
+    error_code = "high_quality_dataset_only"
     description = "Current operation only supports 'high-quality' datasets."
     code = 400
 
 
 class DatasetNotInitializedError(BaseHTTPException):
-    error_code = 'dataset_not_initialized'
+    error_code = "dataset_not_initialized"
     description = "The dataset is still being initialized or indexing. Please wait a moment."
     code = 400
 
 
 class ArchivedDocumentImmutableError(BaseHTTPException):
-    error_code = 'archived_document_immutable'
+    error_code = "archived_document_immutable"
     description = "The archived document is not editable."
     code = 403
 
 
 class DatasetNameDuplicateError(BaseHTTPException):
-    error_code = 'dataset_name_duplicate'
+    error_code = "dataset_name_duplicate"
     description = "The dataset name already exists. Please modify your dataset name."
     code = 409
 
 
 class InvalidActionError(BaseHTTPException):
-    error_code = 'invalid_action'
+    error_code = "invalid_action"
     description = "Invalid action."
     code = 400
 
 
 class DocumentAlreadyFinishedError(BaseHTTPException):
-    error_code = 'document_already_finished'
+    error_code = "document_already_finished"
     description = "The document has been processed. Please refresh the page or go to the document details."
     code = 400
 
 
 class DocumentIndexingError(BaseHTTPException):
-    error_code = 'document_indexing'
+    error_code = "document_indexing"
     description = "The document is being processed and cannot be edited."
     code = 400
 
 
 class InvalidMetadataError(BaseHTTPException):
-    error_code = 'invalid_metadata'
+    error_code = "invalid_metadata"
     description = "The metadata content is incorrect. Please check and verify."
     code = 400
 
 
 class WebsiteCrawlError(BaseHTTPException):
-    error_code = 'crawl_failed'
+    error_code = "crawl_failed"
     description = "{message}"
     code = 500
 
 
 class DatasetInUseError(BaseHTTPException):
-    error_code = 'dataset_in_use'
+    error_code = "dataset_in_use"
     description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
     code = 409
 
 
 class IndexingEstimateError(BaseHTTPException):
-    error_code = 'indexing_estimate_error'
+    error_code = "indexing_estimate_error"
     description = "Knowledge indexing estimate failed: {message}"
     code = 500

+ 12 - 14
api/controllers/console/datasets/file.py

@@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000
 
 
 class FileApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -31,23 +30,22 @@ class FileApi(Resource):
         batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
         image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
         return {
-            'file_size_limit': file_size_limit,
-            'batch_count_limit': batch_count_limit,
-            'image_file_size_limit': image_file_size_limit
+            "file_size_limit": file_size_limit,
+            "batch_count_limit": batch_count_limit,
+            "image_file_size_limit": image_file_size_limit,
         }, 200
 
     @setup_required
     @login_required
     @account_initialization_required
     @marshal_with(file_fields)
-    @cloud_edition_billing_resource_check(resource='documents')
+    @cloud_edition_billing_resource_check(resource="documents")
     def post(self):
-
         # get file from request
-        file = request.files['file']
+        file = request.files["file"]
 
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if len(request.files) > 1:
@@ -69,7 +67,7 @@ class FilePreviewApi(Resource):
     def get(self, file_id):
         file_id = str(file_id)
         text = FileService.get_file_preview(file_id)
-        return {'content': text}
+        return {"content": text}
 
 
 class FileSupportTypeApi(Resource):
@@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource):
     @account_initialization_required
     def get(self):
         etl_type = dify_config.ETL_TYPE
-        allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
-        return {'allowed_extensions': allowed_extensions}
+        allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
+        return {"allowed_extensions": allowed_extensions}
 
 
-api.add_resource(FileApi, '/files/upload')
-api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
-api.add_resource(FileSupportTypeApi, '/files/support-type')
+api.add_resource(FileApi, "/files/upload")
+api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
+api.add_resource(FileSupportTypeApi, "/files/support-type")

+ 9 - 9
api/controllers/console/datasets/hit_testing.py

@@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService
 
 
 class HitTestingApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -46,8 +45,8 @@ class HitTestingApi(Resource):
             raise Forbidden(str(e))
 
         parser = reqparse.RequestParser()
-        parser.add_argument('query', type=str, location='json')
-        parser.add_argument('retrieval_model', type=dict, required=False, location='json')
+        parser.add_argument("query", type=str, location="json")
+        parser.add_argument("retrieval_model", type=dict, required=False, location="json")
         args = parser.parse_args()
 
         HitTestingService.hit_testing_args_check(args)
@@ -55,13 +54,13 @@ class HitTestingApi(Resource):
         try:
             response = HitTestingService.retrieve(
                 dataset=dataset,
-                query=args['query'],
+                query=args["query"],
                 account=current_user,
-                retrieval_model=args['retrieval_model'],
-                limit=10
+                retrieval_model=args["retrieval_model"],
+                limit=10,
             )
 
-            return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
+            return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
         except services.errors.index.IndexNotInitializedError:
             raise DatasetNotInitializedError()
         except ProviderTokenNotInitError as ex:
@@ -73,7 +72,8 @@ class HitTestingApi(Resource):
         except LLMBadRequestError:
             raise ProviderNotInitializeError(
                 "No Embedding Model or Reranking Model available. Please configure a valid provider "
-                "in the Settings -> Model Provider.")
+                "in the Settings -> Model Provider."
+            )
         except InvokeError as e:
             raise CompletionRequestError(e.description)
         except ValueError as e:
@@ -83,4 +83,4 @@ class HitTestingApi(Resource):
             raise InternalServerError(str(e))
 
 
-api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
+api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

+ 7 - 9
api/controllers/console/datasets/website.py

@@ -9,16 +9,14 @@ from services.website_service import WebsiteService
 
 
 class WebsiteCrawlApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('provider', type=str, choices=['firecrawl'],
-                            required=True, nullable=True, location='json')
-        parser.add_argument('url', type=str, required=True, nullable=True, location='json')
-        parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
+        parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
+        parser.add_argument("url", type=str, required=True, nullable=True, location="json")
+        parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
         args = parser.parse_args()
         WebsiteService.document_create_args_validate(args)
         # crawl url
@@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource):
     @account_initialization_required
     def get(self, job_id: str):
         parser = reqparse.RequestParser()
-        parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
+        parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
         args = parser.parse_args()
         # get crawl status
         try:
-            result = WebsiteService.get_crawl_status(job_id, args['provider'])
+            result = WebsiteService.get_crawl_status(job_id, args["provider"])
         except Exception as e:
             raise WebsiteCrawlError(str(e))
         return result, 200
 
 
-api.add_resource(WebsiteCrawlApi, '/website/crawl')
-api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
+api.add_resource(WebsiteCrawlApi, "/website/crawl")
+api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>")

+ 16 - 10
api/controllers/console/error.py

@@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException
 
 
 class AlreadySetupError(BaseHTTPException):
-    error_code = 'already_setup'
+    error_code = "already_setup"
     description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
     code = 403
 
 
 class NotSetupError(BaseHTTPException):
-    error_code = 'not_setup'
-    description = "Dify has not been initialized and installed yet. " \
-                  "Please proceed with the initialization and installation process first."
+    error_code = "not_setup"
+    description = (
+        "Dify has not been initialized and installed yet. "
+        "Please proceed with the initialization and installation process first."
+    )
     code = 401
 
+
 class NotInitValidateError(BaseHTTPException):
-    error_code = 'not_init_validated'
-    description = "Init validation has not been completed yet. " \
-                  "Please proceed with the init validation process first."
+    error_code = "not_init_validated"
+    description = (
+        "Init validation has not been completed yet. " "Please proceed with the init validation process first."
+    )
     code = 401
 
+
 class InitValidateFailedError(BaseHTTPException):
-    error_code = 'init_validate_failed'
+    error_code = "init_validate_failed"
     description = "Init validation failed. Please check the password and try again."
     code = 401
 
+
 class AccountNotLinkTenantError(BaseHTTPException):
-    error_code = 'account_not_link_tenant'
+    error_code = "account_not_link_tenant"
     description = "Account not link tenant."
     code = 403
 
 
 class AlreadyActivateError(BaseHTTPException):
-    error_code = 'already_activate'
+    error_code = "already_activate"
     description = "Auth Token is invalid or account already activated, please check again."
     code = 403

+ 23 - 26
api/controllers/console/explore/audio.py

@@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource):
     def post(self, installed_app):
         app_model = installed_app.app
 
-        file = request.files['file']
+        file = request.files["file"]
 
         try:
-            response = AudioService.transcript_asr(
-                app_model=app_model,
-                file=file,
-                end_user=None
-            )
+            response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
 
             return response
         except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource):
         app_model = installed_app.app
         try:
             parser = reqparse.RequestParser()
-            parser.add_argument('message_id', type=str, required=False, location='json')
-            parser.add_argument('voice', type=str, location='json')
-            parser.add_argument('text', type=str, location='json')
-            parser.add_argument('streaming', type=bool, location='json')
+            parser.add_argument("message_id", type=str, required=False, location="json")
+            parser.add_argument("voice", type=str, location="json")
+            parser.add_argument("text", type=str, location="json")
+            parser.add_argument("streaming", type=bool, location="json")
             args = parser.parse_args()
 
-            message_id = args.get('message_id', None)
-            text = args.get('text', None)
-            if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
-                    and app_model.workflow
-                    and app_model.workflow.features_dict):
-                text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
-                voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
+            message_id = args.get("message_id", None)
+            text = args.get("text", None)
+            if (
+                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                and app_model.workflow
+                and app_model.workflow.features_dict
+            ):
+                text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
+                voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
             else:
                 try:
-                    voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
+                    voice = (
+                        args.get("voice")
+                        if args.get("voice")
+                        else app_model.app_model_config.text_to_speech_dict.get("voice")
+                    )
                 except Exception:
                     voice = None
-            response = AudioService.transcript_tts(
-                app_model=app_model,
-                message_id=message_id,
-                voice=voice,
-                text=text
-            )
+            response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
             return response
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
@@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource):
             raise InternalServerError()
 
 
-api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
-api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
+api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
+api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
 # api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
 #                  endpoint='installed_app_text_with_message_id')

+ 35 - 32
api/controllers/console/explore/completion.py

@@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService
 
 # define completion api for user
 class CompletionApi(InstalledAppResource):
-
     def post(self, installed_app):
         app_model = installed_app.app
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, location='json', default='')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, location="json", default="")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
-        args['auto_generate_name'] = False
+        streaming = args["response_mode"] == "streaming"
+        args["auto_generate_name"] = False
 
         installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
         db.session.commit()
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=current_user,
-                args=args,
-                invoke_from=InvokeFrom.EXPLORE,
-                streaming=streaming
+                app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource):
 class CompletionStopApi(InstalledAppResource):
     def post(self, installed_app, task_id):
         app_model = installed_app.app
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class ChatApi(InstalledAppResource):
@@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, required=True, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('conversation_id', type=uuid_value, location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, required=True, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("conversation_id", type=uuid_value, location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
         args = parser.parse_args()
 
-        args['auto_generate_name'] = False
+        args["auto_generate_name"] = False
 
         installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
         db.session.commit()
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=current_user,
-                args=args,
-                invoke_from=InvokeFrom.EXPLORE,
-                streaming=True
+                app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
             )
 
             return helper.compact_generate_response(response)
@@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
-api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
-api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
-api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
-api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')
+api.add_resource(
+    CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
+)
+api.add_resource(
+    CompletionStopApi,
+    "/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
+    endpoint="installed_app_stop_completion",
+)
+api.add_resource(
+    ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
+)
+api.add_resource(
+    ChatStopApi,
+    "/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
+    endpoint="installed_app_stop_chat_completion",
+)

+ 33 - 22
api/controllers/console/explore/conversation.py

@@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService
 
 
 class ConversationListApi(InstalledAppResource):
-
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, installed_app):
         app_model = installed_app.app
@@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
+        parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
         args = parser.parse_args()
 
         pinned = None
-        if 'pinned' in args and args['pinned'] is not None:
-            pinned = True if args['pinned'] == 'true' else False
+        if "pinned" in args and args["pinned"] is not None:
+            pinned = True if args["pinned"] == "true" else False
 
         try:
             return WebConversationService.pagination_by_last_id(
                 app_model=app_model,
                 user=current_user,
-                last_id=args['last_id'],
-                limit=args['limit'],
+                last_id=args["last_id"],
+                limit=args["limit"],
                 invoke_from=InvokeFrom.EXPLORE,
                 pinned=pinned,
             )
@@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource):
 
 
 class ConversationRenameApi(InstalledAppResource):
-
     @marshal_with(simple_conversation_fields)
     def post(self, installed_app, c_id):
         app_model = installed_app.app
@@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=False, location='json')
-        parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
+        parser.add_argument("name", type=str, required=False, location="json")
+        parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
         args = parser.parse_args()
 
         try:
             return ConversationService.rename(
-                app_model,
-                conversation_id,
-                current_user,
-                args['name'],
-                args['auto_generate']
+                app_model, conversation_id, current_user, args["name"], args["auto_generate"]
             )
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 
 
 class ConversationPinApi(InstalledAppResource):
-
     def patch(self, installed_app, c_id):
         app_model = installed_app.app
         app_mode = AppMode.value_of(app_model.mode)
@@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource):
         return {"result": "success"}
 
 
-api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
-api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
-api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
-api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
-api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')
+api.add_resource(
+    ConversationRenameApi,
+    "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
+    endpoint="installed_app_conversation_rename",
+)
+api.add_resource(
+    ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
+)
+api.add_resource(
+    ConversationApi,
+    "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
+    endpoint="installed_app_conversation",
+)
+api.add_resource(
+    ConversationPinApi,
+    "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
+    endpoint="installed_app_conversation_pin",
+)
+api.add_resource(
+    ConversationUnPinApi,
+    "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
+    endpoint="installed_app_conversation_unpin",
+)

+ 4 - 4
api/controllers/console/explore/error.py

@@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
 
 
 class NotCompletionAppError(BaseHTTPException):
-    error_code = 'not_completion_app'
+    error_code = "not_completion_app"
     description = "Not Completion App"
     code = 400
 
 
 class NotChatAppError(BaseHTTPException):
-    error_code = 'not_chat_app'
+    error_code = "not_chat_app"
     description = "App mode is invalid."
     code = 400
 
 
 class NotWorkflowAppError(BaseHTTPException):
-    error_code = 'not_workflow_app'
+    error_code = "not_workflow_app"
     description = "Only support workflow app."
     code = 400
 
 
 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
-    error_code = 'app_suggested_questions_after_answer_disabled'
+    error_code = "app_suggested_questions_after_answer_disabled"
     description = "Function Suggested questions after answer disabled."
     code = 403

+ 38 - 38
api/controllers/console/explore/installed_app.py

@@ -21,72 +21,71 @@ class InstalledAppsListApi(Resource):
     @marshal_with(installed_app_list_fields)
     def get(self):
         current_tenant_id = current_user.current_tenant_id
-        installed_apps = db.session.query(InstalledApp).filter(
-            InstalledApp.tenant_id == current_tenant_id
-        ).all()
+        installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
 
         current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
         installed_apps = [
             {
-                'id': installed_app.id,
-                'app': installed_app.app,
-                'app_owner_tenant_id': installed_app.app_owner_tenant_id,
-                'is_pinned': installed_app.is_pinned,
-                'last_used_at': installed_app.last_used_at,
-                'editable': current_user.role in ["owner", "admin"],
-                'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
+                "id": installed_app.id,
+                "app": installed_app.app,
+                "app_owner_tenant_id": installed_app.app_owner_tenant_id,
+                "is_pinned": installed_app.is_pinned,
+                "last_used_at": installed_app.last_used_at,
+                "editable": current_user.role in ["owner", "admin"],
+                "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
             }
             for installed_app in installed_apps
         ]
-        installed_apps.sort(key=lambda app: (-app['is_pinned'],
-                                             app['last_used_at'] is None,
-                                             -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
+        installed_apps.sort(
+            key=lambda app: (
+                -app["is_pinned"],
+                app["last_used_at"] is None,
+                -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
+            )
+        )
 
-        return {'installed_apps': installed_apps}
+        return {"installed_apps": installed_apps}
 
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('apps')
+    @cloud_edition_billing_resource_check("apps")
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
+        parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
         args = parser.parse_args()
 
-        recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
+        recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
         if recommended_app is None:
-            raise NotFound('App not found')
+            raise NotFound("App not found")
 
         current_tenant_id = current_user.current_tenant_id
-        app = db.session.query(App).filter(
-            App.id == args['app_id']
-        ).first()
+        app = db.session.query(App).filter(App.id == args["app_id"]).first()
 
         if app is None:
-            raise NotFound('App not found')
+            raise NotFound("App not found")
 
         if not app.is_public:
-            raise Forbidden('You can\'t install a non-public app')
+            raise Forbidden("You can't install a non-public app")
 
-        installed_app = InstalledApp.query.filter(and_(
-            InstalledApp.app_id == args['app_id'],
-            InstalledApp.tenant_id == current_tenant_id
-        )).first()
+        installed_app = InstalledApp.query.filter(
+            and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
+        ).first()
 
         if installed_app is None:
             # todo: position
             recommended_app.install_count += 1
 
             new_installed_app = InstalledApp(
-                app_id=args['app_id'],
+                app_id=args["app_id"],
                 tenant_id=current_tenant_id,
                 app_owner_tenant_id=app.tenant_id,
                 is_pinned=False,
-                last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
+                last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
             )
             db.session.add(new_installed_app)
             db.session.commit()
 
-        return {'message': 'App installed successfully'}
+        return {"message": "App installed successfully"}
 
 
 class InstalledAppApi(InstalledAppResource):
@@ -94,30 +93,31 @@ class InstalledAppApi(InstalledAppResource):
     update and delete an installed app
     use InstalledAppResource to apply default decorators and get installed_app
     """
+
     def delete(self, installed_app):
         if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
-            raise BadRequest('You can\'t uninstall an app owned by the current tenant')
+            raise BadRequest("You can't uninstall an app owned by the current tenant")
 
         db.session.delete(installed_app)
         db.session.commit()
 
-        return {'result': 'success', 'message': 'App uninstalled successfully'}
+        return {"result": "success", "message": "App uninstalled successfully"}
 
     def patch(self, installed_app):
         parser = reqparse.RequestParser()
-        parser.add_argument('is_pinned', type=inputs.boolean)
+        parser.add_argument("is_pinned", type=inputs.boolean)
         args = parser.parse_args()
 
         commit_args = False
-        if 'is_pinned' in args:
-            installed_app.is_pinned = args['is_pinned']
+        if "is_pinned" in args:
+            installed_app.is_pinned = args["is_pinned"]
             commit_args = True
 
         if commit_args:
             db.session.commit()
 
-        return {'result': 'success', 'message': 'App info updated successfully'}
+        return {"result": "success", "message": "App info updated successfully"}
 
 
-api.add_resource(InstalledAppsListApi, '/installed-apps')
-api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')
+api.add_resource(InstalledAppsListApi, "/installed-apps")
+api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

+ 34 - 21
api/controllers/console/explore/message.py

@@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
-        parser.add_argument('first_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
+        parser.add_argument("first_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
         try:
-            return MessageService.pagination_by_first_id(app_model, current_user,
-                                                     args['conversation_id'], args['first_id'], args['limit'])
+            return MessageService.pagination_by_first_id(
+                app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
+            )
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
         except services.errors.message.FirstMessageNotExistsError:
             raise NotFound("First Message Not Exists.")
 
+
 class MessageFeedbackApi(InstalledAppResource):
     def post(self, installed_app, message_id):
         app_model = installed_app.app
@@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource):
         message_id = str(message_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
+        parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
         args = parser.parse_args()
 
         try:
-            MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
+            MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
         except services.errors.message.MessageNotExistsError:
             raise NotFound("Message Not Exists.")
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class MessageMoreLikeThisApi(InstalledAppResource):
     def get(self, installed_app, message_id):
         app_model = installed_app.app
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         message_id = str(message_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
+        parser.add_argument(
+            "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
+        )
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
+        streaming = args["response_mode"] == "streaming"
 
         try:
             response = AppGenerateService.generate_more_like_this(
@@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
                 user=current_user,
                 message_id=message_id,
                 invoke_from=InvokeFrom.EXPLORE,
-                streaming=streaming
+                streaming=streaming,
             )
             return helper.compact_generate_response(response)
         except MessageNotExistsError:
@@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
 
         try:
             questions = MessageService.get_suggested_questions_after_answer(
-                app_model=app_model,
-                user=current_user,
-                message_id=message_id,
-                invoke_from=InvokeFrom.EXPLORE
+                app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
             )
         except MessageNotExistsError:
             raise NotFound("Message not found")
@@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
-        return {'data': questions}
+        return {"data": questions}
 
 
-api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
-api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
-api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
-api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
+api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
+api.add_resource(
+    MessageFeedbackApi,
+    "/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
+    endpoint="installed_app_message_feedback",
+)
+api.add_resource(
+    MessageMoreLikeThisApi,
+    "/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
+    endpoint="installed_app_more_like_this",
+)
+api.add_resource(
+    MessageSuggestedQuestionApi,
+    "/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
+    endpoint="installed_app_suggested_question",
+)

+ 52 - 48
api/controllers/console/explore/parameter.py

@@ -1,4 +1,3 @@
-
 from flask_restful import fields, marshal_with
 
 from configs import dify_config
@@ -11,33 +10,32 @@ from services.app_service import AppService
 
 class AppParameterApi(InstalledAppResource):
     """Resource for app variables."""
+
     variable_fields = {
-        'key': fields.String,
-        'name': fields.String,
-        'description': fields.String,
-        'type': fields.String,
-        'default': fields.String,
-        'max_length': fields.Integer,
-        'options': fields.List(fields.String)
+        "key": fields.String,
+        "name": fields.String,
+        "description": fields.String,
+        "type": fields.String,
+        "default": fields.String,
+        "max_length": fields.Integer,
+        "options": fields.List(fields.String),
     }
 
-    system_parameters_fields = {
-        'image_file_size_limit': fields.String
-    }
+    system_parameters_fields = {"image_file_size_limit": fields.String}
 
     parameters_fields = {
-        'opening_statement': fields.String,
-        'suggested_questions': fields.Raw,
-        'suggested_questions_after_answer': fields.Raw,
-        'speech_to_text': fields.Raw,
-        'text_to_speech': fields.Raw,
-        'retriever_resource': fields.Raw,
-        'annotation_reply': fields.Raw,
-        'more_like_this': fields.Raw,
-        'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw,
-        'file_upload': fields.Raw,
-        'system_parameters': fields.Nested(system_parameters_fields)
+        "opening_statement": fields.String,
+        "suggested_questions": fields.Raw,
+        "suggested_questions_after_answer": fields.Raw,
+        "speech_to_text": fields.Raw,
+        "text_to_speech": fields.Raw,
+        "retriever_resource": fields.Raw,
+        "annotation_reply": fields.Raw,
+        "more_like_this": fields.Raw,
+        "user_input_form": fields.Raw,
+        "sensitive_word_avoidance": fields.Raw,
+        "file_upload": fields.Raw,
+        "system_parameters": fields.Nested(system_parameters_fields),
     }
 
     @marshal_with(parameters_fields)
@@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource):
             app_model_config = app_model.app_model_config
             features_dict = app_model_config.to_dict()
 
-            user_input_form = features_dict.get('user_input_form', [])
+            user_input_form = features_dict.get("user_input_form", [])
 
         return {
-            'opening_statement': features_dict.get('opening_statement'),
-            'suggested_questions': features_dict.get('suggested_questions', []),
-            'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
-                                                                  {"enabled": False}),
-            'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
-            'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
-            'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
-            'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
-            'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
-            'user_input_form': user_input_form,
-            'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
-                                                          {"enabled": False, "type": "", "configs": []}),
-            'file_upload': features_dict.get('file_upload', {"image": {
-                                                     "enabled": False,
-                                                     "number_limits": 3,
-                                                     "detail": "high",
-                                                     "transfer_methods": ["remote_url", "local_file"]
-                                                 }}),
-            'system_parameters': {
-                'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
-            }
+            "opening_statement": features_dict.get("opening_statement"),
+            "suggested_questions": features_dict.get("suggested_questions", []),
+            "suggested_questions_after_answer": features_dict.get(
+                "suggested_questions_after_answer", {"enabled": False}
+            ),
+            "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
+            "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
+            "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
+            "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
+            "more_like_this": features_dict.get("more_like_this", {"enabled": False}),
+            "user_input_form": user_input_form,
+            "sensitive_word_avoidance": features_dict.get(
+                "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
+            ),
+            "file_upload": features_dict.get(
+                "file_upload",
+                {
+                    "image": {
+                        "enabled": False,
+                        "number_limits": 3,
+                        "detail": "high",
+                        "transfer_methods": ["remote_url", "local_file"],
+                    }
+                },
+            ),
+            "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
         }
 
 
@@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource):
         return AppService().get_app_meta(app_model)
 
 
-api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
-                 endpoint='installed_app_parameters')
-api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
+api.add_resource(
+    AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
+)
+api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

+ 21 - 21
api/controllers/console/explore/recommended_app.py

@@ -8,28 +8,28 @@ from libs.login import login_required
 from services.recommended_app_service import RecommendedAppService
 
 app_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'mode': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String
+    "id": fields.String,
+    "name": fields.String,
+    "mode": fields.String,
+    "icon": fields.String,
+    "icon_background": fields.String,
 }
 
 recommended_app_fields = {
-    'app': fields.Nested(app_fields, attribute='app'),
-    'app_id': fields.String,
-    'description': fields.String(attribute='description'),
-    'copyright': fields.String,
-    'privacy_policy': fields.String,
-    'custom_disclaimer': fields.String,
-    'category': fields.String,
-    'position': fields.Integer,
-    'is_listed': fields.Boolean
+    "app": fields.Nested(app_fields, attribute="app"),
+    "app_id": fields.String,
+    "description": fields.String(attribute="description"),
+    "copyright": fields.String,
+    "privacy_policy": fields.String,
+    "custom_disclaimer": fields.String,
+    "category": fields.String,
+    "position": fields.Integer,
+    "is_listed": fields.Boolean,
 }
 
 recommended_app_list_fields = {
-    'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
-    'categories': fields.List(fields.String)
+    "recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
+    "categories": fields.List(fields.String),
 }
 
 
@@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
     def get(self):
         # language args
         parser = reqparse.RequestParser()
-        parser.add_argument('language', type=str, location='args')
+        parser.add_argument("language", type=str, location="args")
         args = parser.parse_args()
 
-        if args.get('language') and args.get('language') in languages:
-            language_prefix = args.get('language')
+        if args.get("language") and args.get("language") in languages:
+            language_prefix = args.get("language")
         elif current_user and current_user.interface_language:
             language_prefix = current_user.interface_language
         else:
@@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
         return RecommendedAppService.get_recommend_app_detail(app_id)
 
 
-api.add_resource(RecommendedAppListApi, '/explore/apps')
-api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')
+api.add_resource(RecommendedAppListApi, "/explore/apps")
+api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

+ 31 - 25
api/controllers/console/explore/saved_message.py

@@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value
 from services.errors.message import MessageNotExistsError
 from services.saved_message_service import SavedMessageService
 
-feedback_fields = {
-    'rating': fields.String
-}
+feedback_fields = {"rating": fields.String}
 
 message_fields = {
-    'id': fields.String,
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'answer': fields.String,
-    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
-    'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
-    'created_at': TimestampField
+    "id": fields.String,
+    "inputs": fields.Raw,
+    "query": fields.String,
+    "answer": fields.String,
+    "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+    "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
+    "created_at": TimestampField,
 }
 
 
 class SavedMessageListApi(InstalledAppResource):
     saved_message_infinite_scroll_pagination_fields = {
-        'limit': fields.Integer,
-        'has_more': fields.Boolean,
-        'data': fields.List(fields.Nested(message_fields))
+        "limit": fields.Integer,
+        "has_more": fields.Boolean,
+        "data": fields.List(fields.Nested(message_fields)),
     }
 
     @marshal_with(saved_message_infinite_scroll_pagination_fields)
     def get(self, installed_app):
         app_model = installed_app.app
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
-        return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
+        return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
 
     def post(self, installed_app):
         app_model = installed_app.app
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('message_id', type=uuid_value, required=True, location='json')
+        parser.add_argument("message_id", type=uuid_value, required=True, location="json")
         args = parser.parse_args()
 
         try:
-            SavedMessageService.save(app_model, current_user, args['message_id'])
+            SavedMessageService.save(app_model, current_user, args["message_id"])
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class SavedMessageApi(InstalledAppResource):
@@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource):
 
         message_id = str(message_id)
 
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         SavedMessageService.delete(app_model, current_user, message_id)
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
-api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')
+api.add_resource(
+    SavedMessageListApi,
+    "/installed-apps/<uuid:installed_app_id>/saved-messages",
+    endpoint="installed_app_saved_messages",
+)
+api.add_resource(
+    SavedMessageApi,
+    "/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
+    endpoint="installed_app_saved_message",
+)

+ 8 - 12
api/controllers/console/explore/workflow.py

@@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
             raise NotWorkflowAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
+        parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
         args = parser.parse_args()
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=current_user,
-                args=args,
-                invoke_from=InvokeFrom.EXPLORE,
-                streaming=True
+                app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
             )
 
             return helper.compact_generate_response(response)
@@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
 
-        return {
-            "result": "success"
-        }
+        return {"result": "success"}
 
 
-api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
-api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')
+api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
+api.add_resource(
+    InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
+)

+ 14 - 10
api/controllers/console/explore/wraps.py

@@ -14,29 +14,33 @@ def installed_app_required(view=None):
     def decorator(view):
         @wraps(view)
         def decorated(*args, **kwargs):
-            if not kwargs.get('installed_app_id'):
-                raise ValueError('missing installed_app_id in path parameters')
+            if not kwargs.get("installed_app_id"):
+                raise ValueError("missing installed_app_id in path parameters")
 
-            installed_app_id = kwargs.get('installed_app_id')
+            installed_app_id = kwargs.get("installed_app_id")
             installed_app_id = str(installed_app_id)
 
-            del kwargs['installed_app_id']
+            del kwargs["installed_app_id"]
 
-            installed_app = db.session.query(InstalledApp).filter(
-                InstalledApp.id == str(installed_app_id),
-                InstalledApp.tenant_id == current_user.current_tenant_id
-            ).first()
+            installed_app = (
+                db.session.query(InstalledApp)
+                .filter(
+                    InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
+                )
+                .first()
+            )
 
             if installed_app is None:
-                raise NotFound('Installed app not found')
+                raise NotFound("Installed app not found")
 
             if not installed_app.app:
                 db.session.delete(installed_app)
                 db.session.commit()
 
-                raise NotFound('Installed app not found')
+                raise NotFound("Installed app not found")
 
             return view(installed_app, *args, **kwargs)
+
         return decorated
 
     if view:

+ 19 - 25
api/controllers/console/extension.py

@@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService
 
 
 class CodeBasedExtensionAPI(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def get(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('module', type=str, required=True, location='args')
+        parser.add_argument("module", type=str, required=True, location="args")
         args = parser.parse_args()
 
-        return {
-            'module': args['module'],
-            'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
-        }
+        return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
 
 
 class APIBasedExtensionAPI(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource):
     @marshal_with(api_based_extension_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
-        parser.add_argument('api_endpoint', type=str, required=True, location='json')
-        parser.add_argument('api_key', type=str, required=True, location='json')
+        parser.add_argument("name", type=str, required=True, location="json")
+        parser.add_argument("api_endpoint", type=str, required=True, location="json")
+        parser.add_argument("api_key", type=str, required=True, location="json")
         args = parser.parse_args()
 
         extension_data = APIBasedExtension(
             tenant_id=current_user.current_tenant_id,
-            name=args['name'],
-            api_endpoint=args['api_endpoint'],
-            api_key=args['api_key']
+            name=args["name"],
+            api_endpoint=args["api_endpoint"],
+            api_key=args["api_key"],
         )
 
         return APIBasedExtensionService.save(extension_data)
 
 
 class APIBasedExtensionDetailAPI(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource):
         extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
-        parser.add_argument('api_endpoint', type=str, required=True, location='json')
-        parser.add_argument('api_key', type=str, required=True, location='json')
+        parser.add_argument("name", type=str, required=True, location="json")
+        parser.add_argument("api_endpoint", type=str, required=True, location="json")
+        parser.add_argument("api_key", type=str, required=True, location="json")
         args = parser.parse_args()
 
-        extension_data_from_db.name = args['name']
-        extension_data_from_db.api_endpoint = args['api_endpoint']
+        extension_data_from_db.name = args["name"]
+        extension_data_from_db.api_endpoint = args["api_endpoint"]
 
-        if args['api_key'] != HIDDEN_VALUE:
-            extension_data_from_db.api_key = args['api_key']
+        if args["api_key"] != HIDDEN_VALUE:
+            extension_data_from_db.api_key = args["api_key"]
 
         return APIBasedExtensionService.save(extension_data_from_db)
 
@@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource):
 
         APIBasedExtensionService.delete(extension_data_from_db)
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
+api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
 
-api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
-api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')
+api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
+api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")

+ 2 - 3
api/controllers/console/feature.py

@@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record
 
 
 class FeatureApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -24,5 +23,5 @@ class SystemFeatureApi(Resource):
         return FeatureService.get_system_features().model_dump()
 
 
-api.add_resource(FeatureApi, '/features')
-api.add_resource(SystemFeatureApi, '/system-features')
+api.add_resource(FeatureApi, "/features")
+api.add_resource(SystemFeatureApi, "/system-features")

+ 16 - 16
api/controllers/console/init_validate.py

@@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted
 
 
 class InitValidateAPI(Resource):
-
     def get(self):
         init_status = get_init_validate_status()
         if init_status:
-            return { 'status': 'finished' }
-        return {'status': 'not_started' }
+            return {"status": "finished"}
+        return {"status": "not_started"}
 
     @only_edition_self_hosted
     def post(self):
@@ -29,22 +28,23 @@ class InitValidateAPI(Resource):
             raise AlreadySetupError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('password', type=str_len(30),
-                            required=True, location='json')
-        input_password = parser.parse_args()['password']
+        parser.add_argument("password", type=str_len(30), required=True, location="json")
+        input_password = parser.parse_args()["password"]
 
-        if input_password != os.environ.get('INIT_PASSWORD'):
-            session['is_init_validated'] = False
+        if input_password != os.environ.get("INIT_PASSWORD"):
+            session["is_init_validated"] = False
             raise InitValidateFailedError()
-            
-        session['is_init_validated'] = True
-        return {'result': 'success'}, 201
+
+        session["is_init_validated"] = True
+        return {"result": "success"}, 201
+
 
 def get_init_validate_status():
-    if dify_config.EDITION == 'SELF_HOSTED':
-        if os.environ.get('INIT_PASSWORD'):
-            return session.get('is_init_validated') or DifySetup.query.first()
-    
+    if dify_config.EDITION == "SELF_HOSTED":
+        if os.environ.get("INIT_PASSWORD"):
+            return session.get("is_init_validated") or DifySetup.query.first()
+
     return True
 
-api.add_resource(InitValidateAPI, '/init')
+
+api.add_resource(InitValidateAPI, "/init")

+ 2 - 5
api/controllers/console/ping.py

@@ -4,14 +4,11 @@ from controllers.console import api
 
 
 class PingApi(Resource):
-
     def get(self):
         """
         For connection health check
         """
-        return {
-            "result": "pong"
-        }
+        return {"result": "pong"}
 
 
-api.add_resource(PingApi, '/ping')
+api.add_resource(PingApi, "/ping")

+ 14 - 23
api/controllers/console/setup.py

@@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted
 
 
 class SetupApi(Resource):
-
     def get(self):
-        if dify_config.EDITION == 'SELF_HOSTED':
+        if dify_config.EDITION == "SELF_HOSTED":
             setup_status = get_setup_status()
             if setup_status:
-                return {
-                    'step': 'finished',
-                    'setup_at': setup_status.setup_at.isoformat()
-                }
-            return {'step': 'not_started'}
-        return {'step': 'finished'}
+                return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
+            return {"step": "not_started"}
+        return {"step": "finished"}
 
     @only_edition_self_hosted
     def post(self):
@@ -38,28 +34,22 @@ class SetupApi(Resource):
         tenant_count = TenantService.get_tenant_count()
         if tenant_count > 0:
             raise AlreadySetupError()
-    
+
         if not get_init_validate_status():
             raise NotInitValidateError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('email', type=email,
-                            required=True, location='json')
-        parser.add_argument('name', type=str_len(
-            30), required=True, location='json')
-        parser.add_argument('password', type=valid_password,
-                            required=True, location='json')
+        parser.add_argument("email", type=email, required=True, location="json")
+        parser.add_argument("name", type=str_len(30), required=True, location="json")
+        parser.add_argument("password", type=valid_password, required=True, location="json")
         args = parser.parse_args()
 
         # setup
         RegisterService.setup(
-            email=args['email'],
-            name=args['name'],
-            password=args['password'],
-            ip_address=get_remote_ip(request)
+            email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
         )
 
-        return {'result': 'success'}, 201
+        return {"result": "success"}, 201
 
 
 def setup_required(view):
@@ -68,7 +58,7 @@ def setup_required(view):
         # check setup
         if not get_init_validate_status():
             raise NotInitValidateError()
-        
+
         elif not get_setup_status():
             raise NotSetupError()
 
@@ -78,9 +68,10 @@ def setup_required(view):
 
 
 def get_setup_status():
-    if dify_config.EDITION == 'SELF_HOSTED':
+    if dify_config.EDITION == "SELF_HOSTED":
         return DifySetup.query.first()
     else:
         return True
 
-api.add_resource(SetupApi, '/setup')
+
+api.add_resource(SetupApi, "/setup")

+ 32 - 49
api/controllers/console/tag/tags.py

@@ -14,19 +14,18 @@ from services.tag_service import TagService
 
 def _validate_name(name):
     if not name or len(name) < 1 or len(name) > 40:
-        raise ValueError('Name must be between 1 to 50 characters.')
+        raise ValueError("Name must be between 1 to 50 characters.")
     return name
 
 
 class TagListApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     @marshal_with(tag_fields)
     def get(self):
-        tag_type = request.args.get('type', type=str)
-        keyword = request.args.get('keyword', default=None, type=str)
+        tag_type = request.args.get("type", type=str)
+        keyword = request.args.get("keyword", default=None, type=str)
         tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
 
         return tags, 200
@@ -40,28 +39,21 @@ class TagListApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', nullable=False, required=True,
-                            help='Name must be between 1 to 50 characters.',
-                            type=_validate_name)
-        parser.add_argument('type', type=str, location='json',
-                            choices=Tag.TAG_TYPE_LIST,
-                            nullable=True,
-                            help='Invalid tag type.')
+        parser.add_argument(
+            "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
+        )
+        parser.add_argument(
+            "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
+        )
         args = parser.parse_args()
         tag = TagService.save_tags(args)
 
-        response = {
-            'id': tag.id,
-            'name': tag.name,
-            'type': tag.type,
-            'binding_count': 0
-        }
+        response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
 
         return response, 200
 
 
 class TagUpdateDeleteApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', nullable=False, required=True,
-                            help='Name must be between 1 to 50 characters.',
-                            type=_validate_name)
+        parser.add_argument(
+            "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
+        )
         args = parser.parse_args()
         tag = TagService.update_tags(args, tag_id)
 
         binding_count = TagService.get_tag_binding_count(tag_id)
 
-        response = {
-            'id': tag.id,
-            'name': tag.name,
-            'type': tag.type,
-            'binding_count': binding_count
-        }
+        response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
 
         return response, 200
 
@@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource):
 
 
 class TagBindingCreateApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
-                            help='Tag IDs is required.')
-        parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
-                            help='Target ID is required.')
-        parser.add_argument('type', type=str, location='json',
-                            choices=Tag.TAG_TYPE_LIST,
-                            nullable=True,
-                            help='Invalid tag type.')
+        parser.add_argument(
+            "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
+        )
+        parser.add_argument(
+            "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
+        )
+        parser.add_argument(
+            "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
+        )
         args = parser.parse_args()
         TagService.save_tag_binding(args)
 
@@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource):
 
 
 class TagBindingDeleteApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('tag_id', type=str, nullable=False, required=True,
-                            help='Tag ID is required.')
-        parser.add_argument('target_id', type=str, nullable=False, required=True,
-                            help='Target ID is required.')
-        parser.add_argument('type', type=str, location='json',
-                            choices=Tag.TAG_TYPE_LIST,
-                            nullable=True,
-                            help='Invalid tag type.')
+        parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
+        parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
+        parser.add_argument(
+            "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
+        )
         args = parser.parse_args()
         TagService.delete_tag_binding(args)
 
         return 200
 
 
-api.add_resource(TagListApi, '/tags')
-api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
-api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
-api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
+api.add_resource(TagListApi, "/tags")
+api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
+api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
+api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")

+ 16 - 20
api/controllers/console/version.py

@@ -1,4 +1,3 @@
-
 import json
 import logging
 
@@ -11,42 +10,39 @@ from . import api
 
 
 class VersionApi(Resource):
-
     def get(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('current_version', type=str, required=True, location='args')
+        parser.add_argument("current_version", type=str, required=True, location="args")
         args = parser.parse_args()
         check_update_url = dify_config.CHECK_UPDATE_URL
 
         result = {
-            'version': dify_config.CURRENT_VERSION,
-            'release_date': '',
-            'release_notes': '',
-            'can_auto_update': False,
-            'features': {
-                'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
-                'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
-            }
+            "version": dify_config.CURRENT_VERSION,
+            "release_date": "",
+            "release_notes": "",
+            "can_auto_update": False,
+            "features": {
+                "can_replace_logo": dify_config.CAN_REPLACE_LOGO,
+                "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
+            },
         }
 
         if not check_update_url:
             return result
 
         try:
-            response = requests.get(check_update_url, {
-                'current_version': args.get('current_version')
-            })
+            response = requests.get(check_update_url, {"current_version": args.get("current_version")})
         except Exception as error:
             logging.warning("Check update version error: {}.".format(str(error)))
-            result['version'] = args.get('current_version')
+            result["version"] = args.get("current_version")
             return result
 
         content = json.loads(response.content)
-        result['version'] = content['version']
-        result['release_date'] = content['releaseDate']
-        result['release_notes'] = content['releaseNotes']
-        result['can_auto_update'] = content['canAutoUpdate']
+        result["version"] = content["version"]
+        result["release_date"] = content["releaseDate"]
+        result["release_notes"] = content["releaseNotes"]
+        result["can_auto_update"] = content["canAutoUpdate"]
         return result
 
 
-api.add_resource(VersionApi, '/version')
+api.add_resource(VersionApi, "/version")

+ 75 - 81
api/controllers/console/workspace/account.py

@@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr
 
 
 class AccountInitApi(Resource):
-
     @setup_required
     @login_required
     def post(self):
         account = current_user
 
-        if account.status == 'active':
+        if account.status == "active":
             raise AccountAlreadyInitedError()
 
         parser = reqparse.RequestParser()
 
-        if dify_config.EDITION == 'CLOUD':
-            parser.add_argument('invitation_code', type=str, location='json')
+        if dify_config.EDITION == "CLOUD":
+            parser.add_argument("invitation_code", type=str, location="json")
 
-        parser.add_argument(
-            'interface_language', type=supported_language, required=True, location='json')
-        parser.add_argument('timezone', type=timezone,
-                            required=True, location='json')
+        parser.add_argument("interface_language", type=supported_language, required=True, location="json")
+        parser.add_argument("timezone", type=timezone, required=True, location="json")
         args = parser.parse_args()
 
-        if dify_config.EDITION == 'CLOUD':
-            if not args['invitation_code']:
-                raise ValueError('invitation_code is required')
+        if dify_config.EDITION == "CLOUD":
+            if not args["invitation_code"]:
+                raise ValueError("invitation_code is required")
 
             # check invitation code
-            invitation_code = db.session.query(InvitationCode).filter(
-                InvitationCode.code == args['invitation_code'],
-                InvitationCode.status == 'unused',
-            ).first()
+            invitation_code = (
+                db.session.query(InvitationCode)
+                .filter(
+                    InvitationCode.code == args["invitation_code"],
+                    InvitationCode.status == "unused",
+                )
+                .first()
+            )
 
             if not invitation_code:
                 raise InvalidInvitationCodeError()
 
-            invitation_code.status = 'used'
+            invitation_code.status = "used"
             invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
             invitation_code.used_by_tenant_id = account.current_tenant_id
             invitation_code.used_by_account_id = account.id
 
-        account.interface_language = args['interface_language']
-        account.timezone = args['timezone']
-        account.interface_theme = 'light'
-        account.status = 'active'
+        account.interface_language = args["interface_language"]
+        account.timezone = args["timezone"]
+        account.interface_theme = "light"
+        account.status = "active"
         account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         db.session.commit()
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class AccountProfileApi(Resource):
@@ -90,15 +91,14 @@ class AccountNameApi(Resource):
     @marshal_with(account_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument("name", type=str, required=True, location="json")
         args = parser.parse_args()
 
         # Validate account name length
-        if len(args['name']) < 3 or len(args['name']) > 30:
-            raise ValueError(
-                "Account name must be between 3 and 30 characters.")
+        if len(args["name"]) < 3 or len(args["name"]) > 30:
+            raise ValueError("Account name must be between 3 and 30 characters.")
 
-        updated_account = AccountService.update_account(current_user, name=args['name'])
+        updated_account = AccountService.update_account(current_user, name=args["name"])
 
         return updated_account
 
@@ -110,10 +110,10 @@ class AccountAvatarApi(Resource):
     @marshal_with(account_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('avatar', type=str, required=True, location='json')
+        parser.add_argument("avatar", type=str, required=True, location="json")
         args = parser.parse_args()
 
-        updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
+        updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
 
         return updated_account
 
@@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource):
     @marshal_with(account_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument(
-            'interface_language', type=supported_language, required=True, location='json')
+        parser.add_argument("interface_language", type=supported_language, required=True, location="json")
         args = parser.parse_args()
 
-        updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
+        updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
 
         return updated_account
 
@@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource):
     @marshal_with(account_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('interface_theme', type=str, choices=[
-            'light', 'dark'], required=True, location='json')
+        parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
         args = parser.parse_args()
 
-        updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
+        updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
 
         return updated_account
 
@@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource):
     @marshal_with(account_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('timezone', type=str,
-                            required=True, location='json')
+        parser.add_argument("timezone", type=str, required=True, location="json")
         args = parser.parse_args()
 
         # Validate timezone string, e.g. America/New_York, Asia/Shanghai
-        if args['timezone'] not in pytz.all_timezones:
+        if args["timezone"] not in pytz.all_timezones:
             raise ValueError("Invalid timezone string.")
 
-        updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
+        updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
 
         return updated_account
 
@@ -177,20 +174,16 @@ class AccountPasswordApi(Resource):
     @marshal_with(account_fields)
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('password', type=str,
-                            required=False, location='json')
-        parser.add_argument('new_password', type=str,
-                            required=True, location='json')
-        parser.add_argument('repeat_new_password', type=str,
-                            required=True, location='json')
+        parser.add_argument("password", type=str, required=False, location="json")
+        parser.add_argument("new_password", type=str, required=True, location="json")
+        parser.add_argument("repeat_new_password", type=str, required=True, location="json")
         args = parser.parse_args()
 
-        if args['new_password'] != args['repeat_new_password']:
+        if args["new_password"] != args["repeat_new_password"]:
             raise RepeatPasswordNotMatchError()
 
         try:
-            AccountService.update_account_password(
-                current_user, args['password'], args['new_password'])
+            AccountService.update_account_password(current_user, args["password"], args["new_password"])
         except ServiceCurrentPasswordIncorrectError:
             raise CurrentPasswordIncorrectError()
 
@@ -199,14 +192,14 @@ class AccountPasswordApi(Resource):
 
 class AccountIntegrateApi(Resource):
     integrate_fields = {
-        'provider': fields.String,
-        'created_at': TimestampField,
-        'is_bound': fields.Boolean,
-        'link': fields.String
+        "provider": fields.String,
+        "created_at": TimestampField,
+        "is_bound": fields.Boolean,
+        "link": fields.String,
     }
 
     integrate_list_fields = {
-        'data': fields.List(fields.Nested(integrate_fields)),
+        "data": fields.List(fields.Nested(integrate_fields)),
     }
 
     @setup_required
@@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource):
     def get(self):
         account = current_user
 
-        account_integrates = db.session.query(AccountIntegrate).filter(
-            AccountIntegrate.account_id == account.id).all()
+        account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
 
-        base_url = request.url_root.rstrip('/')
+        base_url = request.url_root.rstrip("/")
         oauth_base_path = "/console/api/oauth/login"
         providers = ["github", "google"]
 
@@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource):
         for provider in providers:
             existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
             if existing_integrate:
-                integrate_data.append({
-                    'id': existing_integrate.id,
-                    'provider': provider,
-                    'created_at': existing_integrate.created_at,
-                    'is_bound': True,
-                    'link': None
-                })
+                integrate_data.append(
+                    {
+                        "id": existing_integrate.id,
+                        "provider": provider,
+                        "created_at": existing_integrate.created_at,
+                        "is_bound": True,
+                        "link": None,
+                    }
+                )
             else:
-                integrate_data.append({
-                    'id': None,
-                    'provider': provider,
-                    'created_at': None,
-                    'is_bound': False,
-                    'link': f'{base_url}{oauth_base_path}/{provider}'
-                })
-
-        return {'data': integrate_data}
-
+                integrate_data.append(
+                    {
+                        "id": None,
+                        "provider": provider,
+                        "created_at": None,
+                        "is_bound": False,
+                        "link": f"{base_url}{oauth_base_path}/{provider}",
+                    }
+                )
 
+        return {"data": integrate_data}
 
 
 # Register API resources
-api.add_resource(AccountInitApi, '/account/init')
-api.add_resource(AccountProfileApi, '/account/profile')
-api.add_resource(AccountNameApi, '/account/name')
-api.add_resource(AccountAvatarApi, '/account/avatar')
-api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
-api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
-api.add_resource(AccountTimezoneApi, '/account/timezone')
-api.add_resource(AccountPasswordApi, '/account/password')
-api.add_resource(AccountIntegrateApi, '/account/integrates')
+api.add_resource(AccountInitApi, "/account/init")
+api.add_resource(AccountProfileApi, "/account/profile")
+api.add_resource(AccountNameApi, "/account/name")
+api.add_resource(AccountAvatarApi, "/account/avatar")
+api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
+api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
+api.add_resource(AccountTimezoneApi, "/account/timezone")
+api.add_resource(AccountPasswordApi, "/account/password")
+api.add_resource(AccountIntegrateApi, "/account/integrates")
 # api.add_resource(AccountEmailApi, '/account/email')
 # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

+ 6 - 6
api/controllers/console/workspace/error.py

@@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException
 
 
 class RepeatPasswordNotMatchError(BaseHTTPException):
-    error_code = 'repeat_password_not_match'
+    error_code = "repeat_password_not_match"
     description = "New password and repeat password does not match."
     code = 400
 
 
 class CurrentPasswordIncorrectError(BaseHTTPException):
-    error_code = 'current_password_incorrect'
+    error_code = "current_password_incorrect"
     description = "Current password is incorrect."
     code = 400
 
 
 class ProviderRequestFailedError(BaseHTTPException):
-    error_code = 'provider_request_failed'
+    error_code = "provider_request_failed"
     description = None
     code = 400
 
 
 class InvalidInvitationCodeError(BaseHTTPException):
-    error_code = 'invalid_invitation_code'
+    error_code = "invalid_invitation_code"
     description = "Invalid invitation code."
     code = 400
 
 
 class AccountAlreadyInitedError(BaseHTTPException):
-    error_code = 'account_already_inited'
+    error_code = "account_already_inited"
     description = "The account has been initialized. Please refresh the page."
     code = 400
 
 
 class AccountNotInitializedError(BaseHTTPException):
-    error_code = 'account_not_initialized'
+    error_code = "account_not_initialized"
     description = "The account has not been initialized yet. Please proceed with the initialization process first."
     code = 400

+ 39 - 23
api/controllers/console/workspace/load_balancing_config.py

@@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         # validate model load balancing credentials
@@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource):
             model_load_balancing_service.validate_load_balancing_credentials(
                 tenant_id=tenant_id,
                 provider=provider,
-                model=args['model'],
-                model_type=args['model_type'],
-                credentials=args['credentials']
+                model=args["model"],
+                model_type=args["model_type"],
+                credentials=args["credentials"],
             )
         except CredentialsValidateFailedError as ex:
             result = False
             error = str(ex)
 
-        response = {'result': 'success' if result else 'error'}
+        response = {"result": "success" if result else "error"}
 
         if not result:
-            response['error'] = error
+            response["error"] = error
 
         return response
 
@@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         # validate model load balancing config credentials
@@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
             model_load_balancing_service.validate_load_balancing_credentials(
                 tenant_id=tenant_id,
                 provider=provider,
-                model=args['model'],
-                model_type=args['model_type'],
-                credentials=args['credentials'],
+                model=args["model"],
+                model_type=args["model_type"],
+                credentials=args["credentials"],
                 config_id=config_id,
             )
         except CredentialsValidateFailedError as ex:
             result = False
             error = str(ex)
 
-        response = {'result': 'success' if result else 'error'}
+        response = {"result": "success" if result else "error"}
 
         if not result:
-            response['error'] = error
+            response["error"] = error
 
         return response
 
 
 # Load Balancing Config
-api.add_resource(LoadBalancingCredentialsValidateApi,
-                 '/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
-
-api.add_resource(LoadBalancingConfigCredentialsValidateApi,
-                 '/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')
+api.add_resource(
+    LoadBalancingCredentialsValidateApi,
+    "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
+)
+
+api.add_resource(
+    LoadBalancingConfigCredentialsValidateApi,
+    "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
+)

+ 39 - 41
api/controllers/console/workspace/members.py

@@ -23,7 +23,7 @@ class MemberListApi(Resource):
     @marshal_with(account_with_role_list_fields)
     def get(self):
         members = TenantService.get_tenant_members(current_user.current_tenant)
-        return {'result': 'success', 'accounts': members}, 200
+        return {"result": "success", "accounts": members}, 200
 
 
 class MemberInviteEmailApi(Resource):
@@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('members')
+    @cloud_edition_billing_resource_check("members")
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('emails', type=str, required=True, location='json', action='append')
-        parser.add_argument('role', type=str, required=True, default='admin', location='json')
-        parser.add_argument('language', type=str, required=False, location='json')
+        parser.add_argument("emails", type=str, required=True, location="json", action="append")
+        parser.add_argument("role", type=str, required=True, default="admin", location="json")
+        parser.add_argument("language", type=str, required=False, location="json")
         args = parser.parse_args()
 
-        invitee_emails = args['emails']
-        invitee_role = args['role']
-        interface_language = args['language']
+        invitee_emails = args["emails"]
+        invitee_role = args["role"]
+        interface_language = args["language"]
         if not TenantAccountRole.is_non_owner_role(invitee_role):
-            return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
+            return {"code": "invalid-role", "message": "Invalid role"}, 400
 
         inviter = current_user
         invitation_results = []
         console_web_url = dify_config.CONSOLE_WEB_URL
         for invitee_email in invitee_emails:
             try:
-                token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
-                invitation_results.append({
-                    'status': 'success',
-                    'email': invitee_email,
-                    'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
-                })
+                token = RegisterService.invite_new_member(
+                    inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
+                )
+                invitation_results.append(
+                    {
+                        "status": "success",
+                        "email": invitee_email,
+                        "url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
+                    }
+                )
             except AccountAlreadyInTenantError:
-                invitation_results.append({
-                    'status': 'success',
-                    'email': invitee_email,
-                    'url': f'{console_web_url}/signin'
-                })
+                invitation_results.append(
+                    {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
+                )
                 break
             except Exception as e:
-                invitation_results.append({
-                    'status': 'failed',
-                    'email': invitee_email,
-                    'message': str(e)
-                })
+                invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
 
         return {
-            'result': 'success',
-            'invitation_results': invitation_results,
+            "result": "success",
+            "invitation_results": invitation_results,
         }, 201
 
 
@@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource):
         try:
             TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
         except services.errors.account.CannotOperateSelfError as e:
-            return {'code': 'cannot-operate-self', 'message': str(e)}, 400
+            return {"code": "cannot-operate-self", "message": str(e)}, 400
         except services.errors.account.NoPermissionError as e:
-            return {'code': 'forbidden', 'message': str(e)}, 403
+            return {"code": "forbidden", "message": str(e)}, 403
         except services.errors.account.MemberNotInTenantError as e:
-            return {'code': 'member-not-found', 'message': str(e)}, 404
+            return {"code": "member-not-found", "message": str(e)}, 404
         except Exception as e:
             raise ValueError(str(e))
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class MemberUpdateRoleApi(Resource):
@@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource):
     @account_initialization_required
     def put(self, member_id):
         parser = reqparse.RequestParser()
-        parser.add_argument('role', type=str, required=True, location='json')
+        parser.add_argument("role", type=str, required=True, location="json")
         args = parser.parse_args()
-        new_role = args['role']
+        new_role = args["role"]
 
         if not TenantAccountRole.is_valid_role(new_role):
-            return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
+            return {"code": "invalid-role", "message": "Invalid role"}, 400
 
         member = db.session.get(Account, str(member_id))
         if not member:
@@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource):
 
         # todo: 403
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class DatasetOperatorMemberListApi(Resource):
@@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource):
     @marshal_with(account_with_role_list_fields)
     def get(self):
         members = TenantService.get_dataset_operator_members(current_user.current_tenant)
-        return {'result': 'success', 'accounts': members}, 200
+        return {"result": "success", "accounts": members}, 200
 
 
-api.add_resource(MemberListApi, '/workspaces/current/members')
-api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
-api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
-api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
-api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
+api.add_resource(MemberListApi, "/workspaces/current/members")
+api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
+api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
+api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
+api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")

+ 64 - 74
api/controllers/console/workspace/model_providers.py

@@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService
 
 
 class ModelProviderListApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -25,21 +24,23 @@ class ModelProviderListApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model_type', type=str, required=False, nullable=True,
-                            choices=[mt.value for mt in ModelType], location='args')
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=False,
+            nullable=True,
+            choices=[mt.value for mt in ModelType],
+            location="args",
+        )
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
-        provider_list = model_provider_service.get_provider_list(
-            tenant_id=tenant_id,
-            model_type=args.get('model_type')
-        )
+        provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
 
         return jsonable_encoder({"data": provider_list})
 
 
 class ModelProviderCredentialApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         model_provider_service = ModelProviderService()
-        credentials = model_provider_service.get_provider_credentials(
-            tenant_id=tenant_id,
-            provider=provider
-        )
+        credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
 
-        return {
-            "credentials": credentials
-        }
+        return {"credentials": credentials}
 
 
 class ModelProviderValidateApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def post(self, provider: str):
-
         parser = reqparse.RequestParser()
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         tenant_id = current_user.current_tenant_id
@@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource):
 
         try:
             model_provider_service.provider_credentials_validate(
-                tenant_id=tenant_id,
-                provider=provider,
-                credentials=args['credentials']
+                tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
             )
         except CredentialsValidateFailedError as ex:
             result = False
             error = str(ex)
 
-        response = {'result': 'success' if result else 'error'}
+        response = {"result": "success" if result else "error"}
 
         if not result:
-            response['error'] = error
+            response["error"] = error
 
         return response
 
 
 class ModelProviderApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -103,21 +94,19 @@ class ModelProviderApi(Resource):
             raise Forbidden()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
 
         try:
             model_provider_service.save_provider_credentials(
-                tenant_id=current_user.current_tenant_id,
-                provider=provider,
-                credentials=args['credentials']
+                tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
             )
         except CredentialsValidateFailedError as ex:
             raise ValueError(str(ex))
 
-        return {'result': 'success'}, 201
+        return {"result": "success"}, 201
 
     @setup_required
     @login_required
@@ -127,12 +116,9 @@ class ModelProviderApi(Resource):
             raise Forbidden()
 
         model_provider_service = ModelProviderService()
-        model_provider_service.remove_provider_credentials(
-            tenant_id=current_user.current_tenant_id,
-            provider=provider
-        )
+        model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class ModelProviderIconApi(Resource):
@@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource):
     def get(self, provider: str, icon_type: str, lang: str):
         model_provider_service = ModelProviderService()
         icon, mimetype = model_provider_service.get_model_provider_icon(
-            provider=provider,
-            icon_type=icon_type,
-            lang=lang
+            provider=provider, icon_type=icon_type, lang=lang
         )
 
         return send_file(io.BytesIO(icon), mimetype=mimetype)
 
 
 class PreferredProviderTypeUpdateApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
-                            choices=['system', 'custom'], location='json')
+        parser.add_argument(
+            "preferred_provider_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=["system", "custom"],
+            location="json",
+        )
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
         model_provider_service.switch_preferred_provider(
-            tenant_id=tenant_id,
-            provider=provider,
-            preferred_provider_type=args['preferred_provider_type']
+            tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
         )
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class ModelProviderPaymentCheckoutUrlApi(Resource):
@@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
     @login_required
     @account_initialization_required
     def get(self, provider: str):
-        if provider != 'anthropic':
-            raise ValueError(f'provider name {provider} is invalid')
+        if provider != "anthropic":
+            raise ValueError(f"provider name {provider} is invalid")
         BillingService.is_tenant_owner_or_admin(current_user)
-        data = BillingService.get_model_provider_payment_link(provider_name=provider,
-                                                              tenant_id=current_user.current_tenant_id,
-                                                              account_id=current_user.id,
-                                                              prefilled_email=current_user.email)
+        data = BillingService.get_model_provider_payment_link(
+            provider_name=provider,
+            tenant_id=current_user.current_tenant_id,
+            account_id=current_user.id,
+            prefilled_email=current_user.email,
+        )
         return data
 
 
@@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
     @account_initialization_required
     def post(self, provider: str):
         model_provider_service = ModelProviderService()
-        result = model_provider_service.free_quota_submit(
-            tenant_id=current_user.current_tenant_id,
-            provider=provider
-        )
+        result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
 
         return result
 
@@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
     @account_initialization_required
     def get(self, provider: str):
         parser = reqparse.RequestParser()
-        parser.add_argument('token', type=str, required=False, nullable=True, location='args')
+        parser.add_argument("token", type=str, required=False, nullable=True, location="args")
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
         result = model_provider_service.free_quota_qualification_verify(
-            tenant_id=current_user.current_tenant_id,
-            provider=provider,
-            token=args['token']
+            tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
         )
 
         return result
 
 
-api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
-
-api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
-api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
-api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
-api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
-                                       '<string:icon_type>/<string:lang>')
-
-api.add_resource(PreferredProviderTypeUpdateApi,
-                 '/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
-api.add_resource(ModelProviderPaymentCheckoutUrlApi,
-                 '/workspaces/current/model-providers/<string:provider>/checkout-url')
-api.add_resource(ModelProviderFreeQuotaSubmitApi,
-                 '/workspaces/current/model-providers/<string:provider>/free-quota-submit')
-api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
-                 '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
+api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
+
+api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
+api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
+api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
+api.add_resource(
+    ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
+)
+
+api.add_resource(
+    PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
+)
+api.add_resource(
+    ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
+)
+api.add_resource(
+    ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
+)
+api.add_resource(
+    ModelProviderFreeQuotaQualificationVerifyApi,
+    "/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
+)

+ 145 - 139
api/controllers/console/workspace/models.py

@@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService
 
 
 class DefaultModelApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def get(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='args')
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="args",
+        )
         args = parser.parse_args()
 
         tenant_id = current_user.current_tenant_id
 
         model_provider_service = ModelProviderService()
         default_model_entity = model_provider_service.get_default_model_of_model_type(
-            tenant_id=tenant_id,
-            model_type=args['model_type']
+            tenant_id=tenant_id, model_type=args["model_type"]
         )
 
-        return jsonable_encoder({
-            "data": default_model_entity
-        })
+        return jsonable_encoder({"data": default_model_entity})
 
     @setup_required
     @login_required
@@ -44,40 +46,39 @@ class DefaultModelApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         parser = reqparse.RequestParser()
-        parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
+        parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         tenant_id = current_user.current_tenant_id
 
         model_provider_service = ModelProviderService()
-        model_settings = args['model_settings']
+        model_settings = args["model_settings"]
         for model_setting in model_settings:
-            if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
-                raise ValueError('invalid model type')
+            if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
+                raise ValueError("invalid model type")
 
-            if 'provider' not in model_setting:
+            if "provider" not in model_setting:
                 continue
 
-            if 'model' not in model_setting:
-                raise ValueError('invalid model')
+            if "model" not in model_setting:
+                raise ValueError("invalid model")
 
             try:
                 model_provider_service.update_default_model_of_model_type(
                     tenant_id=tenant_id,
-                    model_type=model_setting['model_type'],
-                    provider=model_setting['provider'],
-                    model=model_setting['model']
+                    model_type=model_setting["model_type"],
+                    provider=model_setting["provider"],
+                    model=model_setting["model"],
                 )
             except Exception:
                 logging.warning(f"{model_setting['model_type']} save error")
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class ModelProviderModelApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         model_provider_service = ModelProviderService()
-        models = model_provider_service.get_models_by_provider(
-            tenant_id=tenant_id,
-            provider=provider
-        )
+        models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
 
-        return jsonable_encoder({
-            "data": models
-        })
+        return jsonable_encoder({"data": models})
 
     @setup_required
     @login_required
@@ -104,62 +100,66 @@ class ModelProviderModelApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
-        parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
-        parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
-        parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
+        parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
+        parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
+        parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
         args = parser.parse_args()
 
         model_load_balancing_service = ModelLoadBalancingService()
 
-        if ('load_balancing' in args and args['load_balancing'] and
-                'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
-            if 'configs' not in args['load_balancing']:
-                raise ValueError('invalid load balancing configs')
+        if (
+            "load_balancing" in args
+            and args["load_balancing"]
+            and "enabled" in args["load_balancing"]
+            and args["load_balancing"]["enabled"]
+        ):
+            if "configs" not in args["load_balancing"]:
+                raise ValueError("invalid load balancing configs")
 
             # save load balancing configs
             model_load_balancing_service.update_load_balancing_configs(
                 tenant_id=tenant_id,
                 provider=provider,
-                model=args['model'],
-                model_type=args['model_type'],
-                configs=args['load_balancing']['configs']
+                model=args["model"],
+                model_type=args["model_type"],
+                configs=args["load_balancing"]["configs"],
             )
 
             # enable load balancing
             model_load_balancing_service.enable_model_load_balancing(
-                tenant_id=tenant_id,
-                provider=provider,
-                model=args['model'],
-                model_type=args['model_type']
+                tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
             )
         else:
             # disable load balancing
             model_load_balancing_service.disable_model_load_balancing(
-                tenant_id=tenant_id,
-                provider=provider,
-                model=args['model'],
-                model_type=args['model_type']
+                tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
             )
 
-            if args.get('config_from', '') != 'predefined-model':
+            if args.get("config_from", "") != "predefined-model":
                 model_provider_service = ModelProviderService()
 
                 try:
                     model_provider_service.save_model_credentials(
                         tenant_id=tenant_id,
                         provider=provider,
-                        model=args['model'],
-                        model_type=args['model_type'],
-                        credentials=args['credentials']
+                        model=args["model"],
+                        model_type=args["model_type"],
+                        credentials=args["credentials"],
                     )
                 except CredentialsValidateFailedError as ex:
                     logging.exception(f"save model credentials error: {ex}")
                     raise ValueError(str(ex))
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
     @setup_required
     @login_required
@@ -171,24 +171,26 @@ class ModelProviderModelApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
         model_provider_service.remove_model_credentials(
-            tenant_id=tenant_id,
-            provider=provider,
-            model=args['model'],
-            model_type=args['model_type']
+            tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
         )
 
-        return {'result': 'success'}, 204
+        return {"result": "success"}, 204
 
 
 class ModelProviderModelCredentialApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -196,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='args')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='args')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="args")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="args",
+        )
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
         credentials = model_provider_service.get_model_credentials(
-            tenant_id=tenant_id,
-            provider=provider,
-            model_type=args['model_type'],
-            model=args['model']
+            tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
         )
 
         model_load_balancing_service = ModelLoadBalancingService()
         is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
-            tenant_id=tenant_id,
-            provider=provider,
-            model=args['model'],
-            model_type=args['model_type']
+            tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
         )
 
         return {
             "credentials": credentials,
-            "load_balancing": {
-                "enabled": is_load_balancing_enabled,
-                "configs": load_balancing_configs
-            }
+            "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
         }
 
 
 class ModelProviderModelEnableApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -235,24 +233,26 @@ class ModelProviderModelEnableApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
         model_provider_service.enable_model(
-            tenant_id=tenant_id,
-            provider=provider,
-            model=args['model'],
-            model_type=args['model_type']
+            tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
         )
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class ModelProviderModelDisableApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -260,24 +260,26 @@ class ModelProviderModelDisableApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
         model_provider_service.disable_model(
-            tenant_id=tenant_id,
-            provider=provider,
-            model=args['model'],
-            model_type=args['model_type']
+            tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
         )
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class ModelProviderModelValidateApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -285,10 +287,16 @@ class ModelProviderModelValidateApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=[mt.value for mt in ModelType], location='json')
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="json")
+        parser.add_argument(
+            "model_type",
+            type=str,
+            required=True,
+            nullable=False,
+            choices=[mt.value for mt in ModelType],
+            location="json",
+        )
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
         model_provider_service = ModelProviderService()
@@ -300,48 +308,42 @@ class ModelProviderModelValidateApi(Resource):
             model_provider_service.model_credentials_validate(
                 tenant_id=tenant_id,
                 provider=provider,
-                model=args['model'],
-                model_type=args['model_type'],
-                credentials=args['credentials']
+                model=args["model"],
+                model_type=args["model_type"],
+                credentials=args["credentials"],
             )
         except CredentialsValidateFailedError as ex:
             result = False
             error = str(ex)
 
-        response = {'result': 'success' if result else 'error'}
+        response = {"result": "success" if result else "error"}
 
         if not result:
-            response['error'] = error
+            response["error"] = error
 
         return response
 
 
 class ModelProviderModelParameterRuleApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
     def get(self, provider: str):
         parser = reqparse.RequestParser()
-        parser.add_argument('model', type=str, required=True, nullable=False, location='args')
+        parser.add_argument("model", type=str, required=True, nullable=False, location="args")
         args = parser.parse_args()
 
         tenant_id = current_user.current_tenant_id
 
         model_provider_service = ModelProviderService()
         parameter_rules = model_provider_service.get_model_parameter_rules(
-            tenant_id=tenant_id,
-            provider=provider,
-            model=args['model']
+            tenant_id=tenant_id, provider=provider, model=args["model"]
         )
 
-        return jsonable_encoder({
-            "data": parameter_rules
-        })
+        return jsonable_encoder({"data": parameter_rules})
 
 
 class ModelProviderAvailableModelApi(Resource):
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -349,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         model_provider_service = ModelProviderService()
-        models = model_provider_service.get_models_by_model_type(
-            tenant_id=tenant_id,
-            model_type=model_type
-        )
-
-        return jsonable_encoder({
-            "data": models
-        })
-
-
-api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
-api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
-                 endpoint='model-provider-model-enable')
-api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
-                 endpoint='model-provider-model-disable')
-api.add_resource(ModelProviderModelCredentialApi,
-                 '/workspaces/current/model-providers/<string:provider>/models/credentials')
-api.add_resource(ModelProviderModelValidateApi,
-                 '/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
-
-api.add_resource(ModelProviderModelParameterRuleApi,
-                 '/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
-api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
-api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
+        models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
+
+        return jsonable_encoder({"data": models})
+
+
+api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
+api.add_resource(
+    ModelProviderModelEnableApi,
+    "/workspaces/current/model-providers/<string:provider>/models/enable",
+    endpoint="model-provider-model-enable",
+)
+api.add_resource(
+    ModelProviderModelDisableApi,
+    "/workspaces/current/model-providers/<string:provider>/models/disable",
+    endpoint="model-provider-model-disable",
+)
+api.add_resource(
+    ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
+)
+api.add_resource(
+    ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
+)
+
+api.add_resource(
+    ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
+)
+api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
+api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

+ 228 - 172
api/controllers/console/workspace/tool_providers.py

@@ -28,10 +28,18 @@ class ToolProviderListApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         req = reqparse.RequestParser()
-        req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
+        req.add_argument(
+            "type",
+            type=str,
+            choices=["builtin", "model", "api", "workflow"],
+            required=False,
+            nullable=True,
+            location="args",
+        )
         args = req.parse_args()
 
-        return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
+        return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
+
 
 class ToolBuiltinProviderListToolsApi(Resource):
     @setup_required
@@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource):
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
-        return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
-            user_id,
-            tenant_id,
-            provider,
-        ))
+        return jsonable_encoder(
+            BuiltinToolManageService.list_builtin_tool_provider_tools(
+                user_id,
+                tenant_id,
+                provider,
+            )
+        )
+
 
 class ToolBuiltinProviderDeleteApi(Resource):
     @setup_required
@@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
     def post(self, provider):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
@@ -63,7 +74,8 @@ class ToolBuiltinProviderDeleteApi(Resource):
             tenant_id,
             provider,
         )
-    
+
+
 class ToolBuiltinProviderUpdateApi(Resource):
     @setup_required
     @login_required
@@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
     def post(self, provider):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
 
         args = parser.parse_args()
 
@@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource):
             user_id,
             tenant_id,
             provider,
-            args['credentials'],
+            args["credentials"],
         )
-    
+
+
 class ToolBuiltinProviderGetCredentialsApi(Resource):
     @setup_required
     @login_required
@@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
             provider,
         )
 
+
 class ToolBuiltinProviderIconApi(Resource):
     @setup_required
     def get(self, provider):
@@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource):
         icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
         return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
 
+
 class ToolApiProviderAddApi(Resource):
     @setup_required
     @login_required
@@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
-        parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
-        parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
+        parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
 
         args = parser.parse_args()
 
         return ApiToolManageService.create_api_tool_provider(
             user_id,
             tenant_id,
-            args['provider'],
-            args['icon'],
-            args['credentials'],
-            args['schema_type'],
-            args['schema'],
-            args.get('privacy_policy', ''),
-            args.get('custom_disclaimer', ''),
-            args.get('labels', []),
+            args["provider"],
+            args["icon"],
+            args["credentials"],
+            args["schema_type"],
+            args["schema"],
+            args.get("privacy_policy", ""),
+            args.get("custom_disclaimer", ""),
+            args.get("labels", []),
         )
 
+
 class ToolApiProviderGetRemoteSchemaApi(Resource):
     @setup_required
     @login_required
@@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
     def get(self):
         parser = reqparse.RequestParser()
 
-        parser.add_argument('url', type=str, required=True, nullable=False, location='args')
+        parser.add_argument("url", type=str, required=True, nullable=False, location="args")
 
         args = parser.parse_args()
 
         return ApiToolManageService.get_api_tool_provider_remote_schema(
             current_user.id,
             current_user.current_tenant_id,
-            args['url'],
+            args["url"],
         )
-    
+
+
 class ToolApiProviderListToolsApi(Resource):
     @setup_required
     @login_required
@@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource):
 
         parser = reqparse.RequestParser()
 
-        parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
+        parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
 
         args = parser.parse_args()
 
-        return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
-            user_id,
-            tenant_id,
-            args['provider'],
-        ))
+        return jsonable_encoder(
+            ApiToolManageService.list_api_tool_provider_tools(
+                user_id,
+                tenant_id,
+                args["provider"],
+            )
+        )
+
 
 class ToolApiProviderUpdateApi(Resource):
     @setup_required
@@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
-        parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
-        parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
+        parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
+        parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
 
         args = parser.parse_args()
 
         return ApiToolManageService.update_api_tool_provider(
             user_id,
             tenant_id,
-            args['provider'],
-            args['original_provider'],
-            args['icon'],
-            args['credentials'],
-            args['schema_type'],
-            args['schema'],
-            args['privacy_policy'],
-            args['custom_disclaimer'],
-            args.get('labels', []),
+            args["provider"],
+            args["original_provider"],
+            args["icon"],
+            args["credentials"],
+            args["schema_type"],
+            args["schema"],
+            args["privacy_policy"],
+            args["custom_disclaimer"],
+            args.get("labels", []),
         )
 
+
 class ToolApiProviderDeleteApi(Resource):
     @setup_required
     @login_required
@@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
 
-        parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
+        parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
 
         args = parser.parse_args()
 
         return ApiToolManageService.delete_api_tool_provider(
             user_id,
             tenant_id,
-            args['provider'],
+            args["provider"],
         )
 
+
 class ToolApiProviderGetApi(Resource):
     @setup_required
     @login_required
@@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource):
 
         parser = reqparse.RequestParser()
 
-        parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
+        parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
 
         args = parser.parse_args()
 
         return ApiToolManageService.get_api_tool_provider(
             user_id,
             tenant_id,
-            args['provider'],
+            args["provider"],
         )
 
+
 class ToolBuiltinProviderCredentialsSchemaApi(Resource):
     @setup_required
     @login_required
@@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
     def get(self, provider):
         return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
 
+
 class ToolApiProviderSchemaApi(Resource):
     @setup_required
     @login_required
@@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource):
     def post(self):
         parser = reqparse.RequestParser()
 
-        parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
+        parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
 
         args = parser.parse_args()
 
         return ApiToolManageService.parser_api_schema(
-            schema=args['schema'],
+            schema=args["schema"],
         )
 
+
 class ToolApiProviderPreviousTestApi(Resource):
     @setup_required
     @login_required
@@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource):
     def post(self):
         parser = reqparse.RequestParser()
 
-        parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
-        parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
+        parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
+        parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
 
         args = parser.parse_args()
 
         return ApiToolManageService.test_api_tool_preview(
             current_user.current_tenant_id,
-            args['provider_name'] if args['provider_name'] else '',
-            args['tool_name'],
-            args['credentials'],
-            args['parameters'],
-            args['schema_type'],
-            args['schema'],
+            args["provider_name"] if args["provider_name"] else "",
+            args["tool_name"],
+            args["credentials"],
+            args["parameters"],
+            args["schema_type"],
+            args["schema"],
         )
 
+
 class ToolWorkflowProviderCreateApi(Resource):
     @setup_required
     @login_required
@@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         reqparser = reqparse.RequestParser()
-        reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
-        reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
-        reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
-        reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
-        reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
-        reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
-        reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
-        reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
+        reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
+        reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
+        reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
+        reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
+        reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
+        reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
+        reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
+        reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
 
         args = reqparser.parse_args()
 
         return WorkflowToolManageService.create_workflow_tool(
             user_id,
             tenant_id,
-            args['workflow_app_id'],
-            args['name'],
-            args['label'],
-            args['icon'],
-            args['description'],
-            args['parameters'],
-            args['privacy_policy'],
-            args.get('labels', []),
+            args["workflow_app_id"],
+            args["name"],
+            args["label"],
+            args["icon"],
+            args["description"],
+            args["parameters"],
+            args["privacy_policy"],
+            args.get("labels", []),
         )
 
+
 class ToolWorkflowProviderUpdateApi(Resource):
     @setup_required
     @login_required
@@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         reqparser = reqparse.RequestParser()
-        reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
-        reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
-        reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
-        reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
-        reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
-        reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
-        reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
-        reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
-        
+        reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
+        reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
+        reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
+        reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
+        reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
+        reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
+        reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
+        reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
+
         args = reqparser.parse_args()
 
-        if not args['workflow_tool_id']:
-            raise ValueError('incorrect workflow_tool_id')
-        
+        if not args["workflow_tool_id"]:
+            raise ValueError("incorrect workflow_tool_id")
+
         return WorkflowToolManageService.update_workflow_tool(
             user_id,
             tenant_id,
-            args['workflow_tool_id'],
-            args['name'],
-            args['label'],
-            args['icon'],
-            args['description'],
-            args['parameters'],
-            args['privacy_policy'],
-            args.get('labels', []),
+            args["workflow_tool_id"],
+            args["name"],
+            args["label"],
+            args["icon"],
+            args["description"],
+            args["parameters"],
+            args["privacy_policy"],
+            args.get("labels", []),
         )
 
+
 class ToolWorkflowProviderDeleteApi(Resource):
     @setup_required
     @login_required
@@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource):
     def post(self):
         if not current_user.is_admin_or_owner:
             raise Forbidden()
-        
+
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         reqparser = reqparse.RequestParser()
-        reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
+        reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
 
         args = reqparser.parse_args()
 
         return WorkflowToolManageService.delete_workflow_tool(
             user_id,
             tenant_id,
-            args['workflow_tool_id'],
+            args["workflow_tool_id"],
         )
-        
+
+
 class ToolWorkflowProviderGetApi(Resource):
     @setup_required
     @login_required
@@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
-        parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
+        parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
+        parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
 
         args = parser.parse_args()
 
-        if args.get('workflow_tool_id'):
+        if args.get("workflow_tool_id"):
             tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
                 user_id,
                 tenant_id,
-                args['workflow_tool_id'],
+                args["workflow_tool_id"],
             )
-        elif args.get('workflow_app_id'):
+        elif args.get("workflow_app_id"):
             tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
                 user_id,
                 tenant_id,
-                args['workflow_app_id'],
+                args["workflow_app_id"],
             )
         else:
-            raise ValueError('incorrect workflow_tool_id or workflow_app_id')
+            raise ValueError("incorrect workflow_tool_id or workflow_app_id")
 
         return jsonable_encoder(tool)
-    
+
+
 class ToolWorkflowProviderListToolApi(Resource):
     @setup_required
     @login_required
@@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource):
         tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
-        parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
+        parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
 
         args = parser.parse_args()
 
-        return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
-            user_id,
-            tenant_id,
-            args['workflow_tool_id'],
-        ))
+        return jsonable_encoder(
+            WorkflowToolManageService.list_single_workflow_tools(
+                user_id,
+                tenant_id,
+                args["workflow_tool_id"],
+            )
+        )
+
 
 class ToolBuiltinListApi(Resource):
     @setup_required
@@ -465,11 +498,17 @@ class ToolBuiltinListApi(Resource):
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
-        return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
-            user_id,
-            tenant_id,
-        )])
-    
+        return jsonable_encoder(
+            [
+                provider.to_dict()
+                for provider in BuiltinToolManageService.list_builtin_tools(
+                    user_id,
+                    tenant_id,
+                )
+            ]
+        )
+
+
 class ToolApiListApi(Resource):
     @setup_required
     @login_required
@@ -478,11 +517,17 @@ class ToolApiListApi(Resource):
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
-        return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
-            user_id,
-            tenant_id,
-        )])
-    
+        return jsonable_encoder(
+            [
+                provider.to_dict()
+                for provider in ApiToolManageService.list_api_tools(
+                    user_id,
+                    tenant_id,
+                )
+            ]
+        )
+
+
 class ToolWorkflowListApi(Resource):
     @setup_required
     @login_required
@@ -491,11 +536,17 @@ class ToolWorkflowListApi(Resource):
         user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
-        return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
-            user_id,
-            tenant_id,
-        )])
-    
+        return jsonable_encoder(
+            [
+                provider.to_dict()
+                for provider in WorkflowToolManageService.list_tenant_workflow_tools(
+                    user_id,
+                    tenant_id,
+                )
+            ]
+        )
+
+
 class ToolLabelsApi(Resource):
     @setup_required
     @login_required
@@ -503,36 +554,41 @@ class ToolLabelsApi(Resource):
     def get(self):
         return jsonable_encoder(ToolLabelsService.list_tool_labels())
 
+
 # tool provider
-api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
+api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
 
 # builtin tool provider
-api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
-api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
-api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
-api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
-api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
-api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
+api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
+api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
+api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
+api.add_resource(
+    ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
+)
+api.add_resource(
+    ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
+)
+api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
 
 # api tool provider
-api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
-api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
-api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
-api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
-api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
-api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
-api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
-api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
+api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
+api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
+api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
+api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
+api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
+api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
+api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
+api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
 
 # workflow tool provider
-api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
-api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
-api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
-api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
-api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
+api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
+api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
+api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
+api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
+api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
 
-api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
-api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
-api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
+api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
+api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
+api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
 
-api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')
+api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

+ 73 - 70
api/controllers/console/workspace/workspace.py

@@ -26,39 +26,34 @@ from services.file_service import FileService
 from services.workspace_service import WorkspaceService
 
 provider_fields = {
-    'provider_name': fields.String,
-    'provider_type': fields.String,
-    'is_valid': fields.Boolean,
-    'token_is_set': fields.Boolean,
+    "provider_name": fields.String,
+    "provider_type": fields.String,
+    "is_valid": fields.Boolean,
+    "token_is_set": fields.Boolean,
 }
 
 tenant_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'plan': fields.String,
-    'status': fields.String,
-    'created_at': TimestampField,
-    'role': fields.String,
-    'in_trial': fields.Boolean,
-    'trial_end_reason': fields.String,
-    'custom_config': fields.Raw(attribute='custom_config'),
+    "id": fields.String,
+    "name": fields.String,
+    "plan": fields.String,
+    "status": fields.String,
+    "created_at": TimestampField,
+    "role": fields.String,
+    "in_trial": fields.Boolean,
+    "trial_end_reason": fields.String,
+    "custom_config": fields.Raw(attribute="custom_config"),
 }
 
 tenants_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'plan': fields.String,
-    'status': fields.String,
-    'created_at': TimestampField,
-    'current': fields.Boolean
+    "id": fields.String,
+    "name": fields.String,
+    "plan": fields.String,
+    "status": fields.String,
+    "created_at": TimestampField,
+    "current": fields.Boolean,
 }
 
-workspace_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'status': fields.String,
-    'created_at': TimestampField
-}
+workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
 
 
 class TenantListApi(Resource):
@@ -71,7 +66,7 @@ class TenantListApi(Resource):
         for tenant in tenants:
             if tenant.id == current_user.current_tenant_id:
                 tenant.current = True  # Set current=True for current tenant
-        return {'workspaces': marshal(tenants, tenants_fields)}, 200
+        return {"workspaces": marshal(tenants, tenants_fields)}, 200
 
 
 class WorkspaceListApi(Resource):
@@ -79,31 +74,37 @@ class WorkspaceListApi(Resource):
     @admin_required
     def get(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
-        parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
+        parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
-        tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
-            .paginate(page=args['page'], per_page=args['limit'])
+        tenants = (
+            db.session.query(Tenant)
+            .order_by(Tenant.created_at.desc())
+            .paginate(page=args["page"], per_page=args["limit"])
+        )
 
         has_more = False
-        if len(tenants.items) == args['limit']:
+        if len(tenants.items) == args["limit"]:
             current_page_first_tenant = tenants[-1]
-            rest_count = db.session.query(Tenant).filter(
-                Tenant.created_at < current_page_first_tenant.created_at,
-                Tenant.id != current_page_first_tenant.id
-            ).count()
+            rest_count = (
+                db.session.query(Tenant)
+                .filter(
+                    Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id
+                )
+                .count()
+            )
 
             if rest_count > 0:
                 has_more = True
         total = db.session.query(Tenant).count()
         return {
-            'data': marshal(tenants.items, workspace_fields),
-            'has_more': has_more,
-            'limit': args['limit'],
-            'page': args['page'],
-            'total': total
-                }, 200
+            "data": marshal(tenants.items, workspace_fields),
+            "has_more": has_more,
+            "limit": args["limit"],
+            "page": args["page"],
+            "total": total,
+        }, 200
 
 
 class TenantApi(Resource):
@@ -112,8 +113,8 @@ class TenantApi(Resource):
     @account_initialization_required
     @marshal_with(tenant_fields)
     def get(self):
-        if request.path == '/info':
-            logging.warning('Deprecated URL /info was used.')
+        if request.path == "/info":
+            logging.warning("Deprecated URL /info was used.")
 
         tenant = current_user.current_tenant
 
@@ -125,7 +126,7 @@ class TenantApi(Resource):
                 tenant = tenants[0]
             # else, raise Unauthorized
             else:
-                raise Unauthorized('workspace is archived')
+                raise Unauthorized("workspace is archived")
 
         return WorkspaceService.get_tenant_info(tenant), 200
 
@@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource):
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('tenant_id', type=str, required=True, location='json')
+        parser.add_argument("tenant_id", type=str, required=True, location="json")
         args = parser.parse_args()
 
         # check if tenant_id is valid, 403 if not
         try:
-            TenantService.switch_tenant(current_user, args['tenant_id'])
+            TenantService.switch_tenant(current_user, args["tenant_id"])
         except Exception:
             raise AccountNotLinkTenantError("Account not link tenant")
 
-        new_tenant = db.session.query(Tenant).get(args['tenant_id'])  # Get new tenant
+        new_tenant = db.session.query(Tenant).get(args["tenant_id"])  # Get new tenant
+
+        return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
 
-        return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
-    
 
 class CustomConfigWorkspaceApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('workspace_custom')
+    @cloud_edition_billing_resource_check("workspace_custom")
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('remove_webapp_brand', type=bool, location='json')
-        parser.add_argument('replace_webapp_logo', type=str,  location='json')
+        parser.add_argument("remove_webapp_brand", type=bool, location="json")
+        parser.add_argument("replace_webapp_logo", type=str, location="json")
         args = parser.parse_args()
 
         tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
 
         custom_config_dict = {
-            'remove_webapp_brand': args['remove_webapp_brand'],
-            'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') ,
+            "remove_webapp_brand": args["remove_webapp_brand"],
+            "replace_webapp_logo": args["replace_webapp_logo"]
+            if args["replace_webapp_logo"] is not None
+            else tenant.custom_config_dict.get("replace_webapp_logo"),
         }
 
         tenant.custom_config_dict = custom_config_dict
         db.session.commit()
 
-        return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
-    
+        return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
+
 
 class WebappLogoWorkspaceApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
-    @cloud_edition_billing_resource_check('workspace_custom')
+    @cloud_edition_billing_resource_check("workspace_custom")
     def post(self):
         # get file from request
-        file = request.files['file']
+        file = request.files["file"]
 
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if len(request.files) > 1:
             raise TooManyFilesError()
 
-        extension = file.filename.split('.')[-1]
-        if extension.lower() not in ['svg', 'png']:
+        extension = file.filename.split(".")[-1]
+        if extension.lower() not in ["svg", "png"]:
             raise UnsupportedFileTypeError()
 
         try:
@@ -201,14 +204,14 @@ class WebappLogoWorkspaceApi(Resource):
             raise FileTooLargeError(file_too_large_error.description)
         except services.errors.file.UnsupportedFileTypeError:
             raise UnsupportedFileTypeError()
-        
-        return { 'id': upload_file.id }, 201
+
+        return {"id": upload_file.id}, 201
 
 
-api.add_resource(TenantListApi, '/workspaces')  # GET for getting all tenants
-api.add_resource(WorkspaceListApi, '/all-workspaces')  # GET for getting all tenants
-api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current')  # GET for getting current tenant info
-api.add_resource(TenantApi, '/info', endpoint='info')  # Deprecated
-api.add_resource(SwitchWorkspaceApi, '/workspaces/switch')  # POST for switching tenant
-api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
-api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')
+api.add_resource(TenantListApi, "/workspaces")  # GET for getting all tenants
+api.add_resource(WorkspaceListApi, "/all-workspaces")  # GET for getting all tenants
+api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current")  # GET for getting current tenant info
+api.add_resource(TenantApi, "/info", endpoint="info")  # Deprecated
+api.add_resource(SwitchWorkspaceApi, "/workspaces/switch")  # POST for switching tenant
+api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
+api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")

+ 21 - 18
api/controllers/console/wraps.py

@@ -16,7 +16,7 @@ def account_initialization_required(view):
         # check account initialization
         account = current_user
 
-        if account.status == 'uninitialized':
+        if account.status == "uninitialized":
             raise AccountNotInitializedError()
 
         return view(*args, **kwargs)
@@ -27,7 +27,7 @@ def account_initialization_required(view):
 def only_edition_cloud(view):
     @wraps(view)
     def decorated(*args, **kwargs):
-        if dify_config.EDITION != 'CLOUD':
+        if dify_config.EDITION != "CLOUD":
             abort(404)
 
         return view(*args, **kwargs)
@@ -38,7 +38,7 @@ def only_edition_cloud(view):
 def only_edition_self_hosted(view):
     @wraps(view)
     def decorated(*args, **kwargs):
-        if dify_config.EDITION != 'SELF_HOSTED':
+        if dify_config.EDITION != "SELF_HOSTED":
             abort(404)
 
         return view(*args, **kwargs)
@@ -46,8 +46,9 @@ def only_edition_self_hosted(view):
     return decorated
 
 
-def cloud_edition_billing_resource_check(resource: str,
-                                         error_msg: str = "You have reached the limit of your subscription."):
+def cloud_edition_billing_resource_check(
+    resource: str, error_msg: str = "You have reached the limit of your subscription."
+):
     def interceptor(view):
         @wraps(view)
         def decorated(*args, **kwargs):
@@ -58,22 +59,22 @@ def cloud_edition_billing_resource_check(resource: str,
                 vector_space = features.vector_space
                 documents_upload_quota = features.documents_upload_quota
                 annotation_quota_limit = features.annotation_quota_limit
-                if resource == 'members' and 0 < members.limit <= members.size:
+                if resource == "members" and 0 < members.limit <= members.size:
                     abort(403, error_msg)
-                elif resource == 'apps' and 0 < apps.limit <= apps.size:
+                elif resource == "apps" and 0 < apps.limit <= apps.size:
                     abort(403, error_msg)
-                elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
+                elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
                     abort(403, error_msg)
-                elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
+                elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
                     # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
-                    source = request.args.get('source')
-                    if source == 'datasets':
+                    source = request.args.get("source")
+                    if source == "datasets":
                         abort(403, error_msg)
                     else:
                         return view(*args, **kwargs)
-                elif resource == 'workspace_custom' and not features.can_replace_logo:
+                elif resource == "workspace_custom" and not features.can_replace_logo:
                     abort(403, error_msg)
-                elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
+                elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
                     abort(403, error_msg)
                 else:
                     return view(*args, **kwargs)
@@ -85,15 +86,17 @@ def cloud_edition_billing_resource_check(resource: str,
     return interceptor
 
 
-def cloud_edition_billing_knowledge_limit_check(resource: str,
-                                                error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
+def cloud_edition_billing_knowledge_limit_check(
+    resource: str,
+    error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
+):
     def interceptor(view):
         @wraps(view)
         def decorated(*args, **kwargs):
             features = FeatureService.get_features(current_user.current_tenant_id)
             if features.billing.enabled:
-                if resource == 'add_segment':
-                    if features.billing.subscription.plan == 'sandbox':
+                if resource == "add_segment":
+                    if features.billing.subscription.plan == "sandbox":
                         abort(403, error_msg)
                 else:
                     return view(*args, **kwargs)
@@ -112,7 +115,7 @@ def cloud_utm_record(view):
             features = FeatureService.get_features(current_user.current_tenant_id)
 
             if features.billing.enabled:
-                utm_info = request.cookies.get('utm_info')
+                utm_info = request.cookies.get("utm_info")
 
                 if utm_info:
                     utm_info = json.loads(utm_info)

+ 1 - 1
api/controllers/files/__init__.py

@@ -2,7 +2,7 @@ from flask import Blueprint
 
 from libs.external_api import ExternalApi
 
-bp = Blueprint('files', __name__)
+bp = Blueprint("files", __name__)
 api = ExternalApi(bp)
 
 

+ 11 - 16
api/controllers/files/image_preview.py

@@ -13,35 +13,30 @@ class ImagePreviewApi(Resource):
     def get(self, file_id):
         file_id = str(file_id)
 
-        timestamp = request.args.get('timestamp')
-        nonce = request.args.get('nonce')
-        sign = request.args.get('sign')
+        timestamp = request.args.get("timestamp")
+        nonce = request.args.get("nonce")
+        sign = request.args.get("sign")
 
         if not timestamp or not nonce or not sign:
-            return {'content': 'Invalid request.'}, 400
+            return {"content": "Invalid request."}, 400
 
         try:
-            generator, mimetype = FileService.get_image_preview(
-                file_id,
-                timestamp,
-                nonce,
-                sign
-            )
+            generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
         except services.errors.file.UnsupportedFileTypeError:
             raise UnsupportedFileTypeError()
 
         return Response(generator, mimetype=mimetype)
-    
+
 
 class WorkspaceWebappLogoApi(Resource):
     def get(self, workspace_id):
         workspace_id = str(workspace_id)
 
         custom_config = TenantService.get_custom_config(workspace_id)
-        webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
+        webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None
 
         if not webapp_logo_file_id:
-            raise NotFound('webapp logo is not found')
+            raise NotFound("webapp logo is not found")
 
         try:
             generator, mimetype = FileService.get_public_image_preview(
@@ -53,11 +48,11 @@ class WorkspaceWebappLogoApi(Resource):
         return Response(generator, mimetype=mimetype)
 
 
-api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
-api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
+api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
+api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
 
 
 class UnsupportedFileTypeError(BaseHTTPException):
-    error_code = 'unsupported_file_type'
+    error_code = "unsupported_file_type"
     description = "File type not allowed."
     code = 415

+ 16 - 13
api/controllers/files/tool_files.py

@@ -13,36 +13,39 @@ class ToolFilePreviewApi(Resource):
 
         parser = reqparse.RequestParser()
 
-        parser.add_argument('timestamp', type=str, required=True, location='args')
-        parser.add_argument('nonce', type=str, required=True, location='args')
-        parser.add_argument('sign', type=str, required=True, location='args')
+        parser.add_argument("timestamp", type=str, required=True, location="args")
+        parser.add_argument("nonce", type=str, required=True, location="args")
+        parser.add_argument("sign", type=str, required=True, location="args")
 
         args = parser.parse_args()
 
-        if not ToolFileManager.verify_file(file_id=file_id,
-                                            timestamp=args['timestamp'],
-                                            nonce=args['nonce'],
-                                            sign=args['sign'],
+        if not ToolFileManager.verify_file(
+            file_id=file_id,
+            timestamp=args["timestamp"],
+            nonce=args["nonce"],
+            sign=args["sign"],
         ):
-            raise Forbidden('Invalid request.')
-        
+            raise Forbidden("Invalid request.")
+
         try:
             result = ToolFileManager.get_file_generator_by_tool_file_id(
                 file_id,
             )
 
             if not result:
-                raise NotFound('file is not found')
-            
+                raise NotFound("file is not found")
+
             generator, mimetype = result
         except Exception:
             raise UnsupportedFileTypeError()
 
         return Response(generator, mimetype=mimetype)
 
-api.add_resource(ToolFilePreviewApi, '/files/tools/<uuid:file_id>.<string:extension>')
+
+api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>")
+
 
 class UnsupportedFileTypeError(BaseHTTPException):
-    error_code = 'unsupported_file_type'
+    error_code = "unsupported_file_type"
     description = "File type not allowed."
     code = 415

+ 1 - 2
api/controllers/inner_api/__init__.py

@@ -2,8 +2,7 @@ from flask import Blueprint
 
 from libs.external_api import ExternalApi
 
-bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
+bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
 api = ExternalApi(bp)
 
 from .workspace import workspace
-

+ 8 - 13
api/controllers/inner_api/workspace/workspace.py

@@ -9,29 +9,24 @@ from services.account_service import TenantService
 
 
 class EnterpriseWorkspace(Resource):
-
     @setup_required
     @inner_api_only
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
-        parser.add_argument('owner_email', type=str, required=True, location='json')
+        parser.add_argument("name", type=str, required=True, location="json")
+        parser.add_argument("owner_email", type=str, required=True, location="json")
         args = parser.parse_args()
 
-        account = Account.query.filter_by(email=args['owner_email']).first()
+        account = Account.query.filter_by(email=args["owner_email"]).first()
         if account is None:
-            return {
-                'message': 'owner account not found.'
-            }, 404
+            return {"message": "owner account not found."}, 404
 
-        tenant = TenantService.create_tenant(args['name'])
-        TenantService.create_tenant_member(tenant, account, role='owner')
+        tenant = TenantService.create_tenant(args["name"])
+        TenantService.create_tenant_member(tenant, account, role="owner")
 
         tenant_was_created.send(tenant)
 
-        return {
-            'message': 'enterprise workspace created.'
-        }
+        return {"message": "enterprise workspace created."}
 
 
-api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')
+api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")

+ 10 - 10
api/controllers/inner_api/wraps.py

@@ -17,7 +17,7 @@ def inner_api_only(view):
             abort(404)
 
         # get header 'X-Inner-Api-Key'
-        inner_api_key = request.headers.get('X-Inner-Api-Key')
+        inner_api_key = request.headers.get("X-Inner-Api-Key")
         if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
             abort(401)
 
@@ -33,29 +33,29 @@ def inner_api_user_auth(view):
             return view(*args, **kwargs)
 
         # get header 'X-Inner-Api-Key'
-        authorization = request.headers.get('Authorization')
+        authorization = request.headers.get("Authorization")
         if not authorization:
             return view(*args, **kwargs)
 
-        parts = authorization.split(':')
+        parts = authorization.split(":")
         if len(parts) != 2:
             return view(*args, **kwargs)
 
         user_id, token = parts
-        if ' ' in user_id:
-            user_id = user_id.split(' ')[1]
+        if " " in user_id:
+            user_id = user_id.split(" ")[1]
 
-        inner_api_key = request.headers.get('X-Inner-Api-Key')
+        inner_api_key = request.headers.get("X-Inner-Api-Key")
 
-        data_to_sign = f'DIFY {user_id}'
+        data_to_sign = f"DIFY {user_id}"
 
-        signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
-        signature = b64encode(signature.digest()).decode('utf-8')
+        signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
+        signature = b64encode(signature.digest()).decode("utf-8")
 
         if signature != token:
             return view(*args, **kwargs)
 
-        kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
+        kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
 
         return view(*args, **kwargs)
 

+ 1 - 1
api/controllers/service_api/__init__.py

@@ -2,7 +2,7 @@ from flask import Blueprint
 
 from libs.external_api import ExternalApi
 
-bp = Blueprint('service_api', __name__, url_prefix='/v1')
+bp = Blueprint("service_api", __name__, url_prefix="/v1")
 api = ExternalApi(bp)
 
 

+ 52 - 52
api/controllers/service_api/app/app.py

@@ -1,4 +1,3 @@
-
 from flask_restful import Resource, fields, marshal_with
 
 from configs import dify_config
@@ -13,32 +12,30 @@ class AppParameterApi(Resource):
     """Resource for app variables."""
 
     variable_fields = {
-        'key': fields.String,
-        'name': fields.String,
-        'description': fields.String,
-        'type': fields.String,
-        'default': fields.String,
-        'max_length': fields.Integer,
-        'options': fields.List(fields.String)
+        "key": fields.String,
+        "name": fields.String,
+        "description": fields.String,
+        "type": fields.String,
+        "default": fields.String,
+        "max_length": fields.Integer,
+        "options": fields.List(fields.String),
     }
 
-    system_parameters_fields = {
-        'image_file_size_limit': fields.String
-    }
+    system_parameters_fields = {"image_file_size_limit": fields.String}
 
     parameters_fields = {
-        'opening_statement': fields.String,
-        'suggested_questions': fields.Raw,
-        'suggested_questions_after_answer': fields.Raw,
-        'speech_to_text': fields.Raw,
-        'text_to_speech': fields.Raw,
-        'retriever_resource': fields.Raw,
-        'annotation_reply': fields.Raw,
-        'more_like_this': fields.Raw,
-        'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw,
-        'file_upload': fields.Raw,
-        'system_parameters': fields.Nested(system_parameters_fields)
+        "opening_statement": fields.String,
+        "suggested_questions": fields.Raw,
+        "suggested_questions_after_answer": fields.Raw,
+        "speech_to_text": fields.Raw,
+        "text_to_speech": fields.Raw,
+        "retriever_resource": fields.Raw,
+        "annotation_reply": fields.Raw,
+        "more_like_this": fields.Raw,
+        "user_input_form": fields.Raw,
+        "sensitive_word_avoidance": fields.Raw,
+        "file_upload": fields.Raw,
+        "system_parameters": fields.Nested(system_parameters_fields),
     }
 
     @validate_app_token
@@ -56,30 +53,35 @@ class AppParameterApi(Resource):
             app_model_config = app_model.app_model_config
             features_dict = app_model_config.to_dict()
 
-            user_input_form = features_dict.get('user_input_form', [])
+            user_input_form = features_dict.get("user_input_form", [])
 
         return {
-            'opening_statement': features_dict.get('opening_statement'),
-            'suggested_questions': features_dict.get('suggested_questions', []),
-            'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
-                                                                  {"enabled": False}),
-            'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
-            'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
-            'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
-            'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
-            'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
-            'user_input_form': user_input_form,
-            'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
-                                                          {"enabled": False, "type": "", "configs": []}),
-            'file_upload': features_dict.get('file_upload', {"image": {
-                                                     "enabled": False,
-                                                     "number_limits": 3,
-                                                     "detail": "high",
-                                                     "transfer_methods": ["remote_url", "local_file"]
-                                                 }}),
-            'system_parameters': {
-                'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
-            }
+            "opening_statement": features_dict.get("opening_statement"),
+            "suggested_questions": features_dict.get("suggested_questions", []),
+            "suggested_questions_after_answer": features_dict.get(
+                "suggested_questions_after_answer", {"enabled": False}
+            ),
+            "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
+            "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
+            "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
+            "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
+            "more_like_this": features_dict.get("more_like_this", {"enabled": False}),
+            "user_input_form": user_input_form,
+            "sensitive_word_avoidance": features_dict.get(
+                "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
+            ),
+            "file_upload": features_dict.get(
+                "file_upload",
+                {
+                    "image": {
+                        "enabled": False,
+                        "number_limits": 3,
+                        "detail": "high",
+                        "transfer_methods": ["remote_url", "local_file"],
+                    }
+                },
+            ),
+            "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
         }
 
 
@@ -89,16 +91,14 @@ class AppMetaApi(Resource):
         """Get app meta"""
         return AppService().get_app_meta(app_model)
 
+
 class AppInfoApi(Resource):
     @validate_app_token
     def get(self, app_model: App):
         """Get app information"""
-        return {
-            'name':app_model.name,
-            'description':app_model.description
-        } 
+        return {"name": app_model.name, "description": app_model.description}
 
 
-api.add_resource(AppParameterApi, '/parameters')
-api.add_resource(AppMetaApi, '/meta')
-api.add_resource(AppInfoApi, '/info')
+api.add_resource(AppParameterApi, "/parameters")
+api.add_resource(AppMetaApi, "/meta")
+api.add_resource(AppInfoApi, "/info")

+ 23 - 25
api/controllers/service_api/app/audio.py

@@ -33,14 +33,10 @@ from services.errors.audio import (
 class AudioApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
     def post(self, app_model: App, end_user: EndUser):
-        file = request.files['file']
+        file = request.files["file"]
 
         try:
-            response = AudioService.transcript_asr(
-                app_model=app_model,
-                file=file,
-                end_user=end_user
-            )
+            response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
 
             return response
         except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -74,30 +70,32 @@ class TextApi(Resource):
     def post(self, app_model: App, end_user: EndUser):
         try:
             parser = reqparse.RequestParser()
-            parser.add_argument('message_id', type=str, required=False, location='json')
-            parser.add_argument('voice', type=str, location='json')
-            parser.add_argument('text', type=str, location='json')
-            parser.add_argument('streaming', type=bool, location='json')
+            parser.add_argument("message_id", type=str, required=False, location="json")
+            parser.add_argument("voice", type=str, location="json")
+            parser.add_argument("text", type=str, location="json")
+            parser.add_argument("streaming", type=bool, location="json")
             args = parser.parse_args()
 
-            message_id = args.get('message_id', None)
-            text = args.get('text', None)
-            if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
-                    and app_model.workflow
-                    and app_model.workflow.features_dict):
-                text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
-                voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
+            message_id = args.get("message_id", None)
+            text = args.get("text", None)
+            if (
+                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                and app_model.workflow
+                and app_model.workflow.features_dict
+            ):
+                text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
+                voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
             else:
                 try:
-                    voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
+                    voice = (
+                        args.get("voice")
+                        if args.get("voice")
+                        else app_model.app_model_config.text_to_speech_dict.get("voice")
+                    )
                 except Exception:
                     voice = None
             response = AudioService.transcript_tts(
-                app_model=app_model,
-                message_id=message_id,
-                end_user=end_user.external_user_id,
-                voice=voice,
-                text=text
+                app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
             )
 
             return response
@@ -127,5 +125,5 @@ class TextApi(Resource):
             raise InternalServerError()
 
 
-api.add_resource(AudioApi, '/audio-to-text')
-api.add_resource(TextApi, '/text-to-audio')
+api.add_resource(AudioApi, "/audio-to-text")
+api.add_resource(TextApi, "/text-to-audio")

+ 24 - 28
api/controllers/service_api/app/completion.py

@@ -33,21 +33,21 @@ from services.app_generate_service import AppGenerateService
 class CompletionApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise AppUnavailableError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, location='json', default='')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, location="json", default="")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
 
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
+        streaming = args["response_mode"] == "streaming"
 
-        args['auto_generate_name'] = False
+        args["auto_generate_name"] = False
 
         try:
             response = AppGenerateService.generate(
@@ -84,12 +84,12 @@ class CompletionApi(Resource):
 class CompletionStopApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser, task_id):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise AppUnavailableError()
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class ChatApi(Resource):
@@ -100,25 +100,21 @@ class ChatApi(Resource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, required=True, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('conversation_id', type=uuid_value, location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
-        parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, required=True, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("conversation_id", type=uuid_value, location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
+        parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
 
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
+        streaming = args["response_mode"] == "streaming"
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=end_user,
-                args=args,
-                invoke_from=InvokeFrom.SERVICE_API,
-                streaming=streaming
+                app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -153,10 +149,10 @@ class ChatStopApi(Resource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
-api.add_resource(CompletionApi, '/completion-messages')
-api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
-api.add_resource(ChatApi, '/chat-messages')
-api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')
+api.add_resource(CompletionApi, "/completion-messages")
+api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
+api.add_resource(ChatApi, "/chat-messages")
+api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")

+ 20 - 22
api/controllers/service_api/app/conversation.py

@@ -14,7 +14,6 @@ from services.conversation_service import ConversationService
 
 
 class ConversationApi(Resource):
-
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, app_model: App, end_user: EndUser):
@@ -23,20 +22,26 @@ class ConversationApi(Resource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
-                            required=False, default='-updated_at', location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
+        parser.add_argument(
+            "sort_by",
+            type=str,
+            choices=["created_at", "-created_at", "updated_at", "-updated_at"],
+            required=False,
+            default="-updated_at",
+            location="args",
+        )
         args = parser.parse_args()
 
         try:
             return ConversationService.pagination_by_last_id(
                 app_model=app_model,
                 user=end_user,
-                last_id=args['last_id'],
-                limit=args['limit'],
+                last_id=args["last_id"],
+                limit=args["limit"],
                 invoke_from=InvokeFrom.SERVICE_API,
-                sort_by=args['sort_by']
+                sort_by=args["sort_by"],
             )
         except services.errors.conversation.LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
@@ -56,11 +61,10 @@ class ConversationDetailApi(Resource):
             ConversationService.delete(app_model, conversation_id, end_user)
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class ConversationRenameApi(Resource):
-
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
     @marshal_with(simple_conversation_fields)
     def post(self, app_model: App, end_user: EndUser, c_id):
@@ -71,22 +75,16 @@ class ConversationRenameApi(Resource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=False, location='json')
-        parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
+        parser.add_argument("name", type=str, required=False, location="json")
+        parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
         args = parser.parse_args()
 
         try:
-            return ConversationService.rename(
-                app_model,
-                conversation_id,
-                end_user,
-                args['name'],
-                args['auto_generate']
-            )
+            return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 
 
-api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
-api.add_resource(ConversationApi, '/conversations')
-api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')
+api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
+api.add_resource(ConversationApi, "/conversations")
+api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")

+ 25 - 21
api/controllers/service_api/app/error.py

@@ -2,104 +2,108 @@ from libs.exception import BaseHTTPException
 
 
 class AppUnavailableError(BaseHTTPException):
-    error_code = 'app_unavailable'
+    error_code = "app_unavailable"
     description = "App unavailable, please check your app configurations."
     code = 400
 
 
 class NotCompletionAppError(BaseHTTPException):
-    error_code = 'not_completion_app'
+    error_code = "not_completion_app"
     description = "Please check if your Completion app mode matches the right API route."
     code = 400
 
 
 class NotChatAppError(BaseHTTPException):
-    error_code = 'not_chat_app'
+    error_code = "not_chat_app"
     description = "Please check if your app mode matches the right API route."
     code = 400
 
 
 class NotWorkflowAppError(BaseHTTPException):
-    error_code = 'not_workflow_app'
+    error_code = "not_workflow_app"
     description = "Please check if your app mode matches the right API route."
     code = 400
 
 
 class ConversationCompletedError(BaseHTTPException):
-    error_code = 'conversation_completed'
+    error_code = "conversation_completed"
     description = "The conversation has ended. Please start a new conversation."
     code = 400
 
 
 class ProviderNotInitializeError(BaseHTTPException):
-    error_code = 'provider_not_initialize'
-    description = "No valid model provider credentials found. " \
-                  "Please go to Settings -> Model Provider to complete your provider credentials."
+    error_code = "provider_not_initialize"
+    description = (
+        "No valid model provider credentials found. "
+        "Please go to Settings -> Model Provider to complete your provider credentials."
+    )
     code = 400
 
 
 class ProviderQuotaExceededError(BaseHTTPException):
-    error_code = 'provider_quota_exceeded'
-    description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
-                  "Please go to Settings -> Model Provider to complete your own provider credentials."
+    error_code = "provider_quota_exceeded"
+    description = (
+        "Your quota for Dify Hosted OpenAI has been exhausted. "
+        "Please go to Settings -> Model Provider to complete your own provider credentials."
+    )
     code = 400
 
 
 class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
-    error_code = 'model_currently_not_support'
+    error_code = "model_currently_not_support"
     description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
     code = 400
 
 
 class CompletionRequestError(BaseHTTPException):
-    error_code = 'completion_request_error'
+    error_code = "completion_request_error"
     description = "Completion request failed."
     code = 400
 
 
 class NoAudioUploadedError(BaseHTTPException):
-    error_code = 'no_audio_uploaded'
+    error_code = "no_audio_uploaded"
     description = "Please upload your audio."
     code = 400
 
 
 class AudioTooLargeError(BaseHTTPException):
-    error_code = 'audio_too_large'
+    error_code = "audio_too_large"
     description = "Audio size exceeded. {message}"
     code = 413
 
 
 class UnsupportedAudioTypeError(BaseHTTPException):
-    error_code = 'unsupported_audio_type'
+    error_code = "unsupported_audio_type"
     description = "Audio type not allowed."
     code = 415
 
 
 class ProviderNotSupportSpeechToTextError(BaseHTTPException):
-    error_code = 'provider_not_support_speech_to_text'
+    error_code = "provider_not_support_speech_to_text"
     description = "Provider not support speech to text."
     code = 400
 
 
 class NoFileUploadedError(BaseHTTPException):
-    error_code = 'no_file_uploaded'
+    error_code = "no_file_uploaded"
     description = "Please upload your file."
     code = 400
 
 
 class TooManyFilesError(BaseHTTPException):
-    error_code = 'too_many_files'
+    error_code = "too_many_files"
     description = "Only one file is allowed."
     code = 400
 
 
 class FileTooLargeError(BaseHTTPException):
-    error_code = 'file_too_large'
+    error_code = "file_too_large"
     description = "File size exceeded. {message}"
     code = 413
 
 
 class UnsupportedFileTypeError(BaseHTTPException):
-    error_code = 'unsupported_file_type'
+    error_code = "unsupported_file_type"
     description = "File type not allowed."
     code = 415

+ 3 - 5
api/controllers/service_api/app/file.py

@@ -16,15 +16,13 @@ from services.file_service import FileService
 
 
 class FileApi(Resource):
-
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
     @marshal_with(file_fields)
     def post(self, app_model: App, end_user: EndUser):
-
-        file = request.files['file']
+        file = request.files["file"]
 
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if not file.mimetype:
@@ -43,4 +41,4 @@ class FileApi(Resource):
         return upload_file, 201
 
 
-api.add_resource(FileApi, '/files/upload')
+api.add_resource(FileApi, "/files/upload")

+ 57 - 61
api/controllers/service_api/app/message.py

@@ -17,61 +17,59 @@ from services.message_service import MessageService
 
 
 class MessageListApi(Resource):
-    feedback_fields = {
-        'rating': fields.String
-    }
+    feedback_fields = {"rating": fields.String}
     retriever_resource_fields = {
-        'id': fields.String,
-        'message_id': fields.String,
-        'position': fields.Integer,
-        'dataset_id': fields.String,
-        'dataset_name': fields.String,
-        'document_id': fields.String,
-        'document_name': fields.String,
-        'data_source_type': fields.String,
-        'segment_id': fields.String,
-        'score': fields.Float,
-        'hit_count': fields.Integer,
-        'word_count': fields.Integer,
-        'segment_position': fields.Integer,
-        'index_node_hash': fields.String,
-        'content': fields.String,
-        'created_at': TimestampField
+        "id": fields.String,
+        "message_id": fields.String,
+        "position": fields.Integer,
+        "dataset_id": fields.String,
+        "dataset_name": fields.String,
+        "document_id": fields.String,
+        "document_name": fields.String,
+        "data_source_type": fields.String,
+        "segment_id": fields.String,
+        "score": fields.Float,
+        "hit_count": fields.Integer,
+        "word_count": fields.Integer,
+        "segment_position": fields.Integer,
+        "index_node_hash": fields.String,
+        "content": fields.String,
+        "created_at": TimestampField,
     }
 
     agent_thought_fields = {
-        'id': fields.String,
-        'chain_id': fields.String,
-        'message_id': fields.String,
-        'position': fields.Integer,
-        'thought': fields.String,
-        'tool': fields.String,
-        'tool_labels': fields.Raw,
-        'tool_input': fields.String,
-        'created_at': TimestampField,
-        'observation': fields.String,
-        'message_files': fields.List(fields.String, attribute='files')
+        "id": fields.String,
+        "chain_id": fields.String,
+        "message_id": fields.String,
+        "position": fields.Integer,
+        "thought": fields.String,
+        "tool": fields.String,
+        "tool_labels": fields.Raw,
+        "tool_input": fields.String,
+        "created_at": TimestampField,
+        "observation": fields.String,
+        "message_files": fields.List(fields.String, attribute="files"),
     }
 
     message_fields = {
-        'id': fields.String,
-        'conversation_id': fields.String,
-        'inputs': fields.Raw,
-        'query': fields.String,
-        'answer': fields.String(attribute='re_sign_file_url_answer'),
-        'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
-        'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
-        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
-        'created_at': TimestampField,
-        'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
-        'status': fields.String,
-        'error': fields.String,
+        "id": fields.String,
+        "conversation_id": fields.String,
+        "inputs": fields.Raw,
+        "query": fields.String,
+        "answer": fields.String(attribute="re_sign_file_url_answer"),
+        "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+        "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
+        "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
+        "created_at": TimestampField,
+        "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
+        "status": fields.String,
+        "error": fields.String,
     }
 
     message_infinite_scroll_pagination_fields = {
-        'limit': fields.Integer,
-        'has_more': fields.Boolean,
-        'data': fields.List(fields.Nested(message_fields))
+        "limit": fields.Integer,
+        "has_more": fields.Boolean,
+        "data": fields.List(fields.Nested(message_fields)),
     }
 
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@@ -82,14 +80,15 @@ class MessageListApi(Resource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
-        parser.add_argument('first_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
+        parser.add_argument("first_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
         try:
-            return MessageService.pagination_by_first_id(app_model, end_user,
-                                                         args['conversation_id'], args['first_id'], args['limit'])
+            return MessageService.pagination_by_first_id(
+                app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
+            )
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
         except services.errors.message.FirstMessageNotExistsError:
@@ -102,15 +101,15 @@ class MessageFeedbackApi(Resource):
         message_id = str(message_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
+        parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
         args = parser.parse_args()
 
         try:
-            MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
+            MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
         except services.errors.message.MessageNotExistsError:
             raise NotFound("Message Not Exists.")
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class MessageSuggestedApi(Resource):
@@ -123,10 +122,7 @@ class MessageSuggestedApi(Resource):
 
         try:
             questions = MessageService.get_suggested_questions_after_answer(
-                app_model=app_model,
-                user=end_user,
-                message_id=message_id,
-                invoke_from=InvokeFrom.SERVICE_API
+                app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
             )
         except services.errors.message.MessageNotExistsError:
             raise NotFound("Message Not Exists.")
@@ -136,9 +132,9 @@ class MessageSuggestedApi(Resource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
-        return {'result': 'success', 'data': questions}
+        return {"result": "success", "data": questions}
 
 
-api.add_resource(MessageListApi, '/messages')
-api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
-api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')
+api.add_resource(MessageListApi, "/messages")
+api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
+api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")

+ 23 - 26
api/controllers/service_api/app/workflow.py

@@ -30,19 +30,20 @@ from services.app_generate_service import AppGenerateService
 logger = logging.getLogger(__name__)
 
 workflow_run_fields = {
-    'id': fields.String,
-    'workflow_id': fields.String,
-    'status': fields.String,
-    'inputs': fields.Raw,
-    'outputs': fields.Raw,
-    'error': fields.String,
-    'total_steps': fields.Integer,
-    'total_tokens': fields.Integer,
-    'created_at': fields.DateTime,
-    'finished_at': fields.DateTime,
-    'elapsed_time': fields.Float,
+    "id": fields.String,
+    "workflow_id": fields.String,
+    "status": fields.String,
+    "inputs": fields.Raw,
+    "outputs": fields.Raw,
+    "error": fields.String,
+    "total_steps": fields.Integer,
+    "total_tokens": fields.Integer,
+    "created_at": fields.DateTime,
+    "finished_at": fields.DateTime,
+    "elapsed_time": fields.Float,
 }
 
+
 class WorkflowRunDetailApi(Resource):
     @validate_app_token
     @marshal_with(workflow_run_fields)
@@ -56,6 +57,8 @@ class WorkflowRunDetailApi(Resource):
 
         workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
         return workflow_run
+
+
 class WorkflowRunApi(Resource):
     @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
     def post(self, app_model: App, end_user: EndUser):
@@ -67,20 +70,16 @@ class WorkflowRunApi(Resource):
             raise NotWorkflowAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
         args = parser.parse_args()
 
-        streaming = args.get('response_mode') == 'streaming'
+        streaming = args.get("response_mode") == "streaming"
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=end_user,
-                args=args,
-                invoke_from=InvokeFrom.SERVICE_API,
-                streaming=streaming
+                app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -111,11 +110,9 @@ class WorkflowTaskStopApi(Resource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
 
-        return {
-            "result": "success"
-        }
+        return {"result": "success"}
 
 
-api.add_resource(WorkflowRunApi, '/workflows/run')
-api.add_resource(WorkflowRunDetailApi, '/workflows/run/<string:workflow_id>')
-api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
+api.add_resource(WorkflowRunApi, "/workflows/run")
+api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
+api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")

+ 46 - 42
api/controllers/service_api/dataset/dataset.py

@@ -16,7 +16,7 @@ from services.dataset_service import DatasetService
 
 def _validate_name(name):
     if not name or len(name) < 1 or len(name) > 40:
-        raise ValueError('Name must be between 1 to 40 characters.')
+        raise ValueError("Name must be between 1 to 40 characters.")
     return name
 
 
@@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource):
     def get(self, tenant_id):
         """Resource for getting datasets."""
 
-        page = request.args.get('page', default=1, type=int)
-        limit = request.args.get('limit', default=20, type=int)
-        provider = request.args.get('provider', default="vendor")
-        search = request.args.get('keyword', default=None, type=str)
-        tag_ids = request.args.getlist('tag_ids')
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        provider = request.args.get("provider", default="vendor")
+        search = request.args.get("keyword", default=None, type=str)
+        tag_ids = request.args.getlist("tag_ids")
 
-        datasets, total = DatasetService.get_datasets(page, limit, provider,
-                                                      tenant_id, current_user, search, tag_ids)
+        datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
         # check embedding setting
         provider_manager = ProviderManager()
-        configurations = provider_manager.get_configurations(
-            tenant_id=current_user.current_tenant_id
-        )
+        configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
 
-        embedding_models = configurations.get_models(
-            model_type=ModelType.TEXT_EMBEDDING,
-            only_active=True
-        )
+        embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
 
         model_names = []
         for embedding_model in embedding_models:
@@ -51,50 +45,59 @@ class DatasetListApi(DatasetApiResource):
 
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
-            if item['indexing_technique'] == 'high_quality':
+            if item["indexing_technique"] == "high_quality":
                 item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
                 if item_model in model_names:
-                    item['embedding_available'] = True
+                    item["embedding_available"] = True
                 else:
-                    item['embedding_available'] = False
+                    item["embedding_available"] = False
             else:
-                item['embedding_available'] = True
-        response = {
-            'data': data,
-            'has_more': len(datasets) == limit,
-            'limit': limit,
-            'total': total,
-            'page': page
-        }
+                item["embedding_available"] = True
+        response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
         return response, 200
 
-
     def post(self, tenant_id):
         """Resource for creating datasets."""
         parser = reqparse.RequestParser()
-        parser.add_argument('name', nullable=False, required=True,
-                            help='type is required. Name must be between 1 to 40 characters.',
-                            type=_validate_name)
-        parser.add_argument('indexing_technique', type=str, location='json',
-                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
-                            help='Invalid indexing technique.')
-        parser.add_argument('permission', type=str, location='json', choices=(
-            DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False)
+        parser.add_argument(
+            "name",
+            nullable=False,
+            required=True,
+            help="type is required. Name must be between 1 to 40 characters.",
+            type=_validate_name,
+        )
+        parser.add_argument(
+            "indexing_technique",
+            type=str,
+            location="json",
+            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+            help="Invalid indexing technique.",
+        )
+        parser.add_argument(
+            "permission",
+            type=str,
+            location="json",
+            choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
+            help="Invalid permission.",
+            required=False,
+            nullable=False,
+        )
         args = parser.parse_args()
 
         try:
             dataset = DatasetService.create_empty_dataset(
                 tenant_id=tenant_id,
-                name=args['name'],
-                indexing_technique=args['indexing_technique'],
+                name=args["name"],
+                indexing_technique=args["indexing_technique"],
                 account=current_user,
-                permission=args['permission']
+                permission=args["permission"],
             )
         except services.errors.dataset.DatasetNameDuplicateError:
             raise DatasetNameDuplicateError()
 
         return marshal(dataset, dataset_detail_fields), 200
 
+
 class DatasetApi(DatasetApiResource):
     """Resource for dataset."""
 
@@ -106,7 +109,7 @@ class DatasetApi(DatasetApiResource):
             dataset_id (UUID): The ID of the dataset to be deleted.
 
         Returns:
-            dict: A dictionary with a key 'result' and a value 'success' 
+            dict: A dictionary with a key 'result' and a value 'success'
                   if the dataset was successfully deleted. Omitted in HTTP response.
             int: HTTP status code 204 indicating that the operation was successful.
 
@@ -118,11 +121,12 @@ class DatasetApi(DatasetApiResource):
 
         try:
             if DatasetService.delete_dataset(dataset_id_str, current_user):
-                return {'result': 'success'}, 204
+                return {"result": "success"}, 204
             else:
                 raise NotFound("Dataset not found.")
         except services.errors.dataset.DatasetInUseError:
             raise DatasetInUseError()
 
-api.add_resource(DatasetListApi, '/datasets')
-api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
+
+api.add_resource(DatasetListApi, "/datasets")
+api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")

+ 119 - 176
api/controllers/service_api/dataset/document.py

@@ -27,47 +27,40 @@ from services.file_service import FileService
 class DocumentAddByTextApi(DatasetApiResource):
     """Resource for documents."""
 
-    @cloud_edition_billing_resource_check('vector_space', 'dataset')
-    @cloud_edition_billing_resource_check('documents', 'dataset')
+    @cloud_edition_billing_resource_check("vector_space", "dataset")
+    @cloud_edition_billing_resource_check("documents", "dataset")
     def post(self, tenant_id, dataset_id):
         """Create document by text."""
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('text', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
-        parser.add_argument('original_document_id', type=str, required=False, location='json')
-        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
-                            location='json')
-        parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
-                            location='json')
-        parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
-                            location='json')
+        parser.add_argument("name", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("text", type=str, required=True, nullable=False, location="json")
+        parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
+        parser.add_argument("original_document_id", type=str, required=False, location="json")
+        parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+        parser.add_argument(
+            "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+        )
+        parser.add_argument(
+            "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
+        )
+        parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
         args = parser.parse_args()
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
 
         if not dataset:
-            raise ValueError('Dataset is not exist.')
+            raise ValueError("Dataset is not exist.")
 
-        if not dataset.indexing_technique and not args['indexing_technique']:
-            raise ValueError('indexing_technique is required.')
+        if not dataset.indexing_technique and not args["indexing_technique"]:
+            raise ValueError("indexing_technique is required.")
 
-        upload_file = FileService.upload_text(args.get('text'), args.get('name'))
+        upload_file = FileService.upload_text(args.get("text"), args.get("name"))
         data_source = {
-            'type': 'upload_file',
-            'info_list': {
-                'data_source_type': 'upload_file',
-                'file_info_list': {
-                    'file_ids': [upload_file.id]
-                }
-            }
+            "type": "upload_file",
+            "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
         }
-        args['data_source'] = data_source
+        args["data_source"] = data_source
         # validate args
         DocumentService.document_create_args_validate(args)
 
@@ -76,60 +69,49 @@ class DocumentAddByTextApi(DatasetApiResource):
                 dataset=dataset,
                 document_data=args,
                 account=current_user,
-                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
-                created_from='api'
+                dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
+                created_from="api",
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         document = documents[0]
 
-        documents_and_batch_fields = {
-            'document': marshal(document, document_fields),
-            'batch': batch
-        }
+        documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
         return documents_and_batch_fields, 200
 
 
 class DocumentUpdateByTextApi(DatasetApiResource):
     """Resource for update documents."""
 
-    @cloud_edition_billing_resource_check('vector_space', 'dataset')
+    @cloud_edition_billing_resource_check("vector_space", "dataset")
     def post(self, tenant_id, dataset_id, document_id):
         """Update document by text."""
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=False, nullable=True, location='json')
-        parser.add_argument('text', type=str, required=False, nullable=True, location='json')
-        parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
-        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
-                            location='json')
-        parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
-                            location='json')
+        parser.add_argument("name", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("text", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
+        parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+        parser.add_argument(
+            "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+        )
+        parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
         args = parser.parse_args()
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
 
         if not dataset:
-            raise ValueError('Dataset is not exist.')
+            raise ValueError("Dataset is not exist.")
 
-        if args['text']:
-            upload_file = FileService.upload_text(args.get('text'), args.get('name'))
+        if args["text"]:
+            upload_file = FileService.upload_text(args.get("text"), args.get("name"))
             data_source = {
-                'type': 'upload_file',
-                'info_list': {
-                    'data_source_type': 'upload_file',
-                    'file_info_list': {
-                        'file_ids': [upload_file.id]
-                    }
-                }
+                "type": "upload_file",
+                "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
             }
-            args['data_source'] = data_source
+            args["data_source"] = data_source
         # validate args
-        args['original_document_id'] = str(document_id)
+        args["original_document_id"] = str(document_id)
         DocumentService.document_create_args_validate(args)
 
         try:
@@ -137,65 +119,53 @@ class DocumentUpdateByTextApi(DatasetApiResource):
                 dataset=dataset,
                 document_data=args,
                 account=current_user,
-                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
-                created_from='api'
+                dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
+                created_from="api",
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         document = documents[0]
 
-        documents_and_batch_fields = {
-            'document': marshal(document, document_fields),
-            'batch': batch
-        }
+        documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
         return documents_and_batch_fields, 200
 
 
 class DocumentAddByFileApi(DatasetApiResource):
     """Resource for documents."""
-    @cloud_edition_billing_resource_check('vector_space', 'dataset')
-    @cloud_edition_billing_resource_check('documents', 'dataset')
+
+    @cloud_edition_billing_resource_check("vector_space", "dataset")
+    @cloud_edition_billing_resource_check("documents", "dataset")
     def post(self, tenant_id, dataset_id):
         """Create document by upload file."""
         args = {}
-        if 'data' in request.form:
-            args = json.loads(request.form['data'])
-        if 'doc_form' not in args:
-            args['doc_form'] = 'text_model'
-        if 'doc_language' not in args:
-            args['doc_language'] = 'English'
+        if "data" in request.form:
+            args = json.loads(request.form["data"])
+        if "doc_form" not in args:
+            args["doc_form"] = "text_model"
+        if "doc_language" not in args:
+            args["doc_language"] = "English"
         # get dataset info
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
 
         if not dataset:
-            raise ValueError('Dataset is not exist.')
-        if not dataset.indexing_technique and not args.get('indexing_technique'):
-            raise ValueError('indexing_technique is required.')
+            raise ValueError("Dataset is not exist.")
+        if not dataset.indexing_technique and not args.get("indexing_technique"):
+            raise ValueError("indexing_technique is required.")
 
         # save file info
-        file = request.files['file']
+        file = request.files["file"]
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if len(request.files) > 1:
             raise TooManyFilesError()
 
         upload_file = FileService.upload_file(file, current_user)
-        data_source = {
-            'type': 'upload_file',
-            'info_list': {
-                'file_info_list': {
-                    'file_ids': [upload_file.id]
-                }
-            }
-        }
-        args['data_source'] = data_source
+        data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
+        args["data_source"] = data_source
         # validate args
         DocumentService.document_create_args_validate(args)
 
@@ -204,63 +174,49 @@ class DocumentAddByFileApi(DatasetApiResource):
                 dataset=dataset,
                 document_data=args,
                 account=dataset.created_by_account,
-                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
-                created_from='api'
+                dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
+                created_from="api",
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         document = documents[0]
-        documents_and_batch_fields = {
-            'document': marshal(document, document_fields),
-            'batch': batch
-        }
+        documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
         return documents_and_batch_fields, 200
 
 
 class DocumentUpdateByFileApi(DatasetApiResource):
     """Resource for update documents."""
 
-    @cloud_edition_billing_resource_check('vector_space', 'dataset')
+    @cloud_edition_billing_resource_check("vector_space", "dataset")
     def post(self, tenant_id, dataset_id, document_id):
         """Update document by upload file."""
         args = {}
-        if 'data' in request.form:
-            args = json.loads(request.form['data'])
-        if 'doc_form' not in args:
-            args['doc_form'] = 'text_model'
-        if 'doc_language' not in args:
-            args['doc_language'] = 'English'
+        if "data" in request.form:
+            args = json.loads(request.form["data"])
+        if "doc_form" not in args:
+            args["doc_form"] = "text_model"
+        if "doc_language" not in args:
+            args["doc_language"] = "English"
 
         # get dataset info
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
 
         if not dataset:
-            raise ValueError('Dataset is not exist.')
-        if 'file' in request.files:
+            raise ValueError("Dataset is not exist.")
+        if "file" in request.files:
             # save file info
-            file = request.files['file']
-
+            file = request.files["file"]
 
             if len(request.files) > 1:
                 raise TooManyFilesError()
 
             upload_file = FileService.upload_file(file, current_user)
-            data_source = {
-                'type': 'upload_file',
-                'info_list': {
-                    'file_info_list': {
-                        'file_ids': [upload_file.id]
-                    }
-                }
-            }
-            args['data_source'] = data_source
+            data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
+            args["data_source"] = data_source
         # validate args
-        args['original_document_id'] = str(document_id)
+        args["original_document_id"] = str(document_id)
         DocumentService.document_create_args_validate(args)
 
         try:
@@ -268,16 +224,13 @@ class DocumentUpdateByFileApi(DatasetApiResource):
                 dataset=dataset,
                 document_data=args,
                 account=dataset.created_by_account,
-                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
-                created_from='api'
+                dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
+                created_from="api",
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         document = documents[0]
-        documents_and_batch_fields = {
-            'document': marshal(document, document_fields),
-            'batch': batch
-        }
+        documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
         return documents_and_batch_fields, 200
 
 
@@ -289,13 +242,10 @@ class DocumentDeleteApi(DatasetApiResource):
         tenant_id = str(tenant_id)
 
         # get dataset info
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
 
         if not dataset:
-            raise ValueError('Dataset is not exist.')
+            raise ValueError("Dataset is not exist.")
 
         document = DocumentService.get_document(dataset.id, document_id)
 
@@ -311,44 +261,39 @@ class DocumentDeleteApi(DatasetApiResource):
             # delete document
             DocumentService.delete_document(document)
         except services.errors.document.DocumentIndexingError:
-            raise DocumentIndexingError('Cannot delete document during indexing.')
+            raise DocumentIndexingError("Cannot delete document during indexing.")
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class DocumentListApi(DatasetApiResource):
     def get(self, tenant_id, dataset_id):
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        page = request.args.get('page', default=1, type=int)
-        limit = request.args.get('limit', default=20, type=int)
-        search = request.args.get('keyword', default=None, type=str)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        search = request.args.get("keyword", default=None, type=str)
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
 
-        query = Document.query.filter_by(
-            dataset_id=str(dataset_id), tenant_id=tenant_id)
+        query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
 
         if search:
-            search = f'%{search}%'
+            search = f"%{search}%"
             query = query.filter(Document.name.like(search))
 
         query = query.order_by(desc(Document.created_at))
 
-        paginated_documents = query.paginate(
-            page=page, per_page=limit, max_per_page=100, error_out=False)
+        paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
         documents = paginated_documents.items
 
         response = {
-            'data': marshal(documents, document_fields),
-            'has_more': len(documents) == limit,
-            'limit': limit,
-            'total': paginated_documents.total,
-            'page': page
+            "data": marshal(documents, document_fields),
+            "has_more": len(documents) == limit,
+            "limit": limit,
+            "total": paginated_documents.total,
+            "page": page,
         }
 
         return response
@@ -360,38 +305,36 @@ class DocumentIndexingStatusApi(DatasetApiResource):
         batch = str(batch)
         tenant_id = str(tenant_id)
         # get dataset
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # get documents
         documents = DocumentService.get_batch_documents(dataset_id, batch)
         if not documents:
-            raise NotFound('Documents not found.')
+            raise NotFound("Documents not found.")
         documents_status = []
         for document in documents:
-            completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
-                                                              DocumentSegment.document_id == str(document.id),
-                                                              DocumentSegment.status != 're_segment').count()
-            total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
-                                                          DocumentSegment.status != 're_segment').count()
+            completed_segments = DocumentSegment.query.filter(
+                DocumentSegment.completed_at.isnot(None),
+                DocumentSegment.document_id == str(document.id),
+                DocumentSegment.status != "re_segment",
+            ).count()
+            total_segments = DocumentSegment.query.filter(
+                DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
+            ).count()
             document.completed_segments = completed_segments
             document.total_segments = total_segments
             if document.is_paused:
-                document.indexing_status = 'paused'
+                document.indexing_status = "paused"
             documents_status.append(marshal(document, document_status_fields))
-        data = {
-            'data': documents_status
-        }
+        data = {"data": documents_status}
         return data
 
 
-api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
-api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
-api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
-api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
-api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
-api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
-api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')
+api.add_resource(DocumentAddByTextApi, "/datasets/<uuid:dataset_id>/document/create_by_text")
+api.add_resource(DocumentAddByFileApi, "/datasets/<uuid:dataset_id>/document/create_by_file")
+api.add_resource(DocumentUpdateByTextApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
+api.add_resource(DocumentUpdateByFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file")
+api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
+api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
+api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")

+ 13 - 13
api/controllers/service_api/dataset/error.py

@@ -2,78 +2,78 @@ from libs.exception import BaseHTTPException
 
 
 class NoFileUploadedError(BaseHTTPException):
-    error_code = 'no_file_uploaded'
+    error_code = "no_file_uploaded"
     description = "Please upload your file."
     code = 400
 
 
 class TooManyFilesError(BaseHTTPException):
-    error_code = 'too_many_files'
+    error_code = "too_many_files"
     description = "Only one file is allowed."
     code = 400
 
 
 class FileTooLargeError(BaseHTTPException):
-    error_code = 'file_too_large'
+    error_code = "file_too_large"
     description = "File size exceeded. {message}"
     code = 413
 
 
 class UnsupportedFileTypeError(BaseHTTPException):
-    error_code = 'unsupported_file_type'
+    error_code = "unsupported_file_type"
     description = "File type not allowed."
     code = 415
 
 
 class HighQualityDatasetOnlyError(BaseHTTPException):
-    error_code = 'high_quality_dataset_only'
+    error_code = "high_quality_dataset_only"
     description = "Current operation only supports 'high-quality' datasets."
     code = 400
 
 
 class DatasetNotInitializedError(BaseHTTPException):
-    error_code = 'dataset_not_initialized'
+    error_code = "dataset_not_initialized"
     description = "The dataset is still being initialized or indexing. Please wait a moment."
     code = 400
 
 
 class ArchivedDocumentImmutableError(BaseHTTPException):
-    error_code = 'archived_document_immutable'
+    error_code = "archived_document_immutable"
     description = "The archived document is not editable."
     code = 403
 
 
 class DatasetNameDuplicateError(BaseHTTPException):
-    error_code = 'dataset_name_duplicate'
+    error_code = "dataset_name_duplicate"
     description = "The dataset name already exists. Please modify your dataset name."
     code = 409
 
 
 class InvalidActionError(BaseHTTPException):
-    error_code = 'invalid_action'
+    error_code = "invalid_action"
     description = "Invalid action."
     code = 400
 
 
 class DocumentAlreadyFinishedError(BaseHTTPException):
-    error_code = 'document_already_finished'
+    error_code = "document_already_finished"
     description = "The document has been processed. Please refresh the page or go to the document details."
     code = 400
 
 
 class DocumentIndexingError(BaseHTTPException):
-    error_code = 'document_indexing'
+    error_code = "document_indexing"
     description = "The document is being processed and cannot be edited."
     code = 400
 
 
 class InvalidMetadataError(BaseHTTPException):
-    error_code = 'invalid_metadata'
+    error_code = "invalid_metadata"
     description = "The metadata content is incorrect. Please check and verify."
     code = 400
 
 
 class DatasetInUseError(BaseHTTPException):
-    error_code = 'dataset_in_use'
+    error_code = "dataset_in_use"
     description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
     code = 409

+ 53 - 74
api/controllers/service_api/dataset/segment.py

@@ -21,52 +21,47 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
 class SegmentApi(DatasetApiResource):
     """Resource for segments."""
 
-    @cloud_edition_billing_resource_check('vector_space', 'dataset')
-    @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset')
+    @cloud_edition_billing_resource_check("vector_space", "dataset")
+    @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
     def post(self, tenant_id, dataset_id, document_id):
         """Create single segment."""
         # check dataset
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset.id, document_id)
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
         # check embedding model setting
-        if dataset.indexing_technique == 'high_quality':
+        if dataset.indexing_technique == "high_quality":
             try:
                 model_manager = ModelManager()
                 model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     provider=dataset.embedding_model_provider,
                     model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
+                    model=dataset.embedding_model,
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:   
+                    "in the Settings -> Model Provider."
+                )
+            except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
         # validate args
         parser = reqparse.RequestParser()
-        parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
+        parser.add_argument("segments", type=list, required=False, nullable=True, location="json")
         args = parser.parse_args()
-        if args['segments'] is not None:
-            for args_item in args['segments']:
+        if args["segments"] is not None:
+            for args_item in args["segments"]:
                 SegmentService.segment_create_args_validate(args_item, document)
-            segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
-            return {
-                'data': marshal(segments, segment_fields),
-                'doc_form': document.doc_form
-            }, 200
+            segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
+            return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
         else:
             return {"error": "Segemtns is required"}, 400
 
@@ -75,61 +70,53 @@ class SegmentApi(DatasetApiResource):
         # check dataset
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset.id, document_id)
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
         # check embedding model setting
-        if dataset.indexing_technique == 'high_quality':
+        if dataset.indexing_technique == "high_quality":
             try:
                 model_manager = ModelManager()
                 model_manager.get_model_instance(
                     tenant_id=current_user.current_tenant_id,
                     provider=dataset.embedding_model_provider,
                     model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
+                    model=dataset.embedding_model,
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
+                    "in the Settings -> Model Provider."
+                )
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('status', type=str,
-                            action='append', default=[], location='args')
-        parser.add_argument('keyword', type=str, default=None, location='args')
+        parser.add_argument("status", type=str, action="append", default=[], location="args")
+        parser.add_argument("keyword", type=str, default=None, location="args")
         args = parser.parse_args()
 
-        status_list = args['status']
-        keyword = args['keyword']
+        status_list = args["status"]
+        keyword = args["keyword"]
 
         query = DocumentSegment.query.filter(
-            DocumentSegment.document_id == str(document_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         )
 
         if status_list:
             query = query.filter(DocumentSegment.status.in_(status_list))
 
         if keyword:
-            query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
+            query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
 
         total = query.count()
         segments = query.order_by(DocumentSegment.position).all()
-        return {
-            'data': marshal(segments, segment_fields),
-            'doc_form': document.doc_form,
-            'total': total
-        }, 200
+        return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200
 
 
 class DatasetSegmentApi(DatasetApiResource):
@@ -137,48 +124,41 @@ class DatasetSegmentApi(DatasetApiResource):
         # check dataset
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check user's model setting
         DatasetService.check_dataset_model_setting(dataset)
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
-            raise NotFound('Document not found.')
+            raise NotFound("Document not found.")
         # check segment
         segment = DocumentSegment.query.filter(
-            DocumentSegment.id == str(segment_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         ).first()
         if not segment:
-            raise NotFound('Segment not found.')
+            raise NotFound("Segment not found.")
         SegmentService.delete_segment(segment, document, dataset)
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
-    @cloud_edition_billing_resource_check('vector_space', 'dataset')
+    @cloud_edition_billing_resource_check("vector_space", "dataset")
     def post(self, tenant_id, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == tenant_id,
-            Dataset.id == dataset_id
-        ).first()
+        dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
         if not dataset:
-            raise NotFound('Dataset not found.')
+            raise NotFound("Dataset not found.")
         # check user's model setting
         DatasetService.check_dataset_model_setting(dataset)
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
-            raise NotFound('Document not found.')
-        if dataset.indexing_technique == 'high_quality':
+            raise NotFound("Document not found.")
+        if dataset.indexing_technique == "high_quality":
             # check embedding model setting
             try:
                 model_manager = ModelManager()
@@ -186,35 +166,34 @@ class DatasetSegmentApi(DatasetApiResource):
                     tenant_id=current_user.current_tenant_id,
                     provider=dataset.embedding_model_provider,
                     model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
+                    model=dataset.embedding_model,
                 )
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
+                    "in the Settings -> Model Provider."
+                )
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
             # check segment
         segment_id = str(segment_id)
         segment = DocumentSegment.query.filter(
-            DocumentSegment.id == str(segment_id),
-            DocumentSegment.tenant_id == current_user.current_tenant_id
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
         ).first()
         if not segment:
-            raise NotFound('Segment not found.')
+            raise NotFound("Segment not found.")
 
         # validate args
         parser = reqparse.RequestParser()
-        parser.add_argument('segment', type=dict, required=False, nullable=True, location='json')
+        parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
         args = parser.parse_args()
 
-        SegmentService.segment_create_args_validate(args['segment'], document)
-        segment = SegmentService.update_segment(args['segment'], segment, document, dataset)
-        return {
-            'data': marshal(segment, segment_fields),
-            'doc_form': document.doc_form
-        }, 200
+        SegmentService.segment_create_args_validate(args["segment"], document)
+        segment = SegmentService.update_segment(args["segment"], segment, document, dataset)
+        return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
 
 
-api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
-api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
+api.add_resource(SegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
+api.add_resource(
+    DatasetSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>"
+)

+ 1 - 1
api/controllers/service_api/index.py

@@ -13,4 +13,4 @@ class IndexApi(Resource):
         }
 
 
-api.add_resource(IndexApi, '/')
+api.add_resource(IndexApi, "/")

+ 63 - 46
api/controllers/service_api/wraps.py

@@ -21,9 +21,10 @@ class WhereisUserArg(Enum):
     """
     Enum for whereis_user_arg.
     """
-    QUERY = 'query'
-    JSON = 'json'
-    FORM = 'form'
+
+    QUERY = "query"
+    JSON = "json"
+    FORM = "form"
 
 
 class FetchUserArg(BaseModel):
@@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
     def decorator(view_func):
         @wraps(view_func)
         def decorated_view(*args, **kwargs):
-            api_token = validate_and_get_api_token('app')
+            api_token = validate_and_get_api_token("app")
 
             app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
             if not app_model:
                 raise Forbidden("The app no longer exists.")
 
-            if app_model.status != 'normal':
+            if app_model.status != "normal":
                 raise Forbidden("The app's status is abnormal.")
 
             if not app_model.enable_api:
@@ -51,15 +52,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
             if tenant.status == TenantStatus.ARCHIVE:
                 raise Forbidden("The workspace's status is archived.")
 
-            kwargs['app_model'] = app_model
+            kwargs["app_model"] = app_model
 
             if fetch_user_arg:
                 if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
-                    user_id = request.args.get('user')
+                    user_id = request.args.get("user")
                 elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
-                    user_id = request.get_json().get('user')
+                    user_id = request.get_json().get("user")
                 elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
-                    user_id = request.form.get('user')
+                    user_id = request.form.get("user")
                 else:
                     # use default-user
                     user_id = None
@@ -70,9 +71,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
                 if user_id:
                     user_id = str(user_id)
 
-                kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
+                kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
 
             return view_func(*args, **kwargs)
+
         return decorated_view
 
     if view is None:
@@ -81,9 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
         return decorator(view)
 
 
-def cloud_edition_billing_resource_check(resource: str,
-                                         api_token_type: str,
-                                         error_msg: str = "You have reached the limit of your subscription."):
+def cloud_edition_billing_resource_check(
+    resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
+):
     def interceptor(view):
         def decorated(*args, **kwargs):
             api_token = validate_and_get_api_token(api_token_type)
@@ -95,33 +97,37 @@ def cloud_edition_billing_resource_check(resource: str,
                 vector_space = features.vector_space
                 documents_upload_quota = features.documents_upload_quota
 
-                if resource == 'members' and 0 < members.limit <= members.size:
+                if resource == "members" and 0 < members.limit <= members.size:
                     raise Forbidden(error_msg)
-                elif resource == 'apps' and 0 < apps.limit <= apps.size:
+                elif resource == "apps" and 0 < apps.limit <= apps.size:
                     raise Forbidden(error_msg)
-                elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
+                elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
                     raise Forbidden(error_msg)
-                elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
+                elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
                     raise Forbidden(error_msg)
                 else:
                     return view(*args, **kwargs)
 
             return view(*args, **kwargs)
+
         return decorated
+
     return interceptor
 
 
-def cloud_edition_billing_knowledge_limit_check(resource: str,
-                                                api_token_type: str,
-                                                error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
+def cloud_edition_billing_knowledge_limit_check(
+    resource: str,
+    api_token_type: str,
+    error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
+):
     def interceptor(view):
         @wraps(view)
         def decorated(*args, **kwargs):
             api_token = validate_and_get_api_token(api_token_type)
             features = FeatureService.get_features(api_token.tenant_id)
             if features.billing.enabled:
-                if resource == 'add_segment':
-                    if features.billing.subscription.plan == 'sandbox':
+                if resource == "add_segment":
+                    if features.billing.subscription.plan == "sandbox":
                         raise Forbidden(error_msg)
                 else:
                     return view(*args, **kwargs)
@@ -132,17 +138,20 @@ def cloud_edition_billing_knowledge_limit_check(resource: str,
 
     return interceptor
 
+
 def validate_dataset_token(view=None):
     def decorator(view):
         @wraps(view)
         def decorated(*args, **kwargs):
-            api_token = validate_and_get_api_token('dataset')
-            tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
-                .filter(Tenant.id == api_token.tenant_id) \
-                .filter(TenantAccountJoin.tenant_id == Tenant.id) \
-                .filter(TenantAccountJoin.role.in_(['owner'])) \
-                .filter(Tenant.status == TenantStatus.NORMAL) \
-                .one_or_none() # TODO: only owner information is required, so only one is returned.
+            api_token = validate_and_get_api_token("dataset")
+            tenant_account_join = (
+                db.session.query(Tenant, TenantAccountJoin)
+                .filter(Tenant.id == api_token.tenant_id)
+                .filter(TenantAccountJoin.tenant_id == Tenant.id)
+                .filter(TenantAccountJoin.role.in_(["owner"]))
+                .filter(Tenant.status == TenantStatus.NORMAL)
+                .one_or_none()
+            )  # TODO: only owner information is required, so only one is returned.
             if tenant_account_join:
                 tenant, ta = tenant_account_join
                 account = Account.query.filter_by(id=ta.account_id).first()
@@ -156,6 +165,7 @@ def validate_dataset_token(view=None):
             else:
                 raise Unauthorized("Tenant does not exist.")
             return view(api_token.tenant_id, *args, **kwargs)
+
         return decorated
 
     if view:
@@ -170,20 +180,24 @@ def validate_and_get_api_token(scope=None):
     """
     Validate and get API token.
     """
-    auth_header = request.headers.get('Authorization')
-    if auth_header is None or ' ' not in auth_header:
+    auth_header = request.headers.get("Authorization")
+    if auth_header is None or " " not in auth_header:
         raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
 
     auth_scheme, auth_token = auth_header.split(None, 1)
     auth_scheme = auth_scheme.lower()
 
-    if auth_scheme != 'bearer':
+    if auth_scheme != "bearer":
         raise Unauthorized("Authorization scheme must be 'Bearer'")
 
-    api_token = db.session.query(ApiToken).filter(
-        ApiToken.token == auth_token,
-        ApiToken.type == scope,
-    ).first()
+    api_token = (
+        db.session.query(ApiToken)
+        .filter(
+            ApiToken.token == auth_token,
+            ApiToken.type == scope,
+        )
+        .first()
+    )
 
     if not api_token:
         raise Unauthorized("Access token is invalid")
@@ -199,23 +213,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
     Create or update session terminal based on user ID.
     """
     if not user_id:
-        user_id = 'DEFAULT-USER'
+        user_id = "DEFAULT-USER"
 
-    end_user = db.session.query(EndUser) \
+    end_user = (
+        db.session.query(EndUser)
         .filter(
-        EndUser.tenant_id == app_model.tenant_id,
-        EndUser.app_id == app_model.id,
-        EndUser.session_id == user_id,
-        EndUser.type == 'service_api'
-    ).first()
+            EndUser.tenant_id == app_model.tenant_id,
+            EndUser.app_id == app_model.id,
+            EndUser.session_id == user_id,
+            EndUser.type == "service_api",
+        )
+        .first()
+    )
 
     if end_user is None:
         end_user = EndUser(
             tenant_id=app_model.tenant_id,
             app_id=app_model.id,
-            type='service_api',
-            is_anonymous=True if user_id == 'DEFAULT-USER' else False,
-            session_id=user_id
+            type="service_api",
+            is_anonymous=True if user_id == "DEFAULT-USER" else False,
+            session_id=user_id,
         )
         db.session.add(end_user)
         db.session.commit()

+ 1 - 1
api/controllers/web/__init__.py

@@ -2,7 +2,7 @@ from flask import Blueprint
 
 from libs.external_api import ExternalApi
 
-bp = Blueprint('web', __name__, url_prefix='/api')
+bp = Blueprint("web", __name__, url_prefix="/api")
 api = ExternalApi(bp)
 
 

+ 50 - 46
api/controllers/web/app.py

@@ -10,33 +10,32 @@ from services.app_service import AppService
 
 class AppParameterApi(WebApiResource):
     """Resource for app variables."""
+
     variable_fields = {
-        'key': fields.String,
-        'name': fields.String,
-        'description': fields.String,
-        'type': fields.String,
-        'default': fields.String,
-        'max_length': fields.Integer,
-        'options': fields.List(fields.String)
+        "key": fields.String,
+        "name": fields.String,
+        "description": fields.String,
+        "type": fields.String,
+        "default": fields.String,
+        "max_length": fields.Integer,
+        "options": fields.List(fields.String),
     }
 
-    system_parameters_fields = {
-        'image_file_size_limit': fields.String
-    }
+    system_parameters_fields = {"image_file_size_limit": fields.String}
 
     parameters_fields = {
-        'opening_statement': fields.String,
-        'suggested_questions': fields.Raw,
-        'suggested_questions_after_answer': fields.Raw,
-        'speech_to_text': fields.Raw,
-        'text_to_speech': fields.Raw,
-        'retriever_resource': fields.Raw,
-        'annotation_reply': fields.Raw,
-        'more_like_this': fields.Raw,
-        'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw,
-        'file_upload': fields.Raw,
-        'system_parameters': fields.Nested(system_parameters_fields)
+        "opening_statement": fields.String,
+        "suggested_questions": fields.Raw,
+        "suggested_questions_after_answer": fields.Raw,
+        "speech_to_text": fields.Raw,
+        "text_to_speech": fields.Raw,
+        "retriever_resource": fields.Raw,
+        "annotation_reply": fields.Raw,
+        "more_like_this": fields.Raw,
+        "user_input_form": fields.Raw,
+        "sensitive_word_avoidance": fields.Raw,
+        "file_upload": fields.Raw,
+        "system_parameters": fields.Nested(system_parameters_fields),
     }
 
     @marshal_with(parameters_fields)
@@ -53,30 +52,35 @@ class AppParameterApi(WebApiResource):
             app_model_config = app_model.app_model_config
             features_dict = app_model_config.to_dict()
 
-            user_input_form = features_dict.get('user_input_form', [])
+            user_input_form = features_dict.get("user_input_form", [])
 
         return {
-            'opening_statement': features_dict.get('opening_statement'),
-            'suggested_questions': features_dict.get('suggested_questions', []),
-            'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
-                                                                  {"enabled": False}),
-            'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
-            'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
-            'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
-            'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
-            'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
-            'user_input_form': user_input_form,
-            'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
-                                                          {"enabled": False, "type": "", "configs": []}),
-            'file_upload': features_dict.get('file_upload', {"image": {
-                "enabled": False,
-                "number_limits": 3,
-                "detail": "high",
-                "transfer_methods": ["remote_url", "local_file"]
-            }}),
-            'system_parameters': {
-                'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
-            }
+            "opening_statement": features_dict.get("opening_statement"),
+            "suggested_questions": features_dict.get("suggested_questions", []),
+            "suggested_questions_after_answer": features_dict.get(
+                "suggested_questions_after_answer", {"enabled": False}
+            ),
+            "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
+            "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
+            "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
+            "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
+            "more_like_this": features_dict.get("more_like_this", {"enabled": False}),
+            "user_input_form": user_input_form,
+            "sensitive_word_avoidance": features_dict.get(
+                "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
+            ),
+            "file_upload": features_dict.get(
+                "file_upload",
+                {
+                    "image": {
+                        "enabled": False,
+                        "number_limits": 3,
+                        "detail": "high",
+                        "transfer_methods": ["remote_url", "local_file"],
+                    }
+                },
+            ),
+            "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
         }
 
 
@@ -86,5 +90,5 @@ class AppMeta(WebApiResource):
         return AppService().get_app_meta(app_model)
 
 
-api.add_resource(AppParameterApi, '/parameters')
-api.add_resource(AppMeta, '/meta')
+api.add_resource(AppParameterApi, "/parameters")
+api.add_resource(AppMeta, "/meta")

+ 24 - 26
api/controllers/web/audio.py

@@ -31,14 +31,10 @@ from services.errors.audio import (
 
 class AudioApi(WebApiResource):
     def post(self, app_model: App, end_user):
-        file = request.files['file']
+        file = request.files["file"]
 
         try:
-            response = AudioService.transcript_asr(
-                app_model=app_model,
-                file=file,
-                end_user=end_user
-            )
+            response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
 
             return response
         except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -70,34 +66,36 @@ class AudioApi(WebApiResource):
 class TextApi(WebApiResource):
     def post(self, app_model: App, end_user):
         from flask_restful import reqparse
+
         try:
             parser = reqparse.RequestParser()
-            parser.add_argument('message_id', type=str, required=False, location='json')
-            parser.add_argument('voice', type=str, location='json')
-            parser.add_argument('text', type=str, location='json')
-            parser.add_argument('streaming', type=bool, location='json')
+            parser.add_argument("message_id", type=str, required=False, location="json")
+            parser.add_argument("voice", type=str, location="json")
+            parser.add_argument("text", type=str, location="json")
+            parser.add_argument("streaming", type=bool, location="json")
             args = parser.parse_args()
 
-            message_id = args.get('message_id', None)
-            text = args.get('text', None)
-            if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
-                    and app_model.workflow
-                    and app_model.workflow.features_dict):
-                text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
-                voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
+            message_id = args.get("message_id", None)
+            text = args.get("text", None)
+            if (
+                app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
+                and app_model.workflow
+                and app_model.workflow.features_dict
+            ):
+                text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
+                voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
             else:
                 try:
-                    voice = args.get('voice') if args.get(
-                        'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
+                    voice = (
+                        args.get("voice")
+                        if args.get("voice")
+                        else app_model.app_model_config.text_to_speech_dict.get("voice")
+                    )
                 except Exception:
                     voice = None
 
             response = AudioService.transcript_tts(
-                app_model=app_model,
-                message_id=message_id,
-                end_user=end_user.external_user_id,
-                voice=voice,
-                text=text
+                app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
             )
 
             return response
@@ -127,5 +125,5 @@ class TextApi(WebApiResource):
             raise InternalServerError()
 
 
-api.add_resource(AudioApi, '/audio-to-text')
-api.add_resource(TextApi, '/text-to-audio')
+api.add_resource(AudioApi, "/audio-to-text")
+api.add_resource(TextApi, "/text-to-audio")

+ 25 - 34
api/controllers/web/completion.py

@@ -28,30 +28,25 @@ from services.app_generate_service import AppGenerateService
 
 # define completion api for user
 class CompletionApi(WebApiResource):
-
     def post(self, app_model, end_user):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, location='json', default='')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, location="json", default="")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
 
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
-        args['auto_generate_name'] = False
+        streaming = args["response_mode"] == "streaming"
+        args["auto_generate_name"] = False
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=end_user,
-                args=args,
-                invoke_from=InvokeFrom.WEB_APP,
-                streaming=streaming
+                app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -79,12 +74,12 @@ class CompletionApi(WebApiResource):
 
 class CompletionStopApi(WebApiResource):
     def post(self, app_model, end_user, task_id):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
 class ChatApi(WebApiResource):
@@ -94,25 +89,21 @@ class ChatApi(WebApiResource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('inputs', type=dict, required=True, location='json')
-        parser.add_argument('query', type=str, required=True, location='json')
-        parser.add_argument('files', type=list, required=False, location='json')
-        parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
-        parser.add_argument('conversation_id', type=uuid_value, location='json')
-        parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
+        parser.add_argument("inputs", type=dict, required=True, location="json")
+        parser.add_argument("query", type=str, required=True, location="json")
+        parser.add_argument("files", type=list, required=False, location="json")
+        parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+        parser.add_argument("conversation_id", type=uuid_value, location="json")
+        parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
 
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
-        args['auto_generate_name'] = False
+        streaming = args["response_mode"] == "streaming"
+        args["auto_generate_name"] = False
 
         try:
             response = AppGenerateService.generate(
-                app_model=app_model,
-                user=end_user,
-                args=args,
-                invoke_from=InvokeFrom.WEB_APP,
-                streaming=streaming
+                app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming
             )
 
             return helper.compact_generate_response(response)
@@ -146,10 +137,10 @@ class ChatStopApi(WebApiResource):
 
         AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
 
-        return {'result': 'success'}, 200
+        return {"result": "success"}, 200
 
 
-api.add_resource(CompletionApi, '/completion-messages')
-api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
-api.add_resource(ChatApi, '/chat-messages')
-api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')
+api.add_resource(CompletionApi, "/completion-messages")
+api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
+api.add_resource(ChatApi, "/chat-messages")
+api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")

+ 24 - 27
api/controllers/web/conversation.py

@@ -15,7 +15,6 @@ from services.web_conversation_service import WebConversationService
 
 
 class ConversationListApi(WebApiResource):
-
     @marshal_with(conversation_infinite_scroll_pagination_fields)
     def get(self, app_model, end_user):
         app_mode = AppMode.value_of(app_model.mode)
@@ -23,26 +22,32 @@ class ConversationListApi(WebApiResource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
-        parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
-        parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
-                            required=False, default='-updated_at', location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
+        parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
+        parser.add_argument(
+            "sort_by",
+            type=str,
+            choices=["created_at", "-created_at", "updated_at", "-updated_at"],
+            required=False,
+            default="-updated_at",
+            location="args",
+        )
         args = parser.parse_args()
 
         pinned = None
-        if 'pinned' in args and args['pinned'] is not None:
-            pinned = True if args['pinned'] == 'true' else False
+        if "pinned" in args and args["pinned"] is not None:
+            pinned = True if args["pinned"] == "true" else False
 
         try:
             return WebConversationService.pagination_by_last_id(
                 app_model=app_model,
                 user=end_user,
-                last_id=args['last_id'],
-                limit=args['limit'],
+                last_id=args["last_id"],
+                limit=args["limit"],
                 invoke_from=InvokeFrom.WEB_APP,
                 pinned=pinned,
-                sort_by=args['sort_by']
+                sort_by=args["sort_by"],
             )
         except LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
@@ -65,7 +70,6 @@ class ConversationApi(WebApiResource):
 
 
 class ConversationRenameApi(WebApiResource):
-
     @marshal_with(simple_conversation_fields)
     def post(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
@@ -75,24 +79,17 @@ class ConversationRenameApi(WebApiResource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=False, location='json')
-        parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
+        parser.add_argument("name", type=str, required=False, location="json")
+        parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
         args = parser.parse_args()
 
         try:
-            return ConversationService.rename(
-                app_model,
-                conversation_id,
-                end_user,
-                args['name'],
-                args['auto_generate']
-            )
+            return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 
 
 class ConversationPinApi(WebApiResource):
-
     def patch(self, app_model, end_user, c_id):
         app_mode = AppMode.value_of(app_model.mode)
         if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
@@ -120,8 +117,8 @@ class ConversationUnPinApi(WebApiResource):
         return {"result": "success"}
 
 
-api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='web_conversation_name')
-api.add_resource(ConversationListApi, '/conversations')
-api.add_resource(ConversationApi, '/conversations/<uuid:c_id>')
-api.add_resource(ConversationPinApi, '/conversations/<uuid:c_id>/pin')
-api.add_resource(ConversationUnPinApi, '/conversations/<uuid:c_id>/unpin')
+api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="web_conversation_name")
+api.add_resource(ConversationListApi, "/conversations")
+api.add_resource(ConversationApi, "/conversations/<uuid:c_id>")
+api.add_resource(ConversationPinApi, "/conversations/<uuid:c_id>/pin")
+api.add_resource(ConversationUnPinApi, "/conversations/<uuid:c_id>/unpin")

+ 28 - 24
api/controllers/web/error.py

@@ -2,122 +2,126 @@ from libs.exception import BaseHTTPException
 
 
 class AppUnavailableError(BaseHTTPException):
-    error_code = 'app_unavailable'
+    error_code = "app_unavailable"
     description = "App unavailable, please check your app configurations."
     code = 400
 
 
 class NotCompletionAppError(BaseHTTPException):
-    error_code = 'not_completion_app'
+    error_code = "not_completion_app"
     description = "Please check if your Completion app mode matches the right API route."
     code = 400
 
 
 class NotChatAppError(BaseHTTPException):
-    error_code = 'not_chat_app'
+    error_code = "not_chat_app"
     description = "Please check if your app mode matches the right API route."
     code = 400
 
 
 class NotWorkflowAppError(BaseHTTPException):
-    error_code = 'not_workflow_app'
+    error_code = "not_workflow_app"
     description = "Please check if your Workflow app mode matches the right API route."
     code = 400
 
 
 class ConversationCompletedError(BaseHTTPException):
-    error_code = 'conversation_completed'
+    error_code = "conversation_completed"
     description = "The conversation has ended. Please start a new conversation."
     code = 400
 
 
 class ProviderNotInitializeError(BaseHTTPException):
-    error_code = 'provider_not_initialize'
-    description = "No valid model provider credentials found. " \
-                  "Please go to Settings -> Model Provider to complete your provider credentials."
+    error_code = "provider_not_initialize"
+    description = (
+        "No valid model provider credentials found. "
+        "Please go to Settings -> Model Provider to complete your provider credentials."
+    )
     code = 400
 
 
 class ProviderQuotaExceededError(BaseHTTPException):
-    error_code = 'provider_quota_exceeded'
-    description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
-                  "Please go to Settings -> Model Provider to complete your own provider credentials."
+    error_code = "provider_quota_exceeded"
+    description = (
+        "Your quota for Dify Hosted OpenAI has been exhausted. "
+        "Please go to Settings -> Model Provider to complete your own provider credentials."
+    )
     code = 400
 
 
 class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
-    error_code = 'model_currently_not_support'
+    error_code = "model_currently_not_support"
     description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
     code = 400
 
 
 class CompletionRequestError(BaseHTTPException):
-    error_code = 'completion_request_error'
+    error_code = "completion_request_error"
     description = "Completion request failed."
     code = 400
 
 
 class AppMoreLikeThisDisabledError(BaseHTTPException):
-    error_code = 'app_more_like_this_disabled'
+    error_code = "app_more_like_this_disabled"
     description = "The 'More like this' feature is disabled. Please refresh your page."
     code = 403
 
 
 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
-    error_code = 'app_suggested_questions_after_answer_disabled'
+    error_code = "app_suggested_questions_after_answer_disabled"
     description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
     code = 403
 
 
 class NoAudioUploadedError(BaseHTTPException):
-    error_code = 'no_audio_uploaded'
+    error_code = "no_audio_uploaded"
     description = "Please upload your audio."
     code = 400
 
 
 class AudioTooLargeError(BaseHTTPException):
-    error_code = 'audio_too_large'
+    error_code = "audio_too_large"
     description = "Audio size exceeded. {message}"
     code = 413
 
 
 class UnsupportedAudioTypeError(BaseHTTPException):
-    error_code = 'unsupported_audio_type'
+    error_code = "unsupported_audio_type"
     description = "Audio type not allowed."
     code = 415
 
 
 class ProviderNotSupportSpeechToTextError(BaseHTTPException):
-    error_code = 'provider_not_support_speech_to_text'
+    error_code = "provider_not_support_speech_to_text"
     description = "Provider not support speech to text."
     code = 400
 
 
 class NoFileUploadedError(BaseHTTPException):
-    error_code = 'no_file_uploaded'
+    error_code = "no_file_uploaded"
     description = "Please upload your file."
     code = 400
 
 
 class TooManyFilesError(BaseHTTPException):
-    error_code = 'too_many_files'
+    error_code = "too_many_files"
     description = "Only one file is allowed."
     code = 400
 
 
 class FileTooLargeError(BaseHTTPException):
-    error_code = 'file_too_large'
+    error_code = "file_too_large"
     description = "File size exceeded. {message}"
     code = 413
 
 
 class UnsupportedFileTypeError(BaseHTTPException):
-    error_code = 'unsupported_file_type'
+    error_code = "unsupported_file_type"
     description = "File type not allowed."
     code = 415
 
 
 class WebSSOAuthRequiredError(BaseHTTPException):
-    error_code = 'web_sso_auth_required'
+    error_code = "web_sso_auth_required"
     description = "Web SSO authentication required."
     code = 401

+ 1 - 1
api/controllers/web/feature.py

@@ -9,4 +9,4 @@ class SystemFeatureApi(Resource):
         return FeatureService.get_system_features().model_dump()
 
 
-api.add_resource(SystemFeatureApi, '/system-features')
+api.add_resource(SystemFeatureApi, "/system-features")

+ 3 - 4
api/controllers/web/file.py

@@ -10,14 +10,13 @@ from services.file_service import FileService
 
 
 class FileApi(WebApiResource):
-
     @marshal_with(file_fields)
     def post(self, app_model, end_user):
         # get file from request
-        file = request.files['file']
+        file = request.files["file"]
 
         # check file
-        if 'file' not in request.files:
+        if "file" not in request.files:
             raise NoFileUploadedError()
 
         if len(request.files) > 1:
@@ -32,4 +31,4 @@ class FileApi(WebApiResource):
         return upload_file, 201
 
 
-api.add_resource(FileApi, '/files/upload')
+api.add_resource(FileApi, "/files/upload")

+ 53 - 55
api/controllers/web/message.py

@@ -33,48 +33,46 @@ from services.message_service import MessageService
 
 
 class MessageListApi(WebApiResource):
-    feedback_fields = {
-        'rating': fields.String
-    }
+    feedback_fields = {"rating": fields.String}
 
     retriever_resource_fields = {
-        'id': fields.String,
-        'message_id': fields.String,
-        'position': fields.Integer,
-        'dataset_id': fields.String,
-        'dataset_name': fields.String,
-        'document_id': fields.String,
-        'document_name': fields.String,
-        'data_source_type': fields.String,
-        'segment_id': fields.String,
-        'score': fields.Float,
-        'hit_count': fields.Integer,
-        'word_count': fields.Integer,
-        'segment_position': fields.Integer,
-        'index_node_hash': fields.String,
-        'content': fields.String,
-        'created_at': TimestampField
+        "id": fields.String,
+        "message_id": fields.String,
+        "position": fields.Integer,
+        "dataset_id": fields.String,
+        "dataset_name": fields.String,
+        "document_id": fields.String,
+        "document_name": fields.String,
+        "data_source_type": fields.String,
+        "segment_id": fields.String,
+        "score": fields.Float,
+        "hit_count": fields.Integer,
+        "word_count": fields.Integer,
+        "segment_position": fields.Integer,
+        "index_node_hash": fields.String,
+        "content": fields.String,
+        "created_at": TimestampField,
     }
 
     message_fields = {
-        'id': fields.String,
-        'conversation_id': fields.String,
-        'inputs': fields.Raw,
-        'query': fields.String,
-        'answer': fields.String(attribute='re_sign_file_url_answer'),
-        'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
-        'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
-        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
-        'created_at': TimestampField,
-        'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)),
-        'status': fields.String,
-        'error': fields.String,
+        "id": fields.String,
+        "conversation_id": fields.String,
+        "inputs": fields.Raw,
+        "query": fields.String,
+        "answer": fields.String(attribute="re_sign_file_url_answer"),
+        "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+        "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
+        "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
+        "created_at": TimestampField,
+        "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
+        "status": fields.String,
+        "error": fields.String,
     }
 
     message_infinite_scroll_pagination_fields = {
-        'limit': fields.Integer,
-        'has_more': fields.Boolean,
-        'data': fields.List(fields.Nested(message_fields))
+        "limit": fields.Integer,
+        "has_more": fields.Boolean,
+        "data": fields.List(fields.Nested(message_fields)),
     }
 
     @marshal_with(message_infinite_scroll_pagination_fields)
@@ -84,14 +82,15 @@ class MessageListApi(WebApiResource):
             raise NotChatAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
-        parser.add_argument('first_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
+        parser.add_argument("first_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
         try:
-            return MessageService.pagination_by_first_id(app_model, end_user,
-                                                     args['conversation_id'], args['first_id'], args['limit'])
+            return MessageService.pagination_by_first_id(
+                app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
+            )
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
         except services.errors.message.FirstMessageNotExistsError:
@@ -103,29 +102,31 @@ class MessageFeedbackApi(WebApiResource):
         message_id = str(message_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
+        parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
         args = parser.parse_args()
 
         try:
-            MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
+            MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
         except services.errors.message.MessageNotExistsError:
             raise NotFound("Message Not Exists.")
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class MessageMoreLikeThisApi(WebApiResource):
     def get(self, app_model, end_user, message_id):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         message_id = str(message_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
+        parser.add_argument(
+            "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
+        )
         args = parser.parse_args()
 
-        streaming = args['response_mode'] == 'streaming'
+        streaming = args["response_mode"] == "streaming"
 
         try:
             response = AppGenerateService.generate_more_like_this(
@@ -133,7 +134,7 @@ class MessageMoreLikeThisApi(WebApiResource):
                 user=end_user,
                 message_id=message_id,
                 invoke_from=InvokeFrom.WEB_APP,
-                streaming=streaming
+                streaming=streaming,
             )
 
             return helper.compact_generate_response(response)
@@ -166,10 +167,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
 
         try:
             questions = MessageService.get_suggested_questions_after_answer(
-                app_model=app_model,
-                user=end_user,
-                message_id=message_id,
-                invoke_from=InvokeFrom.WEB_APP
+                app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
             )
         except MessageNotExistsError:
             raise NotFound("Message not found")
@@ -189,10 +187,10 @@ class MessageSuggestedQuestionApi(WebApiResource):
             logging.exception("internal server error.")
             raise InternalServerError()
 
-        return {'data': questions}
+        return {"data": questions}
 
 
-api.add_resource(MessageListApi, '/messages')
-api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
-api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
-api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
+api.add_resource(MessageListApi, "/messages")
+api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
+api.add_resource(MessageMoreLikeThisApi, "/messages/<uuid:message_id>/more-like-this")
+api.add_resource(MessageSuggestedQuestionApi, "/messages/<uuid:message_id>/suggested-questions")

+ 15 - 18
api/controllers/web/passport.py

@@ -15,33 +15,31 @@ from services.feature_service import FeatureService
 
 class PassportResource(Resource):
     """Base resource for passport."""
+
     def get(self):
         system_features = FeatureService.get_system_features()
-        app_code = request.headers.get('X-App-Code')
+        app_code = request.headers.get("X-App-Code")
         if app_code is None:
-            raise Unauthorized('X-App-Code header is missing.')
+            raise Unauthorized("X-App-Code header is missing.")
 
         if system_features.sso_enforced_for_web:
-            app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False)
+            app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
             if app_web_sso_enabled:
                 raise WebSSOAuthRequiredError()
-        
+
         # get site from db and check if it is normal
-        site = db.session.query(Site).filter(
-            Site.code == app_code,
-            Site.status == 'normal'
-        ).first()
+        site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
         if not site:
             raise NotFound()
         # get app from db and check if it is normal and enable_site
         app_model = db.session.query(App).filter(App.id == site.app_id).first()
-        if not app_model or app_model.status != 'normal' or not app_model.enable_site:
+        if not app_model or app_model.status != "normal" or not app_model.enable_site:
             raise NotFound()
 
         end_user = EndUser(
             tenant_id=app_model.tenant_id,
             app_id=app_model.id,
-            type='browser',
+            type="browser",
             is_anonymous=True,
             session_id=generate_session_id(),
         )
@@ -51,20 +49,20 @@ class PassportResource(Resource):
 
         payload = {
             "iss": site.app_id,
-            'sub': 'Web API Passport',
-            'app_id': site.app_id,
-            'app_code': app_code,
-            'end_user_id': end_user.id,
+            "sub": "Web API Passport",
+            "app_id": site.app_id,
+            "app_code": app_code,
+            "end_user_id": end_user.id,
         }
 
         tk = PassportService().issue(payload)
 
         return {
-            'access_token': tk,
+            "access_token": tk,
         }
 
 
-api.add_resource(PassportResource, '/passport')
+api.add_resource(PassportResource, "/passport")
 
 
 def generate_session_id():
@@ -73,7 +71,6 @@ def generate_session_id():
     """
     while True:
         session_id = str(uuid.uuid4())
-        existing_count = db.session.query(EndUser) \
-            .filter(EndUser.session_id == session_id).count()
+        existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
         if existing_count == 0:
             return session_id

+ 23 - 25
api/controllers/web/saved_message.py

@@ -10,67 +10,65 @@ from libs.helper import TimestampField, uuid_value
 from services.errors.message import MessageNotExistsError
 from services.saved_message_service import SavedMessageService
 
-feedback_fields = {
-    'rating': fields.String
-}
+feedback_fields = {"rating": fields.String}
 
 message_fields = {
-    'id': fields.String,
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'answer': fields.String,
-    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
-    'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
-    'created_at': TimestampField
+    "id": fields.String,
+    "inputs": fields.Raw,
+    "query": fields.String,
+    "answer": fields.String,
+    "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
+    "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
+    "created_at": TimestampField,
 }
 
 
 class SavedMessageListApi(WebApiResource):
     saved_message_infinite_scroll_pagination_fields = {
-        'limit': fields.Integer,
-        'has_more': fields.Boolean,
-        'data': fields.List(fields.Nested(message_fields))
+        "limit": fields.Integer,
+        "has_more": fields.Boolean,
+        "data": fields.List(fields.Nested(message_fields)),
     }
 
     @marshal_with(saved_message_infinite_scroll_pagination_fields)
     def get(self, app_model, end_user):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('last_id', type=uuid_value, location='args')
-        parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
+        parser.add_argument("last_id", type=uuid_value, location="args")
+        parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
         args = parser.parse_args()
 
-        return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
+        return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
 
     def post(self, app_model, end_user):
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         parser = reqparse.RequestParser()
-        parser.add_argument('message_id', type=uuid_value, required=True, location='json')
+        parser.add_argument("message_id", type=uuid_value, required=True, location="json")
         args = parser.parse_args()
 
         try:
-            SavedMessageService.save(app_model, end_user, args['message_id'])
+            SavedMessageService.save(app_model, end_user, args["message_id"])
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
 class SavedMessageApi(WebApiResource):
     def delete(self, app_model, end_user, message_id):
         message_id = str(message_id)
 
-        if app_model.mode != 'completion':
+        if app_model.mode != "completion":
             raise NotCompletionAppError()
 
         SavedMessageService.delete(app_model, end_user, message_id)
 
-        return {'result': 'success'}
+        return {"result": "success"}
 
 
-api.add_resource(SavedMessageListApi, '/saved-messages')
-api.add_resource(SavedMessageApi, '/saved-messages/<uuid:message_id>')
+api.add_resource(SavedMessageListApi, "/saved-messages")
+api.add_resource(SavedMessageApi, "/saved-messages/<uuid:message_id>")

部分文件因为文件数量过多而无法显示