datasets_segments.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # -*- coding:utf-8 -*-
  2. import uuid
  3. from datetime import datetime
  4. from flask import request
  5. from flask_login import login_required, current_user
  6. from flask_restful import Resource, reqparse, fields, marshal
  7. from werkzeug.exceptions import NotFound, Forbidden
  8. import services
  9. from controllers.console import api
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
  12. from controllers.console.setup import setup_required
  13. from controllers.console.wraps import account_initialization_required
  14. from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
  15. from core.model_providers.model_factory import ModelFactory
  16. from extensions.ext_database import db
  17. from extensions.ext_redis import redis_client
  18. from models.dataset import DocumentSegment
  19. from libs.helper import TimestampField
  20. from services.dataset_service import DatasetService, DocumentService, SegmentService
  21. from tasks.enable_segment_to_index_task import enable_segment_to_index_task
  22. from tasks.disable_segment_from_index_task import disable_segment_from_index_task
  23. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  24. import pandas as pd
  25. segment_fields = {
  26. 'id': fields.String,
  27. 'position': fields.Integer,
  28. 'document_id': fields.String,
  29. 'content': fields.String,
  30. 'answer': fields.String,
  31. 'word_count': fields.Integer,
  32. 'tokens': fields.Integer,
  33. 'keywords': fields.List(fields.String),
  34. 'index_node_id': fields.String,
  35. 'index_node_hash': fields.String,
  36. 'hit_count': fields.Integer,
  37. 'enabled': fields.Boolean,
  38. 'disabled_at': TimestampField,
  39. 'disabled_by': fields.String,
  40. 'status': fields.String,
  41. 'created_by': fields.String,
  42. 'created_at': TimestampField,
  43. 'indexing_at': TimestampField,
  44. 'completed_at': TimestampField,
  45. 'error': fields.String,
  46. 'stopped_at': TimestampField
  47. }
  48. segment_list_response = {
  49. 'data': fields.List(fields.Nested(segment_fields)),
  50. 'has_more': fields.Boolean,
  51. 'limit': fields.Integer
  52. }
  53. class DatasetDocumentSegmentListApi(Resource):
  54. @setup_required
  55. @login_required
  56. @account_initialization_required
  57. def get(self, dataset_id, document_id):
  58. dataset_id = str(dataset_id)
  59. document_id = str(document_id)
  60. dataset = DatasetService.get_dataset(dataset_id)
  61. if not dataset:
  62. raise NotFound('Dataset not found.')
  63. try:
  64. DatasetService.check_dataset_permission(dataset, current_user)
  65. except services.errors.account.NoPermissionError as e:
  66. raise Forbidden(str(e))
  67. document = DocumentService.get_document(dataset_id, document_id)
  68. if not document:
  69. raise NotFound('Document not found.')
  70. parser = reqparse.RequestParser()
  71. parser.add_argument('last_id', type=str, default=None, location='args')
  72. parser.add_argument('limit', type=int, default=20, location='args')
  73. parser.add_argument('status', type=str,
  74. action='append', default=[], location='args')
  75. parser.add_argument('hit_count_gte', type=int,
  76. default=None, location='args')
  77. parser.add_argument('enabled', type=str, default='all', location='args')
  78. parser.add_argument('keyword', type=str, default=None, location='args')
  79. args = parser.parse_args()
  80. last_id = args['last_id']
  81. limit = min(args['limit'], 100)
  82. status_list = args['status']
  83. hit_count_gte = args['hit_count_gte']
  84. keyword = args['keyword']
  85. query = DocumentSegment.query.filter(
  86. DocumentSegment.document_id == str(document_id),
  87. DocumentSegment.tenant_id == current_user.current_tenant_id
  88. )
  89. if last_id is not None:
  90. last_segment = DocumentSegment.query.get(str(last_id))
  91. if last_segment:
  92. query = query.filter(
  93. DocumentSegment.position > last_segment.position)
  94. else:
  95. return {'data': [], 'has_more': False, 'limit': limit}, 200
  96. if status_list:
  97. query = query.filter(DocumentSegment.status.in_(status_list))
  98. if hit_count_gte is not None:
  99. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  100. if keyword:
  101. query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
  102. if args['enabled'].lower() != 'all':
  103. if args['enabled'].lower() == 'true':
  104. query = query.filter(DocumentSegment.enabled == True)
  105. elif args['enabled'].lower() == 'false':
  106. query = query.filter(DocumentSegment.enabled == False)
  107. total = query.count()
  108. segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
  109. has_more = False
  110. if len(segments) > limit:
  111. has_more = True
  112. segments = segments[:-1]
  113. return {
  114. 'data': marshal(segments, segment_fields),
  115. 'doc_form': document.doc_form,
  116. 'has_more': has_more,
  117. 'limit': limit,
  118. 'total': total
  119. }, 200
  120. class DatasetDocumentSegmentApi(Resource):
  121. @setup_required
  122. @login_required
  123. @account_initialization_required
  124. def patch(self, dataset_id, segment_id, action):
  125. dataset_id = str(dataset_id)
  126. dataset = DatasetService.get_dataset(dataset_id)
  127. if not dataset:
  128. raise NotFound('Dataset not found.')
  129. # The role of the current user in the ta table must be admin or owner
  130. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  131. raise Forbidden()
  132. try:
  133. DatasetService.check_dataset_permission(dataset, current_user)
  134. except services.errors.account.NoPermissionError as e:
  135. raise Forbidden(str(e))
  136. # check embedding model setting
  137. try:
  138. ModelFactory.get_embedding_model(
  139. tenant_id=current_user.current_tenant_id,
  140. model_provider_name=dataset.embedding_model_provider,
  141. model_name=dataset.embedding_model
  142. )
  143. except LLMBadRequestError:
  144. raise ProviderNotInitializeError(
  145. f"No Embedding Model available. Please configure a valid provider "
  146. f"in the Settings -> Model Provider.")
  147. except ProviderTokenNotInitError as ex:
  148. raise ProviderNotInitializeError(ex.description)
  149. segment = DocumentSegment.query.filter(
  150. DocumentSegment.id == str(segment_id),
  151. DocumentSegment.tenant_id == current_user.current_tenant_id
  152. ).first()
  153. if not segment:
  154. raise NotFound('Segment not found.')
  155. document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
  156. cache_result = redis_client.get(document_indexing_cache_key)
  157. if cache_result is not None:
  158. raise InvalidActionError("Document is being indexed, please try again later")
  159. indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
  160. cache_result = redis_client.get(indexing_cache_key)
  161. if cache_result is not None:
  162. raise InvalidActionError("Segment is being indexed, please try again later")
  163. if action == "enable":
  164. if segment.enabled:
  165. raise InvalidActionError("Segment is already enabled.")
  166. segment.enabled = True
  167. segment.disabled_at = None
  168. segment.disabled_by = None
  169. db.session.commit()
  170. # Set cache to prevent indexing the same segment multiple times
  171. redis_client.setex(indexing_cache_key, 600, 1)
  172. enable_segment_to_index_task.delay(segment.id)
  173. return {'result': 'success'}, 200
  174. elif action == "disable":
  175. if not segment.enabled:
  176. raise InvalidActionError("Segment is already disabled.")
  177. segment.enabled = False
  178. segment.disabled_at = datetime.utcnow()
  179. segment.disabled_by = current_user.id
  180. db.session.commit()
  181. # Set cache to prevent indexing the same segment multiple times
  182. redis_client.setex(indexing_cache_key, 600, 1)
  183. disable_segment_from_index_task.delay(segment.id)
  184. return {'result': 'success'}, 200
  185. else:
  186. raise InvalidActionError()
  187. class DatasetDocumentSegmentAddApi(Resource):
  188. @setup_required
  189. @login_required
  190. @account_initialization_required
  191. def post(self, dataset_id, document_id):
  192. # check dataset
  193. dataset_id = str(dataset_id)
  194. dataset = DatasetService.get_dataset(dataset_id)
  195. if not dataset:
  196. raise NotFound('Dataset not found.')
  197. # check document
  198. document_id = str(document_id)
  199. document = DocumentService.get_document(dataset_id, document_id)
  200. if not document:
  201. raise NotFound('Document not found.')
  202. # The role of the current user in the ta table must be admin or owner
  203. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  204. raise Forbidden()
  205. # check embedding model setting
  206. try:
  207. ModelFactory.get_embedding_model(
  208. tenant_id=current_user.current_tenant_id,
  209. model_provider_name=dataset.embedding_model_provider,
  210. model_name=dataset.embedding_model
  211. )
  212. except LLMBadRequestError:
  213. raise ProviderNotInitializeError(
  214. f"No Embedding Model available. Please configure a valid provider "
  215. f"in the Settings -> Model Provider.")
  216. except ProviderTokenNotInitError as ex:
  217. raise ProviderNotInitializeError(ex.description)
  218. try:
  219. DatasetService.check_dataset_permission(dataset, current_user)
  220. except services.errors.account.NoPermissionError as e:
  221. raise Forbidden(str(e))
  222. # validate args
  223. parser = reqparse.RequestParser()
  224. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  225. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  226. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  227. args = parser.parse_args()
  228. SegmentService.segment_create_args_validate(args, document)
  229. segment = SegmentService.create_segment(args, document, dataset)
  230. return {
  231. 'data': marshal(segment, segment_fields),
  232. 'doc_form': document.doc_form
  233. }, 200
  234. class DatasetDocumentSegmentUpdateApi(Resource):
  235. @setup_required
  236. @login_required
  237. @account_initialization_required
  238. def patch(self, dataset_id, document_id, segment_id):
  239. # check dataset
  240. dataset_id = str(dataset_id)
  241. dataset = DatasetService.get_dataset(dataset_id)
  242. if not dataset:
  243. raise NotFound('Dataset not found.')
  244. # check document
  245. document_id = str(document_id)
  246. document = DocumentService.get_document(dataset_id, document_id)
  247. if not document:
  248. raise NotFound('Document not found.')
  249. # check embedding model setting
  250. try:
  251. ModelFactory.get_embedding_model(
  252. tenant_id=current_user.current_tenant_id,
  253. model_provider_name=dataset.embedding_model_provider,
  254. model_name=dataset.embedding_model
  255. )
  256. except LLMBadRequestError:
  257. raise ProviderNotInitializeError(
  258. f"No Embedding Model available. Please configure a valid provider "
  259. f"in the Settings -> Model Provider.")
  260. except ProviderTokenNotInitError as ex:
  261. raise ProviderNotInitializeError(ex.description)
  262. # check segment
  263. segment_id = str(segment_id)
  264. segment = DocumentSegment.query.filter(
  265. DocumentSegment.id == str(segment_id),
  266. DocumentSegment.tenant_id == current_user.current_tenant_id
  267. ).first()
  268. if not segment:
  269. raise NotFound('Segment not found.')
  270. # The role of the current user in the ta table must be admin or owner
  271. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  272. raise Forbidden()
  273. try:
  274. DatasetService.check_dataset_permission(dataset, current_user)
  275. except services.errors.account.NoPermissionError as e:
  276. raise Forbidden(str(e))
  277. # validate args
  278. parser = reqparse.RequestParser()
  279. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  280. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  281. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  282. args = parser.parse_args()
  283. SegmentService.segment_create_args_validate(args, document)
  284. segment = SegmentService.update_segment(args, segment, document, dataset)
  285. return {
  286. 'data': marshal(segment, segment_fields),
  287. 'doc_form': document.doc_form
  288. }, 200
  289. @setup_required
  290. @login_required
  291. @account_initialization_required
  292. def delete(self, dataset_id, document_id, segment_id):
  293. # check dataset
  294. dataset_id = str(dataset_id)
  295. dataset = DatasetService.get_dataset(dataset_id)
  296. if not dataset:
  297. raise NotFound('Dataset not found.')
  298. # check document
  299. document_id = str(document_id)
  300. document = DocumentService.get_document(dataset_id, document_id)
  301. if not document:
  302. raise NotFound('Document not found.')
  303. # check segment
  304. segment_id = str(segment_id)
  305. segment = DocumentSegment.query.filter(
  306. DocumentSegment.id == str(segment_id),
  307. DocumentSegment.tenant_id == current_user.current_tenant_id
  308. ).first()
  309. if not segment:
  310. raise NotFound('Segment not found.')
  311. # The role of the current user in the ta table must be admin or owner
  312. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  313. raise Forbidden()
  314. try:
  315. DatasetService.check_dataset_permission(dataset, current_user)
  316. except services.errors.account.NoPermissionError as e:
  317. raise Forbidden(str(e))
  318. SegmentService.delete_segment(segment, document, dataset)
  319. return {'result': 'success'}, 200
  320. class DatasetDocumentSegmentBatchImportApi(Resource):
  321. @setup_required
  322. @login_required
  323. @account_initialization_required
  324. def post(self, dataset_id, document_id):
  325. # check dataset
  326. dataset_id = str(dataset_id)
  327. dataset = DatasetService.get_dataset(dataset_id)
  328. if not dataset:
  329. raise NotFound('Dataset not found.')
  330. # check document
  331. document_id = str(document_id)
  332. document = DocumentService.get_document(dataset_id, document_id)
  333. if not document:
  334. raise NotFound('Document not found.')
  335. try:
  336. ModelFactory.get_embedding_model(
  337. tenant_id=current_user.current_tenant_id,
  338. model_provider_name=dataset.embedding_model_provider,
  339. model_name=dataset.embedding_model
  340. )
  341. except LLMBadRequestError:
  342. raise ProviderNotInitializeError(
  343. f"No Embedding Model available. Please configure a valid provider "
  344. f"in the Settings -> Model Provider.")
  345. except ProviderTokenNotInitError as ex:
  346. raise ProviderNotInitializeError(ex.description)
  347. # get file from request
  348. file = request.files['file']
  349. # check file
  350. if 'file' not in request.files:
  351. raise NoFileUploadedError()
  352. if len(request.files) > 1:
  353. raise TooManyFilesError()
  354. # check file type
  355. if not file.filename.endswith('.csv'):
  356. raise ValueError("Invalid file type. Only CSV files are allowed")
  357. try:
  358. # Skip the first row
  359. df = pd.read_csv(file)
  360. result = []
  361. for index, row in df.iterrows():
  362. if document.doc_form == 'qa_model':
  363. data = {'content': row[0], 'answer': row[1]}
  364. else:
  365. data = {'content': row[0]}
  366. result.append(data)
  367. if len(result) == 0:
  368. raise ValueError("The CSV file is empty.")
  369. # async job
  370. job_id = str(uuid.uuid4())
  371. indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
  372. # send batch add segments task
  373. redis_client.setnx(indexing_cache_key, 'waiting')
  374. batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
  375. current_user.current_tenant_id, current_user.id)
  376. except Exception as e:
  377. return {'error': str(e)}, 500
  378. return {
  379. 'job_id': job_id,
  380. 'job_status': 'waiting'
  381. }, 200
  382. @setup_required
  383. @login_required
  384. @account_initialization_required
  385. def get(self, job_id):
  386. job_id = str(job_id)
  387. indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
  388. cache_result = redis_client.get(indexing_cache_key)
  389. if cache_result is None:
  390. raise ValueError("The job is not exist.")
  391. return {
  392. 'job_id': job_id,
  393. 'job_status': cache_result.decode()
  394. }, 200
  395. api.add_resource(DatasetDocumentSegmentListApi,
  396. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
  397. api.add_resource(DatasetDocumentSegmentApi,
  398. '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
  399. api.add_resource(DatasetDocumentSegmentAddApi,
  400. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
  401. api.add_resource(DatasetDocumentSegmentUpdateApi,
  402. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
  403. api.add_resource(DatasetDocumentSegmentBatchImportApi,
  404. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
  405. '/datasets/batch_import_status/<uuid:job_id>')