datasets.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # -*- coding:utf-8 -*-
  2. from flask import request
  3. from flask_login import current_user
  4. from core.login.login import login_required
  5. from flask_restful import Resource, reqparse, fields, marshal, marshal_with
  6. from werkzeug.exceptions import NotFound, Forbidden
  7. import services
  8. from controllers.console import api
  9. from controllers.console.app.error import ProviderNotInitializeError
  10. from controllers.console.datasets.error import DatasetNameDuplicateError
  11. from controllers.console.setup import setup_required
  12. from controllers.console.wraps import account_initialization_required
  13. from core.indexing_runner import IndexingRunner
  14. from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
  15. from core.model_providers.model_factory import ModelFactory
  16. from core.model_providers.models.entity.model_params import ModelType
  17. from libs.helper import TimestampField
  18. from extensions.ext_database import db
  19. from models.dataset import DocumentSegment, Document
  20. from models.model import UploadFile
  21. from services.dataset_service import DatasetService, DocumentService
  22. from services.provider_service import ProviderService
  23. dataset_detail_fields = {
  24. 'id': fields.String,
  25. 'name': fields.String,
  26. 'description': fields.String,
  27. 'provider': fields.String,
  28. 'permission': fields.String,
  29. 'data_source_type': fields.String,
  30. 'indexing_technique': fields.String,
  31. 'app_count': fields.Integer,
  32. 'document_count': fields.Integer,
  33. 'word_count': fields.Integer,
  34. 'created_by': fields.String,
  35. 'created_at': TimestampField,
  36. 'updated_by': fields.String,
  37. 'updated_at': TimestampField,
  38. 'embedding_model': fields.String,
  39. 'embedding_model_provider': fields.String,
  40. 'embedding_available': fields.Boolean
  41. }
  42. dataset_query_detail_fields = {
  43. "id": fields.String,
  44. "content": fields.String,
  45. "source": fields.String,
  46. "source_app_id": fields.String,
  47. "created_by_role": fields.String,
  48. "created_by": fields.String,
  49. "created_at": TimestampField
  50. }
  51. def _validate_name(name):
  52. if not name or len(name) < 1 or len(name) > 40:
  53. raise ValueError('Name must be between 1 to 40 characters.')
  54. return name
  55. def _validate_description_length(description):
  56. if len(description) > 400:
  57. raise ValueError('Description cannot exceed 400 characters.')
  58. return description
  59. class DatasetListApi(Resource):
  60. @setup_required
  61. @login_required
  62. @account_initialization_required
  63. def get(self):
  64. page = request.args.get('page', default=1, type=int)
  65. limit = request.args.get('limit', default=20, type=int)
  66. ids = request.args.getlist('ids')
  67. provider = request.args.get('provider', default="vendor")
  68. if ids:
  69. datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
  70. else:
  71. datasets, total = DatasetService.get_datasets(page, limit, provider,
  72. current_user.current_tenant_id, current_user)
  73. # check embedding setting
  74. provider_service = ProviderService()
  75. valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
  76. # if len(valid_model_list) == 0:
  77. # raise ProviderNotInitializeError(
  78. # f"No Embedding Model available. Please configure a valid provider "
  79. # f"in the Settings -> Model Provider.")
  80. model_names = []
  81. for valid_model in valid_model_list:
  82. model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
  83. data = marshal(datasets, dataset_detail_fields)
  84. for item in data:
  85. if item['indexing_technique'] == 'high_quality':
  86. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  87. if item_model in model_names:
  88. item['embedding_available'] = True
  89. else:
  90. item['embedding_available'] = False
  91. else:
  92. item['embedding_available'] = True
  93. response = {
  94. 'data': data,
  95. 'has_more': len(datasets) == limit,
  96. 'limit': limit,
  97. 'total': total,
  98. 'page': page
  99. }
  100. return response, 200
  101. @setup_required
  102. @login_required
  103. @account_initialization_required
  104. def post(self):
  105. parser = reqparse.RequestParser()
  106. parser.add_argument('name', nullable=False, required=True,
  107. help='type is required. Name must be between 1 to 40 characters.',
  108. type=_validate_name)
  109. parser.add_argument('indexing_technique', type=str, location='json',
  110. choices=('high_quality', 'economy'),
  111. help='Invalid indexing technique.')
  112. args = parser.parse_args()
  113. # The role of the current user in the ta table must be admin or owner
  114. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  115. raise Forbidden()
  116. try:
  117. dataset = DatasetService.create_empty_dataset(
  118. tenant_id=current_user.current_tenant_id,
  119. name=args['name'],
  120. indexing_technique=args['indexing_technique'],
  121. account=current_user
  122. )
  123. except services.errors.dataset.DatasetNameDuplicateError:
  124. raise DatasetNameDuplicateError()
  125. return marshal(dataset, dataset_detail_fields), 201
  126. class DatasetApi(Resource):
  127. @setup_required
  128. @login_required
  129. @account_initialization_required
  130. def get(self, dataset_id):
  131. dataset_id_str = str(dataset_id)
  132. dataset = DatasetService.get_dataset(dataset_id_str)
  133. if dataset is None:
  134. raise NotFound("Dataset not found.")
  135. try:
  136. DatasetService.check_dataset_permission(
  137. dataset, current_user)
  138. except services.errors.account.NoPermissionError as e:
  139. raise Forbidden(str(e))
  140. data = marshal(dataset, dataset_detail_fields)
  141. # check embedding setting
  142. provider_service = ProviderService()
  143. # get valid model list
  144. valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
  145. model_names = []
  146. for valid_model in valid_model_list:
  147. model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
  148. if data['indexing_technique'] == 'high_quality':
  149. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  150. if item_model in model_names:
  151. data['embedding_available'] = True
  152. else:
  153. data['embedding_available'] = False
  154. else:
  155. data['embedding_available'] = True
  156. return data, 200
  157. @setup_required
  158. @login_required
  159. @account_initialization_required
  160. def patch(self, dataset_id):
  161. dataset_id_str = str(dataset_id)
  162. dataset = DatasetService.get_dataset(dataset_id_str)
  163. if dataset is None:
  164. raise NotFound("Dataset not found.")
  165. # check user's model setting
  166. DatasetService.check_dataset_model_setting(dataset)
  167. parser = reqparse.RequestParser()
  168. parser.add_argument('name', nullable=False,
  169. help='type is required. Name must be between 1 to 40 characters.',
  170. type=_validate_name)
  171. parser.add_argument('description',
  172. location='json', store_missing=False,
  173. type=_validate_description_length)
  174. parser.add_argument('indexing_technique', type=str, location='json',
  175. choices=('high_quality', 'economy'),
  176. help='Invalid indexing technique.')
  177. parser.add_argument('permission', type=str, location='json', choices=(
  178. 'only_me', 'all_team_members'), help='Invalid permission.')
  179. args = parser.parse_args()
  180. # The role of the current user in the ta table must be admin or owner
  181. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  182. raise Forbidden()
  183. dataset = DatasetService.update_dataset(
  184. dataset_id_str, args, current_user)
  185. if dataset is None:
  186. raise NotFound("Dataset not found.")
  187. return marshal(dataset, dataset_detail_fields), 200
  188. @setup_required
  189. @login_required
  190. @account_initialization_required
  191. def delete(self, dataset_id):
  192. dataset_id_str = str(dataset_id)
  193. # The role of the current user in the ta table must be admin or owner
  194. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  195. raise Forbidden()
  196. if DatasetService.delete_dataset(dataset_id_str, current_user):
  197. return {'result': 'success'}, 204
  198. else:
  199. raise NotFound("Dataset not found.")
  200. class DatasetQueryApi(Resource):
  201. @setup_required
  202. @login_required
  203. @account_initialization_required
  204. def get(self, dataset_id):
  205. dataset_id_str = str(dataset_id)
  206. dataset = DatasetService.get_dataset(dataset_id_str)
  207. if dataset is None:
  208. raise NotFound("Dataset not found.")
  209. try:
  210. DatasetService.check_dataset_permission(dataset, current_user)
  211. except services.errors.account.NoPermissionError as e:
  212. raise Forbidden(str(e))
  213. page = request.args.get('page', default=1, type=int)
  214. limit = request.args.get('limit', default=20, type=int)
  215. dataset_queries, total = DatasetService.get_dataset_queries(
  216. dataset_id=dataset.id,
  217. page=page,
  218. per_page=limit
  219. )
  220. response = {
  221. 'data': marshal(dataset_queries, dataset_query_detail_fields),
  222. 'has_more': len(dataset_queries) == limit,
  223. 'limit': limit,
  224. 'total': total,
  225. 'page': page
  226. }
  227. return response, 200
  228. class DatasetIndexingEstimateApi(Resource):
  229. @setup_required
  230. @login_required
  231. @account_initialization_required
  232. def post(self):
  233. parser = reqparse.RequestParser()
  234. parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
  235. parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
  236. parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
  237. parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
  238. parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
  239. parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
  240. args = parser.parse_args()
  241. # validate args
  242. DocumentService.estimate_args_validate(args)
  243. if args['info_list']['data_source_type'] == 'upload_file':
  244. file_ids = args['info_list']['file_info_list']['file_ids']
  245. file_details = db.session.query(UploadFile).filter(
  246. UploadFile.tenant_id == current_user.current_tenant_id,
  247. UploadFile.id.in_(file_ids)
  248. ).all()
  249. if file_details is None:
  250. raise NotFound("File not found.")
  251. indexing_runner = IndexingRunner()
  252. try:
  253. response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
  254. args['process_rule'], args['doc_form'],
  255. args['doc_language'], args['dataset_id'],
  256. args['indexing_technique'])
  257. except LLMBadRequestError:
  258. raise ProviderNotInitializeError(
  259. f"No Embedding Model available. Please configure a valid provider "
  260. f"in the Settings -> Model Provider.")
  261. except ProviderTokenNotInitError as ex:
  262. raise ProviderNotInitializeError(ex.description)
  263. elif args['info_list']['data_source_type'] == 'notion_import':
  264. indexing_runner = IndexingRunner()
  265. try:
  266. response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
  267. args['info_list']['notion_info_list'],
  268. args['process_rule'], args['doc_form'],
  269. args['doc_language'], args['dataset_id'],
  270. args['indexing_technique'])
  271. except LLMBadRequestError:
  272. raise ProviderNotInitializeError(
  273. f"No Embedding Model available. Please configure a valid provider "
  274. f"in the Settings -> Model Provider.")
  275. except ProviderTokenNotInitError as ex:
  276. raise ProviderNotInitializeError(ex.description)
  277. else:
  278. raise ValueError('Data source type not support')
  279. return response, 200
  280. class DatasetRelatedAppListApi(Resource):
  281. app_detail_kernel_fields = {
  282. 'id': fields.String,
  283. 'name': fields.String,
  284. 'mode': fields.String,
  285. 'icon': fields.String,
  286. 'icon_background': fields.String,
  287. }
  288. related_app_list = {
  289. 'data': fields.List(fields.Nested(app_detail_kernel_fields)),
  290. 'total': fields.Integer,
  291. }
  292. @setup_required
  293. @login_required
  294. @account_initialization_required
  295. @marshal_with(related_app_list)
  296. def get(self, dataset_id):
  297. dataset_id_str = str(dataset_id)
  298. dataset = DatasetService.get_dataset(dataset_id_str)
  299. if dataset is None:
  300. raise NotFound("Dataset not found.")
  301. try:
  302. DatasetService.check_dataset_permission(dataset, current_user)
  303. except services.errors.account.NoPermissionError as e:
  304. raise Forbidden(str(e))
  305. app_dataset_joins = DatasetService.get_related_apps(dataset.id)
  306. related_apps = []
  307. for app_dataset_join in app_dataset_joins:
  308. app_model = app_dataset_join.app
  309. if app_model:
  310. related_apps.append(app_model)
  311. return {
  312. 'data': related_apps,
  313. 'total': len(related_apps)
  314. }, 200
  315. class DatasetIndexingStatusApi(Resource):
  316. document_status_fields = {
  317. 'id': fields.String,
  318. 'indexing_status': fields.String,
  319. 'processing_started_at': TimestampField,
  320. 'parsing_completed_at': TimestampField,
  321. 'cleaning_completed_at': TimestampField,
  322. 'splitting_completed_at': TimestampField,
  323. 'completed_at': TimestampField,
  324. 'paused_at': TimestampField,
  325. 'error': fields.String,
  326. 'stopped_at': TimestampField,
  327. 'completed_segments': fields.Integer,
  328. 'total_segments': fields.Integer,
  329. }
  330. document_status_fields_list = {
  331. 'data': fields.List(fields.Nested(document_status_fields))
  332. }
  333. @setup_required
  334. @login_required
  335. @account_initialization_required
  336. def get(self, dataset_id):
  337. dataset_id = str(dataset_id)
  338. documents = db.session.query(Document).filter(
  339. Document.dataset_id == dataset_id,
  340. Document.tenant_id == current_user.current_tenant_id
  341. ).all()
  342. documents_status = []
  343. for document in documents:
  344. completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
  345. DocumentSegment.document_id == str(document.id),
  346. DocumentSegment.status != 're_segment').count()
  347. total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
  348. DocumentSegment.status != 're_segment').count()
  349. document.completed_segments = completed_segments
  350. document.total_segments = total_segments
  351. documents_status.append(marshal(document, self.document_status_fields))
  352. data = {
  353. 'data': documents_status
  354. }
  355. return data
  356. api.add_resource(DatasetListApi, '/datasets')
  357. api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
  358. api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
  359. api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
  360. api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
  361. api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')