dataset.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from flask import request
  2. from flask_restful import reqparse, marshal
  3. import services.dataset_service
  4. from controllers.service_api import api
  5. from controllers.service_api.dataset.error import DatasetNameDuplicateError
  6. from controllers.service_api.wraps import DatasetApiResource
  7. from core.login.login import current_user
  8. from core.model_providers.models.entity.model_params import ModelType
  9. from extensions.ext_database import db
  10. from fields.dataset_fields import dataset_detail_fields
  11. from models.account import Account, TenantAccountJoin
  12. from models.dataset import Dataset
  13. from services.dataset_service import DatasetService
  14. from services.provider_service import ProviderService
  15. def _validate_name(name):
  16. if not name or len(name) < 1 or len(name) > 40:
  17. raise ValueError('Name must be between 1 to 40 characters.')
  18. return name
  19. class DatasetApi(DatasetApiResource):
  20. """Resource for get datasets."""
  21. def get(self, tenant_id):
  22. page = request.args.get('page', default=1, type=int)
  23. limit = request.args.get('limit', default=20, type=int)
  24. provider = request.args.get('provider', default="vendor")
  25. datasets, total = DatasetService.get_datasets(page, limit, provider,
  26. tenant_id, current_user)
  27. # check embedding setting
  28. provider_service = ProviderService()
  29. valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
  30. ModelType.EMBEDDINGS.value)
  31. model_names = []
  32. for valid_model in valid_model_list:
  33. model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
  34. data = marshal(datasets, dataset_detail_fields)
  35. for item in data:
  36. if item['indexing_technique'] == 'high_quality':
  37. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  38. if item_model in model_names:
  39. item['embedding_available'] = True
  40. else:
  41. item['embedding_available'] = False
  42. else:
  43. item['embedding_available'] = True
  44. response = {
  45. 'data': data,
  46. 'has_more': len(datasets) == limit,
  47. 'limit': limit,
  48. 'total': total,
  49. 'page': page
  50. }
  51. return response, 200
  52. """Resource for datasets."""
  53. def post(self, tenant_id):
  54. parser = reqparse.RequestParser()
  55. parser.add_argument('name', nullable=False, required=True,
  56. help='type is required. Name must be between 1 to 40 characters.',
  57. type=_validate_name)
  58. parser.add_argument('indexing_technique', type=str, location='json',
  59. choices=('high_quality', 'economy'),
  60. help='Invalid indexing technique.')
  61. args = parser.parse_args()
  62. try:
  63. dataset = DatasetService.create_empty_dataset(
  64. tenant_id=tenant_id,
  65. name=args['name'],
  66. indexing_technique=args['indexing_technique'],
  67. account=current_user
  68. )
  69. except services.errors.dataset.DatasetNameDuplicateError:
  70. raise DatasetNameDuplicateError()
  71. return marshal(dataset, dataset_detail_fields), 200
  72. api.add_resource(DatasetApi, '/datasets')