datasets_segments.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. import uuid
  2. import pandas as pd
  3. from flask import request
  4. from flask_login import current_user # type: ignore
  5. from flask_restful import Resource, marshal, reqparse # type: ignore
  6. from werkzeug.exceptions import Forbidden, NotFound
  7. import services
  8. from controllers.console import api
  9. from controllers.console.app.error import ProviderNotInitializeError
  10. from controllers.console.datasets.error import (
  11. ChildChunkDeleteIndexError,
  12. ChildChunkIndexingError,
  13. InvalidActionError,
  14. NoFileUploadedError,
  15. TooManyFilesError,
  16. )
  17. from controllers.console.wraps import (
  18. account_initialization_required,
  19. cloud_edition_billing_knowledge_limit_check,
  20. cloud_edition_billing_rate_limit_check,
  21. cloud_edition_billing_resource_check,
  22. setup_required,
  23. )
  24. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  25. from core.model_manager import ModelManager
  26. from core.model_runtime.entities.model_entities import ModelType
  27. from extensions.ext_redis import redis_client
  28. from fields.segment_fields import child_chunk_fields, segment_fields
  29. from libs.login import login_required
  30. from models.dataset import ChildChunk, DocumentSegment
  31. from services.dataset_service import DatasetService, DocumentService, SegmentService
  32. from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
  33. from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
  34. from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
  35. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  36. class DatasetDocumentSegmentListApi(Resource):
  37. @setup_required
  38. @login_required
  39. @account_initialization_required
  40. def get(self, dataset_id, document_id):
  41. dataset_id = str(dataset_id)
  42. document_id = str(document_id)
  43. dataset = DatasetService.get_dataset(dataset_id)
  44. if not dataset:
  45. raise NotFound("Dataset not found.")
  46. try:
  47. DatasetService.check_dataset_permission(dataset, current_user)
  48. except services.errors.account.NoPermissionError as e:
  49. raise Forbidden(str(e))
  50. document = DocumentService.get_document(dataset_id, document_id)
  51. if not document:
  52. raise NotFound("Document not found.")
  53. parser = reqparse.RequestParser()
  54. parser.add_argument("limit", type=int, default=20, location="args")
  55. parser.add_argument("status", type=str, action="append", default=[], location="args")
  56. parser.add_argument("hit_count_gte", type=int, default=None, location="args")
  57. parser.add_argument("enabled", type=str, default="all", location="args")
  58. parser.add_argument("keyword", type=str, default=None, location="args")
  59. parser.add_argument("page", type=int, default=1, location="args")
  60. args = parser.parse_args()
  61. page = args["page"]
  62. limit = min(args["limit"], 100)
  63. status_list = args["status"]
  64. hit_count_gte = args["hit_count_gte"]
  65. keyword = args["keyword"]
  66. query = DocumentSegment.query.filter(
  67. DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  68. ).order_by(DocumentSegment.position.asc())
  69. if status_list:
  70. query = query.filter(DocumentSegment.status.in_(status_list))
  71. if hit_count_gte is not None:
  72. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  73. if keyword:
  74. query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
  75. if args["enabled"].lower() != "all":
  76. if args["enabled"].lower() == "true":
  77. query = query.filter(DocumentSegment.enabled == True)
  78. elif args["enabled"].lower() == "false":
  79. query = query.filter(DocumentSegment.enabled == False)
  80. segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
  81. response = {
  82. "data": marshal(segments.items, segment_fields),
  83. "limit": limit,
  84. "total": segments.total,
  85. "total_pages": segments.pages,
  86. "page": page,
  87. }
  88. return response, 200
  89. @setup_required
  90. @login_required
  91. @account_initialization_required
  92. @cloud_edition_billing_rate_limit_check("knowledge")
  93. def delete(self, dataset_id, document_id):
  94. # check dataset
  95. dataset_id = str(dataset_id)
  96. dataset = DatasetService.get_dataset(dataset_id)
  97. if not dataset:
  98. raise NotFound("Dataset not found.")
  99. # check user's model setting
  100. DatasetService.check_dataset_model_setting(dataset)
  101. # check document
  102. document_id = str(document_id)
  103. document = DocumentService.get_document(dataset_id, document_id)
  104. if not document:
  105. raise NotFound("Document not found.")
  106. segment_ids = request.args.getlist("segment_id")
  107. # The role of the current user in the ta table must be admin or owner
  108. if not current_user.is_editor:
  109. raise Forbidden()
  110. try:
  111. DatasetService.check_dataset_permission(dataset, current_user)
  112. except services.errors.account.NoPermissionError as e:
  113. raise Forbidden(str(e))
  114. SegmentService.delete_segments(segment_ids, document, dataset)
  115. return {"result": "success"}, 200
  116. class DatasetDocumentSegmentApi(Resource):
  117. @setup_required
  118. @login_required
  119. @account_initialization_required
  120. @cloud_edition_billing_resource_check("vector_space")
  121. @cloud_edition_billing_rate_limit_check("knowledge")
  122. def patch(self, dataset_id, document_id, action):
  123. dataset_id = str(dataset_id)
  124. dataset = DatasetService.get_dataset(dataset_id)
  125. if not dataset:
  126. raise NotFound("Dataset not found.")
  127. document_id = str(document_id)
  128. document = DocumentService.get_document(dataset_id, document_id)
  129. if not document:
  130. raise NotFound("Document not found.")
  131. # check user's model setting
  132. DatasetService.check_dataset_model_setting(dataset)
  133. # The role of the current user in the ta table must be admin, owner, or editor
  134. if not current_user.is_editor:
  135. raise Forbidden()
  136. try:
  137. DatasetService.check_dataset_permission(dataset, current_user)
  138. except services.errors.account.NoPermissionError as e:
  139. raise Forbidden(str(e))
  140. if dataset.indexing_technique == "high_quality":
  141. # check embedding model setting
  142. try:
  143. model_manager = ModelManager()
  144. model_manager.get_model_instance(
  145. tenant_id=current_user.current_tenant_id,
  146. provider=dataset.embedding_model_provider,
  147. model_type=ModelType.TEXT_EMBEDDING,
  148. model=dataset.embedding_model,
  149. )
  150. except LLMBadRequestError:
  151. raise ProviderNotInitializeError(
  152. "No Embedding Model available. Please configure a valid provider "
  153. "in the Settings -> Model Provider."
  154. )
  155. except ProviderTokenNotInitError as ex:
  156. raise ProviderNotInitializeError(ex.description)
  157. segment_ids = request.args.getlist("segment_id")
  158. document_indexing_cache_key = "document_{}_indexing".format(document.id)
  159. cache_result = redis_client.get(document_indexing_cache_key)
  160. if cache_result is not None:
  161. raise InvalidActionError("Document is being indexed, please try again later")
  162. try:
  163. SegmentService.update_segments_status(segment_ids, action, dataset, document)
  164. except Exception as e:
  165. raise InvalidActionError(str(e))
  166. return {"result": "success"}, 200
  167. class DatasetDocumentSegmentAddApi(Resource):
  168. @setup_required
  169. @login_required
  170. @account_initialization_required
  171. @cloud_edition_billing_resource_check("vector_space")
  172. @cloud_edition_billing_knowledge_limit_check("add_segment")
  173. @cloud_edition_billing_rate_limit_check("knowledge")
  174. def post(self, dataset_id, document_id):
  175. # check dataset
  176. dataset_id = str(dataset_id)
  177. dataset = DatasetService.get_dataset(dataset_id)
  178. if not dataset:
  179. raise NotFound("Dataset not found.")
  180. # check document
  181. document_id = str(document_id)
  182. document = DocumentService.get_document(dataset_id, document_id)
  183. if not document:
  184. raise NotFound("Document not found.")
  185. if not current_user.is_editor:
  186. raise Forbidden()
  187. # check embedding model setting
  188. if dataset.indexing_technique == "high_quality":
  189. try:
  190. model_manager = ModelManager()
  191. model_manager.get_model_instance(
  192. tenant_id=current_user.current_tenant_id,
  193. provider=dataset.embedding_model_provider,
  194. model_type=ModelType.TEXT_EMBEDDING,
  195. model=dataset.embedding_model,
  196. )
  197. except LLMBadRequestError:
  198. raise ProviderNotInitializeError(
  199. "No Embedding Model available. Please configure a valid provider "
  200. "in the Settings -> Model Provider."
  201. )
  202. except ProviderTokenNotInitError as ex:
  203. raise ProviderNotInitializeError(ex.description)
  204. try:
  205. DatasetService.check_dataset_permission(dataset, current_user)
  206. except services.errors.account.NoPermissionError as e:
  207. raise Forbidden(str(e))
  208. # validate args
  209. parser = reqparse.RequestParser()
  210. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  211. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  212. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  213. args = parser.parse_args()
  214. SegmentService.segment_create_args_validate(args, document)
  215. segment = SegmentService.create_segment(args, document, dataset)
  216. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  217. class DatasetDocumentSegmentUpdateApi(Resource):
  218. @setup_required
  219. @login_required
  220. @account_initialization_required
  221. @cloud_edition_billing_resource_check("vector_space")
  222. @cloud_edition_billing_rate_limit_check("knowledge")
  223. def patch(self, dataset_id, document_id, segment_id):
  224. # check dataset
  225. dataset_id = str(dataset_id)
  226. dataset = DatasetService.get_dataset(dataset_id)
  227. if not dataset:
  228. raise NotFound("Dataset not found.")
  229. # check user's model setting
  230. DatasetService.check_dataset_model_setting(dataset)
  231. # check document
  232. document_id = str(document_id)
  233. document = DocumentService.get_document(dataset_id, document_id)
  234. if not document:
  235. raise NotFound("Document not found.")
  236. if dataset.indexing_technique == "high_quality":
  237. # check embedding model setting
  238. try:
  239. model_manager = ModelManager()
  240. model_manager.get_model_instance(
  241. tenant_id=current_user.current_tenant_id,
  242. provider=dataset.embedding_model_provider,
  243. model_type=ModelType.TEXT_EMBEDDING,
  244. model=dataset.embedding_model,
  245. )
  246. except LLMBadRequestError:
  247. raise ProviderNotInitializeError(
  248. "No Embedding Model available. Please configure a valid provider "
  249. "in the Settings -> Model Provider."
  250. )
  251. except ProviderTokenNotInitError as ex:
  252. raise ProviderNotInitializeError(ex.description)
  253. # check segment
  254. segment_id = str(segment_id)
  255. segment = DocumentSegment.query.filter(
  256. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  257. ).first()
  258. if not segment:
  259. raise NotFound("Segment not found.")
  260. # The role of the current user in the ta table must be admin, owner, or editor
  261. if not current_user.is_editor:
  262. raise Forbidden()
  263. try:
  264. DatasetService.check_dataset_permission(dataset, current_user)
  265. except services.errors.account.NoPermissionError as e:
  266. raise Forbidden(str(e))
  267. # validate args
  268. parser = reqparse.RequestParser()
  269. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  270. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  271. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  272. parser.add_argument(
  273. "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
  274. )
  275. args = parser.parse_args()
  276. SegmentService.segment_create_args_validate(args, document)
  277. segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
  278. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  279. @setup_required
  280. @login_required
  281. @account_initialization_required
  282. @cloud_edition_billing_rate_limit_check("knowledge")
  283. def delete(self, dataset_id, document_id, segment_id):
  284. # check dataset
  285. dataset_id = str(dataset_id)
  286. dataset = DatasetService.get_dataset(dataset_id)
  287. if not dataset:
  288. raise NotFound("Dataset not found.")
  289. # check user's model setting
  290. DatasetService.check_dataset_model_setting(dataset)
  291. # check document
  292. document_id = str(document_id)
  293. document = DocumentService.get_document(dataset_id, document_id)
  294. if not document:
  295. raise NotFound("Document not found.")
  296. # check segment
  297. segment_id = str(segment_id)
  298. segment = DocumentSegment.query.filter(
  299. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  300. ).first()
  301. if not segment:
  302. raise NotFound("Segment not found.")
  303. # The role of the current user in the ta table must be admin or owner
  304. if not current_user.is_editor:
  305. raise Forbidden()
  306. try:
  307. DatasetService.check_dataset_permission(dataset, current_user)
  308. except services.errors.account.NoPermissionError as e:
  309. raise Forbidden(str(e))
  310. SegmentService.delete_segment(segment, document, dataset)
  311. return {"result": "success"}, 200
  312. class DatasetDocumentSegmentBatchImportApi(Resource):
  313. @setup_required
  314. @login_required
  315. @account_initialization_required
  316. @cloud_edition_billing_resource_check("vector_space")
  317. @cloud_edition_billing_knowledge_limit_check("add_segment")
  318. @cloud_edition_billing_rate_limit_check("knowledge")
  319. def post(self, dataset_id, document_id):
  320. # check dataset
  321. dataset_id = str(dataset_id)
  322. dataset = DatasetService.get_dataset(dataset_id)
  323. if not dataset:
  324. raise NotFound("Dataset not found.")
  325. # check document
  326. document_id = str(document_id)
  327. document = DocumentService.get_document(dataset_id, document_id)
  328. if not document:
  329. raise NotFound("Document not found.")
  330. # get file from request
  331. file = request.files["file"]
  332. # check file
  333. if "file" not in request.files:
  334. raise NoFileUploadedError()
  335. if len(request.files) > 1:
  336. raise TooManyFilesError()
  337. # check file type
  338. if not file.filename.endswith(".csv"):
  339. raise ValueError("Invalid file type. Only CSV files are allowed")
  340. try:
  341. # Skip the first row
  342. df = pd.read_csv(file)
  343. result = []
  344. for index, row in df.iterrows():
  345. if document.doc_form == "qa_model":
  346. data = {"content": row[0], "answer": row[1]}
  347. else:
  348. data = {"content": row[0]}
  349. result.append(data)
  350. if len(result) == 0:
  351. raise ValueError("The CSV file is empty.")
  352. # async job
  353. job_id = str(uuid.uuid4())
  354. indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
  355. # send batch add segments task
  356. redis_client.setnx(indexing_cache_key, "waiting")
  357. batch_create_segment_to_index_task.delay(
  358. str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
  359. )
  360. except Exception as e:
  361. return {"error": str(e)}, 500
  362. return {"job_id": job_id, "job_status": "waiting"}, 200
  363. @setup_required
  364. @login_required
  365. @account_initialization_required
  366. def get(self, job_id):
  367. job_id = str(job_id)
  368. indexing_cache_key = "segment_batch_import_{}".format(job_id)
  369. cache_result = redis_client.get(indexing_cache_key)
  370. if cache_result is None:
  371. raise ValueError("The job is not exist.")
  372. return {"job_id": job_id, "job_status": cache_result.decode()}, 200
  373. class ChildChunkAddApi(Resource):
  374. @setup_required
  375. @login_required
  376. @account_initialization_required
  377. @cloud_edition_billing_resource_check("vector_space")
  378. @cloud_edition_billing_knowledge_limit_check("add_segment")
  379. @cloud_edition_billing_rate_limit_check("knowledge")
  380. def post(self, dataset_id, document_id, segment_id):
  381. # check dataset
  382. dataset_id = str(dataset_id)
  383. dataset = DatasetService.get_dataset(dataset_id)
  384. if not dataset:
  385. raise NotFound("Dataset not found.")
  386. # check document
  387. document_id = str(document_id)
  388. document = DocumentService.get_document(dataset_id, document_id)
  389. if not document:
  390. raise NotFound("Document not found.")
  391. # check segment
  392. segment_id = str(segment_id)
  393. segment = DocumentSegment.query.filter(
  394. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  395. ).first()
  396. if not segment:
  397. raise NotFound("Segment not found.")
  398. if not current_user.is_editor:
  399. raise Forbidden()
  400. # check embedding model setting
  401. if dataset.indexing_technique == "high_quality":
  402. try:
  403. model_manager = ModelManager()
  404. model_manager.get_model_instance(
  405. tenant_id=current_user.current_tenant_id,
  406. provider=dataset.embedding_model_provider,
  407. model_type=ModelType.TEXT_EMBEDDING,
  408. model=dataset.embedding_model,
  409. )
  410. except LLMBadRequestError:
  411. raise ProviderNotInitializeError(
  412. "No Embedding Model available. Please configure a valid provider "
  413. "in the Settings -> Model Provider."
  414. )
  415. except ProviderTokenNotInitError as ex:
  416. raise ProviderNotInitializeError(ex.description)
  417. try:
  418. DatasetService.check_dataset_permission(dataset, current_user)
  419. except services.errors.account.NoPermissionError as e:
  420. raise Forbidden(str(e))
  421. # validate args
  422. parser = reqparse.RequestParser()
  423. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  424. args = parser.parse_args()
  425. try:
  426. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  427. except ChildChunkIndexingServiceError as e:
  428. raise ChildChunkIndexingError(str(e))
  429. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  430. @setup_required
  431. @login_required
  432. @account_initialization_required
  433. def get(self, dataset_id, document_id, segment_id):
  434. # check dataset
  435. dataset_id = str(dataset_id)
  436. dataset = DatasetService.get_dataset(dataset_id)
  437. if not dataset:
  438. raise NotFound("Dataset not found.")
  439. # check user's model setting
  440. DatasetService.check_dataset_model_setting(dataset)
  441. # check document
  442. document_id = str(document_id)
  443. document = DocumentService.get_document(dataset_id, document_id)
  444. if not document:
  445. raise NotFound("Document not found.")
  446. # check segment
  447. segment_id = str(segment_id)
  448. segment = DocumentSegment.query.filter(
  449. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  450. ).first()
  451. if not segment:
  452. raise NotFound("Segment not found.")
  453. parser = reqparse.RequestParser()
  454. parser.add_argument("limit", type=int, default=20, location="args")
  455. parser.add_argument("keyword", type=str, default=None, location="args")
  456. parser.add_argument("page", type=int, default=1, location="args")
  457. args = parser.parse_args()
  458. page = args["page"]
  459. limit = min(args["limit"], 100)
  460. keyword = args["keyword"]
  461. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  462. return {
  463. "data": marshal(child_chunks.items, child_chunk_fields),
  464. "total": child_chunks.total,
  465. "total_pages": child_chunks.pages,
  466. "page": page,
  467. "limit": limit,
  468. }, 200
  469. @setup_required
  470. @login_required
  471. @account_initialization_required
  472. @cloud_edition_billing_resource_check("vector_space")
  473. @cloud_edition_billing_rate_limit_check("knowledge")
  474. def patch(self, dataset_id, document_id, segment_id):
  475. # check dataset
  476. dataset_id = str(dataset_id)
  477. dataset = DatasetService.get_dataset(dataset_id)
  478. if not dataset:
  479. raise NotFound("Dataset not found.")
  480. # check user's model setting
  481. DatasetService.check_dataset_model_setting(dataset)
  482. # check document
  483. document_id = str(document_id)
  484. document = DocumentService.get_document(dataset_id, document_id)
  485. if not document:
  486. raise NotFound("Document not found.")
  487. # check segment
  488. segment_id = str(segment_id)
  489. segment = DocumentSegment.query.filter(
  490. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  491. ).first()
  492. if not segment:
  493. raise NotFound("Segment not found.")
  494. # The role of the current user in the ta table must be admin, owner, or editor
  495. if not current_user.is_editor:
  496. raise Forbidden()
  497. try:
  498. DatasetService.check_dataset_permission(dataset, current_user)
  499. except services.errors.account.NoPermissionError as e:
  500. raise Forbidden(str(e))
  501. # validate args
  502. parser = reqparse.RequestParser()
  503. parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
  504. args = parser.parse_args()
  505. try:
  506. chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
  507. child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
  508. except ChildChunkIndexingServiceError as e:
  509. raise ChildChunkIndexingError(str(e))
  510. return {"data": marshal(child_chunks, child_chunk_fields)}, 200
  511. class ChildChunkUpdateApi(Resource):
  512. @setup_required
  513. @login_required
  514. @account_initialization_required
  515. @cloud_edition_billing_rate_limit_check("knowledge")
  516. def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
  517. # check dataset
  518. dataset_id = str(dataset_id)
  519. dataset = DatasetService.get_dataset(dataset_id)
  520. if not dataset:
  521. raise NotFound("Dataset not found.")
  522. # check user's model setting
  523. DatasetService.check_dataset_model_setting(dataset)
  524. # check document
  525. document_id = str(document_id)
  526. document = DocumentService.get_document(dataset_id, document_id)
  527. if not document:
  528. raise NotFound("Document not found.")
  529. # check segment
  530. segment_id = str(segment_id)
  531. segment = DocumentSegment.query.filter(
  532. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  533. ).first()
  534. if not segment:
  535. raise NotFound("Segment not found.")
  536. # check child chunk
  537. child_chunk_id = str(child_chunk_id)
  538. child_chunk = ChildChunk.query.filter(
  539. ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
  540. ).first()
  541. if not child_chunk:
  542. raise NotFound("Child chunk not found.")
  543. # The role of the current user in the ta table must be admin or owner
  544. if not current_user.is_editor:
  545. raise Forbidden()
  546. try:
  547. DatasetService.check_dataset_permission(dataset, current_user)
  548. except services.errors.account.NoPermissionError as e:
  549. raise Forbidden(str(e))
  550. try:
  551. SegmentService.delete_child_chunk(child_chunk, dataset)
  552. except ChildChunkDeleteIndexServiceError as e:
  553. raise ChildChunkDeleteIndexError(str(e))
  554. return {"result": "success"}, 200
  555. @setup_required
  556. @login_required
  557. @account_initialization_required
  558. @cloud_edition_billing_resource_check("vector_space")
  559. @cloud_edition_billing_rate_limit_check("knowledge")
  560. def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
  561. # check dataset
  562. dataset_id = str(dataset_id)
  563. dataset = DatasetService.get_dataset(dataset_id)
  564. if not dataset:
  565. raise NotFound("Dataset not found.")
  566. # check user's model setting
  567. DatasetService.check_dataset_model_setting(dataset)
  568. # check document
  569. document_id = str(document_id)
  570. document = DocumentService.get_document(dataset_id, document_id)
  571. if not document:
  572. raise NotFound("Document not found.")
  573. # check segment
  574. segment_id = str(segment_id)
  575. segment = DocumentSegment.query.filter(
  576. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  577. ).first()
  578. if not segment:
  579. raise NotFound("Segment not found.")
  580. # check child chunk
  581. child_chunk_id = str(child_chunk_id)
  582. child_chunk = ChildChunk.query.filter(
  583. ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
  584. ).first()
  585. if not child_chunk:
  586. raise NotFound("Child chunk not found.")
  587. # The role of the current user in the ta table must be admin or owner
  588. if not current_user.is_editor:
  589. raise Forbidden()
  590. try:
  591. DatasetService.check_dataset_permission(dataset, current_user)
  592. except services.errors.account.NoPermissionError as e:
  593. raise Forbidden(str(e))
  594. # validate args
  595. parser = reqparse.RequestParser()
  596. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  597. args = parser.parse_args()
  598. try:
  599. child_chunk = SegmentService.update_child_chunk(
  600. args.get("content"), child_chunk, segment, document, dataset
  601. )
  602. except ChildChunkIndexingServiceError as e:
  603. raise ChildChunkIndexingError(str(e))
  604. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  605. api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  606. api.add_resource(
  607. DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
  608. )
  609. api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
  610. api.add_resource(
  611. DatasetDocumentSegmentUpdateApi,
  612. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
  613. )
  614. api.add_resource(
  615. DatasetDocumentSegmentBatchImportApi,
  616. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
  617. "/datasets/batch_import_status/<uuid:job_id>",
  618. )
  619. api.add_resource(
  620. ChildChunkAddApi,
  621. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
  622. )
  623. api.add_resource(
  624. ChildChunkUpdateApi,
  625. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
  626. )