datasets_segments.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. import uuid
  2. import pandas as pd
  3. from flask import request
  4. from flask_login import current_user # type: ignore
  5. from flask_restful import Resource, marshal, reqparse # type: ignore
  6. from werkzeug.exceptions import Forbidden, NotFound
  7. import services
  8. from controllers.console import api
  9. from controllers.console.app.error import ProviderNotInitializeError
  10. from controllers.console.datasets.error import (
  11. ChildChunkDeleteIndexError,
  12. ChildChunkIndexingError,
  13. InvalidActionError,
  14. NoFileUploadedError,
  15. TooManyFilesError,
  16. )
  17. from controllers.console.wraps import (
  18. account_initialization_required,
  19. cloud_edition_billing_knowledge_limit_check,
  20. cloud_edition_billing_resource_check,
  21. setup_required,
  22. )
  23. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  24. from core.model_manager import ModelManager
  25. from core.model_runtime.entities.model_entities import ModelType
  26. from extensions.ext_redis import redis_client
  27. from fields.segment_fields import child_chunk_fields, segment_fields
  28. from libs.login import login_required
  29. from models.dataset import ChildChunk, DocumentSegment
  30. from services.dataset_service import DatasetService, DocumentService, SegmentService
  31. from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
  32. from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
  33. from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
  34. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  35. class DatasetDocumentSegmentListApi(Resource):
  36. @setup_required
  37. @login_required
  38. @account_initialization_required
  39. def get(self, dataset_id, document_id):
  40. dataset_id = str(dataset_id)
  41. document_id = str(document_id)
  42. dataset = DatasetService.get_dataset(dataset_id)
  43. if not dataset:
  44. raise NotFound("Dataset not found.")
  45. try:
  46. DatasetService.check_dataset_permission(dataset, current_user)
  47. except services.errors.account.NoPermissionError as e:
  48. raise Forbidden(str(e))
  49. document = DocumentService.get_document(dataset_id, document_id)
  50. if not document:
  51. raise NotFound("Document not found.")
  52. parser = reqparse.RequestParser()
  53. parser.add_argument("limit", type=int, default=20, location="args")
  54. parser.add_argument("status", type=str, action="append", default=[], location="args")
  55. parser.add_argument("hit_count_gte", type=int, default=None, location="args")
  56. parser.add_argument("enabled", type=str, default="all", location="args")
  57. parser.add_argument("keyword", type=str, default=None, location="args")
  58. parser.add_argument("page", type=int, default=1, location="args")
  59. args = parser.parse_args()
  60. page = args["page"]
  61. limit = min(args["limit"], 100)
  62. status_list = args["status"]
  63. hit_count_gte = args["hit_count_gte"]
  64. keyword = args["keyword"]
  65. query = DocumentSegment.query.filter(
  66. DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  67. ).order_by(DocumentSegment.position.asc())
  68. if status_list:
  69. query = query.filter(DocumentSegment.status.in_(status_list))
  70. if hit_count_gte is not None:
  71. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  72. if keyword:
  73. query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
  74. if args["enabled"].lower() != "all":
  75. if args["enabled"].lower() == "true":
  76. query = query.filter(DocumentSegment.enabled == True)
  77. elif args["enabled"].lower() == "false":
  78. query = query.filter(DocumentSegment.enabled == False)
  79. segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
  80. response = {
  81. "data": marshal(segments.items, segment_fields),
  82. "limit": limit,
  83. "total": segments.total,
  84. "total_pages": segments.pages,
  85. "page": page,
  86. }
  87. return response, 200
  88. @setup_required
  89. @login_required
  90. @account_initialization_required
  91. def delete(self, dataset_id, document_id):
  92. # check dataset
  93. dataset_id = str(dataset_id)
  94. dataset = DatasetService.get_dataset(dataset_id)
  95. if not dataset:
  96. raise NotFound("Dataset not found.")
  97. # check user's model setting
  98. DatasetService.check_dataset_model_setting(dataset)
  99. # check document
  100. document_id = str(document_id)
  101. document = DocumentService.get_document(dataset_id, document_id)
  102. if not document:
  103. raise NotFound("Document not found.")
  104. segment_ids = request.args.getlist("segment_id")
  105. # The role of the current user in the ta table must be admin or owner
  106. if not current_user.is_editor:
  107. raise Forbidden()
  108. try:
  109. DatasetService.check_dataset_permission(dataset, current_user)
  110. except services.errors.account.NoPermissionError as e:
  111. raise Forbidden(str(e))
  112. SegmentService.delete_segments(segment_ids, document, dataset)
  113. return {"result": "success"}, 200
  114. class DatasetDocumentSegmentApi(Resource):
  115. @setup_required
  116. @login_required
  117. @account_initialization_required
  118. @cloud_edition_billing_resource_check("vector_space")
  119. def patch(self, dataset_id, document_id, action):
  120. dataset_id = str(dataset_id)
  121. dataset = DatasetService.get_dataset(dataset_id)
  122. if not dataset:
  123. raise NotFound("Dataset not found.")
  124. document_id = str(document_id)
  125. document = DocumentService.get_document(dataset_id, document_id)
  126. if not document:
  127. raise NotFound("Document not found.")
  128. # check user's model setting
  129. DatasetService.check_dataset_model_setting(dataset)
  130. # The role of the current user in the ta table must be admin, owner, or editor
  131. if not current_user.is_editor:
  132. raise Forbidden()
  133. try:
  134. DatasetService.check_dataset_permission(dataset, current_user)
  135. except services.errors.account.NoPermissionError as e:
  136. raise Forbidden(str(e))
  137. if dataset.indexing_technique == "high_quality":
  138. # check embedding model setting
  139. try:
  140. model_manager = ModelManager()
  141. model_manager.get_model_instance(
  142. tenant_id=current_user.current_tenant_id,
  143. provider=dataset.embedding_model_provider,
  144. model_type=ModelType.TEXT_EMBEDDING,
  145. model=dataset.embedding_model,
  146. )
  147. except LLMBadRequestError:
  148. raise ProviderNotInitializeError(
  149. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  150. )
  151. except ProviderTokenNotInitError as ex:
  152. raise ProviderNotInitializeError(ex.description)
  153. segment_ids = request.args.getlist("segment_id")
  154. document_indexing_cache_key = "document_{}_indexing".format(document.id)
  155. cache_result = redis_client.get(document_indexing_cache_key)
  156. if cache_result is not None:
  157. raise InvalidActionError("Document is being indexed, please try again later")
  158. try:
  159. SegmentService.update_segments_status(segment_ids, action, dataset, document)
  160. except Exception as e:
  161. raise InvalidActionError(str(e))
  162. return {"result": "success"}, 200
  163. class DatasetDocumentSegmentAddApi(Resource):
  164. @setup_required
  165. @login_required
  166. @account_initialization_required
  167. @cloud_edition_billing_resource_check("vector_space")
  168. @cloud_edition_billing_knowledge_limit_check("add_segment")
  169. def post(self, dataset_id, document_id):
  170. # check dataset
  171. dataset_id = str(dataset_id)
  172. dataset = DatasetService.get_dataset(dataset_id)
  173. if not dataset:
  174. raise NotFound("Dataset not found.")
  175. # check document
  176. document_id = str(document_id)
  177. document = DocumentService.get_document(dataset_id, document_id)
  178. if not document:
  179. raise NotFound("Document not found.")
  180. if not current_user.is_editor:
  181. raise Forbidden()
  182. # check embedding model setting
  183. if dataset.indexing_technique == "high_quality":
  184. try:
  185. model_manager = ModelManager()
  186. model_manager.get_model_instance(
  187. tenant_id=current_user.current_tenant_id,
  188. provider=dataset.embedding_model_provider,
  189. model_type=ModelType.TEXT_EMBEDDING,
  190. model=dataset.embedding_model,
  191. )
  192. except LLMBadRequestError:
  193. raise ProviderNotInitializeError(
  194. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  195. )
  196. except ProviderTokenNotInitError as ex:
  197. raise ProviderNotInitializeError(ex.description)
  198. try:
  199. DatasetService.check_dataset_permission(dataset, current_user)
  200. except services.errors.account.NoPermissionError as e:
  201. raise Forbidden(str(e))
  202. # validate args
  203. parser = reqparse.RequestParser()
  204. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  205. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  206. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  207. args = parser.parse_args()
  208. SegmentService.segment_create_args_validate(args, document)
  209. segment = SegmentService.create_segment(args, document, dataset)
  210. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  211. class DatasetDocumentSegmentUpdateApi(Resource):
  212. @setup_required
  213. @login_required
  214. @account_initialization_required
  215. @cloud_edition_billing_resource_check("vector_space")
  216. def patch(self, dataset_id, document_id, segment_id):
  217. # check dataset
  218. dataset_id = str(dataset_id)
  219. dataset = DatasetService.get_dataset(dataset_id)
  220. if not dataset:
  221. raise NotFound("Dataset not found.")
  222. # check user's model setting
  223. DatasetService.check_dataset_model_setting(dataset)
  224. # check document
  225. document_id = str(document_id)
  226. document = DocumentService.get_document(dataset_id, document_id)
  227. if not document:
  228. raise NotFound("Document not found.")
  229. if dataset.indexing_technique == "high_quality":
  230. # check embedding model setting
  231. try:
  232. model_manager = ModelManager()
  233. model_manager.get_model_instance(
  234. tenant_id=current_user.current_tenant_id,
  235. provider=dataset.embedding_model_provider,
  236. model_type=ModelType.TEXT_EMBEDDING,
  237. model=dataset.embedding_model,
  238. )
  239. except LLMBadRequestError:
  240. raise ProviderNotInitializeError(
  241. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  242. )
  243. except ProviderTokenNotInitError as ex:
  244. raise ProviderNotInitializeError(ex.description)
  245. # check segment
  246. segment_id = str(segment_id)
  247. segment = DocumentSegment.query.filter(
  248. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  249. ).first()
  250. if not segment:
  251. raise NotFound("Segment not found.")
  252. # The role of the current user in the ta table must be admin, owner, or editor
  253. if not current_user.is_editor:
  254. raise Forbidden()
  255. try:
  256. DatasetService.check_dataset_permission(dataset, current_user)
  257. except services.errors.account.NoPermissionError as e:
  258. raise Forbidden(str(e))
  259. # validate args
  260. parser = reqparse.RequestParser()
  261. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  262. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  263. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  264. parser.add_argument(
  265. "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
  266. )
  267. args = parser.parse_args()
  268. SegmentService.segment_create_args_validate(args, document)
  269. segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
  270. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  271. @setup_required
  272. @login_required
  273. @account_initialization_required
  274. def delete(self, dataset_id, document_id, segment_id):
  275. # check dataset
  276. dataset_id = str(dataset_id)
  277. dataset = DatasetService.get_dataset(dataset_id)
  278. if not dataset:
  279. raise NotFound("Dataset not found.")
  280. # check user's model setting
  281. DatasetService.check_dataset_model_setting(dataset)
  282. # check document
  283. document_id = str(document_id)
  284. document = DocumentService.get_document(dataset_id, document_id)
  285. if not document:
  286. raise NotFound("Document not found.")
  287. # check segment
  288. segment_id = str(segment_id)
  289. segment = DocumentSegment.query.filter(
  290. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  291. ).first()
  292. if not segment:
  293. raise NotFound("Segment not found.")
  294. # The role of the current user in the ta table must be admin or owner
  295. if not current_user.is_editor:
  296. raise Forbidden()
  297. try:
  298. DatasetService.check_dataset_permission(dataset, current_user)
  299. except services.errors.account.NoPermissionError as e:
  300. raise Forbidden(str(e))
  301. SegmentService.delete_segment(segment, document, dataset)
  302. return {"result": "success"}, 200
  303. class DatasetDocumentSegmentBatchImportApi(Resource):
  304. @setup_required
  305. @login_required
  306. @account_initialization_required
  307. @cloud_edition_billing_resource_check("vector_space")
  308. @cloud_edition_billing_knowledge_limit_check("add_segment")
  309. def post(self, dataset_id, document_id):
  310. # check dataset
  311. dataset_id = str(dataset_id)
  312. dataset = DatasetService.get_dataset(dataset_id)
  313. if not dataset:
  314. raise NotFound("Dataset not found.")
  315. # check document
  316. document_id = str(document_id)
  317. document = DocumentService.get_document(dataset_id, document_id)
  318. if not document:
  319. raise NotFound("Document not found.")
  320. # get file from request
  321. file = request.files["file"]
  322. # check file
  323. if "file" not in request.files:
  324. raise NoFileUploadedError()
  325. if len(request.files) > 1:
  326. raise TooManyFilesError()
  327. # check file type
  328. if not file.filename.endswith(".csv"):
  329. raise ValueError("Invalid file type. Only CSV files are allowed")
  330. try:
  331. # Skip the first row
  332. df = pd.read_csv(file)
  333. result = []
  334. for index, row in df.iterrows():
  335. if document.doc_form == "qa_model":
  336. data = {"content": row.iloc[0], "answer": row.iloc[1]}
  337. else:
  338. data = {"content": row.iloc[0]}
  339. result.append(data)
  340. if len(result) == 0:
  341. raise ValueError("The CSV file is empty.")
  342. # async job
  343. job_id = str(uuid.uuid4())
  344. indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
  345. # send batch add segments task
  346. redis_client.setnx(indexing_cache_key, "waiting")
  347. batch_create_segment_to_index_task.delay(
  348. str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
  349. )
  350. except Exception as e:
  351. return {"error": str(e)}, 500
  352. return {"job_id": job_id, "job_status": "waiting"}, 200
  353. @setup_required
  354. @login_required
  355. @account_initialization_required
  356. def get(self, job_id):
  357. job_id = str(job_id)
  358. indexing_cache_key = "segment_batch_import_{}".format(job_id)
  359. cache_result = redis_client.get(indexing_cache_key)
  360. if cache_result is None:
  361. raise ValueError("The job is not exist.")
  362. return {"job_id": job_id, "job_status": cache_result.decode()}, 200
  363. class ChildChunkAddApi(Resource):
  364. @setup_required
  365. @login_required
  366. @account_initialization_required
  367. @cloud_edition_billing_resource_check("vector_space")
  368. @cloud_edition_billing_knowledge_limit_check("add_segment")
  369. def post(self, dataset_id, document_id, segment_id):
  370. # check dataset
  371. dataset_id = str(dataset_id)
  372. dataset = DatasetService.get_dataset(dataset_id)
  373. if not dataset:
  374. raise NotFound("Dataset not found.")
  375. # check document
  376. document_id = str(document_id)
  377. document = DocumentService.get_document(dataset_id, document_id)
  378. if not document:
  379. raise NotFound("Document not found.")
  380. # check segment
  381. segment_id = str(segment_id)
  382. segment = DocumentSegment.query.filter(
  383. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  384. ).first()
  385. if not segment:
  386. raise NotFound("Segment not found.")
  387. if not current_user.is_editor:
  388. raise Forbidden()
  389. # check embedding model setting
  390. if dataset.indexing_technique == "high_quality":
  391. try:
  392. model_manager = ModelManager()
  393. model_manager.get_model_instance(
  394. tenant_id=current_user.current_tenant_id,
  395. provider=dataset.embedding_model_provider,
  396. model_type=ModelType.TEXT_EMBEDDING,
  397. model=dataset.embedding_model,
  398. )
  399. except LLMBadRequestError:
  400. raise ProviderNotInitializeError(
  401. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  402. )
  403. except ProviderTokenNotInitError as ex:
  404. raise ProviderNotInitializeError(ex.description)
  405. try:
  406. DatasetService.check_dataset_permission(dataset, current_user)
  407. except services.errors.account.NoPermissionError as e:
  408. raise Forbidden(str(e))
  409. # validate args
  410. parser = reqparse.RequestParser()
  411. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  412. args = parser.parse_args()
  413. try:
  414. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  415. except ChildChunkIndexingServiceError as e:
  416. raise ChildChunkIndexingError(str(e))
  417. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  418. @setup_required
  419. @login_required
  420. @account_initialization_required
  421. def get(self, dataset_id, document_id, segment_id):
  422. # check dataset
  423. dataset_id = str(dataset_id)
  424. dataset = DatasetService.get_dataset(dataset_id)
  425. if not dataset:
  426. raise NotFound("Dataset not found.")
  427. # check user's model setting
  428. DatasetService.check_dataset_model_setting(dataset)
  429. # check document
  430. document_id = str(document_id)
  431. document = DocumentService.get_document(dataset_id, document_id)
  432. if not document:
  433. raise NotFound("Document not found.")
  434. # check segment
  435. segment_id = str(segment_id)
  436. segment = DocumentSegment.query.filter(
  437. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  438. ).first()
  439. if not segment:
  440. raise NotFound("Segment not found.")
  441. parser = reqparse.RequestParser()
  442. parser.add_argument("limit", type=int, default=20, location="args")
  443. parser.add_argument("keyword", type=str, default=None, location="args")
  444. parser.add_argument("page", type=int, default=1, location="args")
  445. args = parser.parse_args()
  446. page = args["page"]
  447. limit = min(args["limit"], 100)
  448. keyword = args["keyword"]
  449. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  450. return {
  451. "data": marshal(child_chunks.items, child_chunk_fields),
  452. "total": child_chunks.total,
  453. "total_pages": child_chunks.pages,
  454. "page": page,
  455. "limit": limit,
  456. }, 200
  457. @setup_required
  458. @login_required
  459. @account_initialization_required
  460. @cloud_edition_billing_resource_check("vector_space")
  461. def patch(self, dataset_id, document_id, segment_id):
  462. # check dataset
  463. dataset_id = str(dataset_id)
  464. dataset = DatasetService.get_dataset(dataset_id)
  465. if not dataset:
  466. raise NotFound("Dataset not found.")
  467. # check user's model setting
  468. DatasetService.check_dataset_model_setting(dataset)
  469. # check document
  470. document_id = str(document_id)
  471. document = DocumentService.get_document(dataset_id, document_id)
  472. if not document:
  473. raise NotFound("Document not found.")
  474. # check segment
  475. segment_id = str(segment_id)
  476. segment = DocumentSegment.query.filter(
  477. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  478. ).first()
  479. if not segment:
  480. raise NotFound("Segment not found.")
  481. # The role of the current user in the ta table must be admin, owner, or editor
  482. if not current_user.is_editor:
  483. raise Forbidden()
  484. try:
  485. DatasetService.check_dataset_permission(dataset, current_user)
  486. except services.errors.account.NoPermissionError as e:
  487. raise Forbidden(str(e))
  488. # validate args
  489. parser = reqparse.RequestParser()
  490. parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
  491. args = parser.parse_args()
  492. try:
  493. chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
  494. child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
  495. except ChildChunkIndexingServiceError as e:
  496. raise ChildChunkIndexingError(str(e))
  497. return {"data": marshal(child_chunks, child_chunk_fields)}, 200
  498. class ChildChunkUpdateApi(Resource):
  499. @setup_required
  500. @login_required
  501. @account_initialization_required
  502. def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
  503. # check dataset
  504. dataset_id = str(dataset_id)
  505. dataset = DatasetService.get_dataset(dataset_id)
  506. if not dataset:
  507. raise NotFound("Dataset not found.")
  508. # check user's model setting
  509. DatasetService.check_dataset_model_setting(dataset)
  510. # check document
  511. document_id = str(document_id)
  512. document = DocumentService.get_document(dataset_id, document_id)
  513. if not document:
  514. raise NotFound("Document not found.")
  515. # check segment
  516. segment_id = str(segment_id)
  517. segment = DocumentSegment.query.filter(
  518. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  519. ).first()
  520. if not segment:
  521. raise NotFound("Segment not found.")
  522. # check child chunk
  523. child_chunk_id = str(child_chunk_id)
  524. child_chunk = ChildChunk.query.filter(
  525. ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
  526. ).first()
  527. if not child_chunk:
  528. raise NotFound("Child chunk not found.")
  529. # The role of the current user in the ta table must be admin or owner
  530. if not current_user.is_editor:
  531. raise Forbidden()
  532. try:
  533. DatasetService.check_dataset_permission(dataset, current_user)
  534. except services.errors.account.NoPermissionError as e:
  535. raise Forbidden(str(e))
  536. try:
  537. SegmentService.delete_child_chunk(child_chunk, dataset)
  538. except ChildChunkDeleteIndexServiceError as e:
  539. raise ChildChunkDeleteIndexError(str(e))
  540. return {"result": "success"}, 200
  541. @setup_required
  542. @login_required
  543. @account_initialization_required
  544. @cloud_edition_billing_resource_check("vector_space")
  545. def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
  546. # check dataset
  547. dataset_id = str(dataset_id)
  548. dataset = DatasetService.get_dataset(dataset_id)
  549. if not dataset:
  550. raise NotFound("Dataset not found.")
  551. # check user's model setting
  552. DatasetService.check_dataset_model_setting(dataset)
  553. # check document
  554. document_id = str(document_id)
  555. document = DocumentService.get_document(dataset_id, document_id)
  556. if not document:
  557. raise NotFound("Document not found.")
  558. # check segment
  559. segment_id = str(segment_id)
  560. segment = DocumentSegment.query.filter(
  561. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  562. ).first()
  563. if not segment:
  564. raise NotFound("Segment not found.")
  565. # check child chunk
  566. child_chunk_id = str(child_chunk_id)
  567. child_chunk = ChildChunk.query.filter(
  568. ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
  569. ).first()
  570. if not child_chunk:
  571. raise NotFound("Child chunk not found.")
  572. # The role of the current user in the ta table must be admin or owner
  573. if not current_user.is_editor:
  574. raise Forbidden()
  575. try:
  576. DatasetService.check_dataset_permission(dataset, current_user)
  577. except services.errors.account.NoPermissionError as e:
  578. raise Forbidden(str(e))
  579. # validate args
  580. parser = reqparse.RequestParser()
  581. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  582. args = parser.parse_args()
  583. try:
  584. child_chunk = SegmentService.update_child_chunk(
  585. args.get("content"), child_chunk, segment, document, dataset
  586. )
  587. except ChildChunkIndexingServiceError as e:
  588. raise ChildChunkIndexingError(str(e))
  589. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  590. api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  591. api.add_resource(
  592. DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
  593. )
  594. api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
  595. api.add_resource(
  596. DatasetDocumentSegmentUpdateApi,
  597. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
  598. )
  599. api.add_resource(
  600. DatasetDocumentSegmentBatchImportApi,
  601. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
  602. "/datasets/batch_import_status/<uuid:job_id>",
  603. )
  604. api.add_resource(
  605. ChildChunkAddApi,
  606. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
  607. )
  608. api.add_resource(
  609. ChildChunkUpdateApi,
  610. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
  611. )