dataset.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from flask import request
  2. from flask_restful import marshal, reqparse
  3. from werkzeug.exceptions import NotFound
  4. import services.dataset_service
  5. from controllers.service_api import api
  6. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
  7. from controllers.service_api.wraps import DatasetApiResource
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.provider_manager import ProviderManager
  10. from fields.dataset_fields import dataset_detail_fields
  11. from libs.login import current_user
  12. from models.dataset import Dataset, DatasetPermissionEnum
  13. from services.dataset_service import DatasetService
  14. def _validate_name(name):
  15. if not name or len(name) < 1 or len(name) > 40:
  16. raise ValueError('Name must be between 1 to 40 characters.')
  17. return name
  18. class DatasetListApi(DatasetApiResource):
  19. """Resource for datasets."""
  20. def get(self, tenant_id):
  21. """Resource for getting datasets."""
  22. page = request.args.get('page', default=1, type=int)
  23. limit = request.args.get('limit', default=20, type=int)
  24. provider = request.args.get('provider', default="vendor")
  25. search = request.args.get('keyword', default=None, type=str)
  26. tag_ids = request.args.getlist('tag_ids')
  27. datasets, total = DatasetService.get_datasets(page, limit, provider,
  28. tenant_id, current_user, search, tag_ids)
  29. # check embedding setting
  30. provider_manager = ProviderManager()
  31. configurations = provider_manager.get_configurations(
  32. tenant_id=current_user.current_tenant_id
  33. )
  34. embedding_models = configurations.get_models(
  35. model_type=ModelType.TEXT_EMBEDDING,
  36. only_active=True
  37. )
  38. model_names = []
  39. for embedding_model in embedding_models:
  40. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  41. data = marshal(datasets, dataset_detail_fields)
  42. for item in data:
  43. if item['indexing_technique'] == 'high_quality':
  44. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  45. if item_model in model_names:
  46. item['embedding_available'] = True
  47. else:
  48. item['embedding_available'] = False
  49. else:
  50. item['embedding_available'] = True
  51. response = {
  52. 'data': data,
  53. 'has_more': len(datasets) == limit,
  54. 'limit': limit,
  55. 'total': total,
  56. 'page': page
  57. }
  58. return response, 200
  59. def post(self, tenant_id):
  60. """Resource for creating datasets."""
  61. parser = reqparse.RequestParser()
  62. parser.add_argument('name', nullable=False, required=True,
  63. help='type is required. Name must be between 1 to 40 characters.',
  64. type=_validate_name)
  65. parser.add_argument('indexing_technique', type=str, location='json',
  66. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  67. help='Invalid indexing technique.')
  68. parser.add_argument('permission', type=str, location='json', choices=(
  69. DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False)
  70. args = parser.parse_args()
  71. try:
  72. dataset = DatasetService.create_empty_dataset(
  73. tenant_id=tenant_id,
  74. name=args['name'],
  75. indexing_technique=args['indexing_technique'],
  76. account=current_user,
  77. permission=args['permission']
  78. )
  79. except services.errors.dataset.DatasetNameDuplicateError:
  80. raise DatasetNameDuplicateError()
  81. return marshal(dataset, dataset_detail_fields), 200
  82. class DatasetApi(DatasetApiResource):
  83. """Resource for dataset."""
  84. def delete(self, _, dataset_id):
  85. """
  86. Deletes a dataset given its ID.
  87. Args:
  88. dataset_id (UUID): The ID of the dataset to be deleted.
  89. Returns:
  90. dict: A dictionary with a key 'result' and a value 'success'
  91. if the dataset was successfully deleted. Omitted in HTTP response.
  92. int: HTTP status code 204 indicating that the operation was successful.
  93. Raises:
  94. NotFound: If the dataset with the given ID does not exist.
  95. """
  96. dataset_id_str = str(dataset_id)
  97. try:
  98. if DatasetService.delete_dataset(dataset_id_str, current_user):
  99. return {'result': 'success'}, 204
  100. else:
  101. raise NotFound("Dataset not found.")
  102. except services.errors.dataset.DatasetInUseError:
  103. raise DatasetInUseError()
  104. api.add_resource(DatasetListApi, '/datasets')
  105. api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')