datasets.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. import flask_restful # type: ignore
  2. from flask import request
  3. from flask_login import current_user # type: ignore # type: ignore
  4. from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
  5. from werkzeug.exceptions import Forbidden, NotFound
  6. import services
  7. from configs import dify_config
  8. from controllers.console import api
  9. from controllers.console.apikey import api_key_fields, api_key_list
  10. from controllers.console.app.error import ProviderNotInitializeError
  11. from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
  12. from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
  13. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  14. from core.indexing_runner import IndexingRunner
  15. from core.model_runtime.entities.model_entities import ModelType
  16. from core.plugin.entities.plugin import ModelProviderID
  17. from core.provider_manager import ProviderManager
  18. from core.rag.datasource.vdb.vector_type import VectorType
  19. from core.rag.extractor.entity.extract_setting import ExtractSetting
  20. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  21. from extensions.ext_database import db
  22. from fields.app_fields import related_app_list
  23. from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
  24. from fields.document_fields import document_status_fields
  25. from libs.login import login_required
  26. from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
  27. from models.dataset import DatasetPermissionEnum
  28. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  29. def _validate_name(name):
  30. if not name or len(name) < 1 or len(name) > 40:
  31. raise ValueError("Name must be between 1 to 40 characters.")
  32. return name
  33. def _validate_description_length(description):
  34. if len(description) > 400:
  35. raise ValueError("Description cannot exceed 400 characters.")
  36. return description
  37. class DatasetListApi(Resource):
  38. @setup_required
  39. @login_required
  40. @account_initialization_required
  41. @enterprise_license_required
  42. def get(self):
  43. page = request.args.get("page", default=1, type=int)
  44. limit = request.args.get("limit", default=20, type=int)
  45. ids = request.args.getlist("ids")
  46. # provider = request.args.get("provider", default="vendor")
  47. search = request.args.get("keyword", default=None, type=str)
  48. tag_ids = request.args.getlist("tag_ids")
  49. include_all = request.args.get("include_all", default="false").lower() == "true"
  50. if ids:
  51. datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
  52. else:
  53. datasets, total = DatasetService.get_datasets(
  54. page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
  55. )
  56. # check embedding setting
  57. provider_manager = ProviderManager()
  58. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  59. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  60. model_names = []
  61. for embedding_model in embedding_models:
  62. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  63. data = marshal(datasets, dataset_detail_fields)
  64. for item in data:
  65. # convert embedding_model_provider to plugin standard format
  66. if item["indexing_technique"] == "high_quality":
  67. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  68. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  69. if item_model in model_names:
  70. item["embedding_available"] = True
  71. else:
  72. item["embedding_available"] = False
  73. else:
  74. item["embedding_available"] = True
  75. if item.get("permission") == "partial_members":
  76. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
  77. item.update({"partial_member_list": part_users_list})
  78. else:
  79. item.update({"partial_member_list": []})
  80. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  81. return response, 200
  82. @setup_required
  83. @login_required
  84. @account_initialization_required
  85. def post(self):
  86. parser = reqparse.RequestParser()
  87. parser.add_argument(
  88. "name",
  89. nullable=False,
  90. required=True,
  91. help="type is required. Name must be between 1 to 40 characters.",
  92. type=_validate_name,
  93. )
  94. parser.add_argument(
  95. "description",
  96. type=str,
  97. nullable=True,
  98. required=False,
  99. default="",
  100. )
  101. parser.add_argument(
  102. "indexing_technique",
  103. type=str,
  104. location="json",
  105. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  106. nullable=True,
  107. help="Invalid indexing technique.",
  108. )
  109. parser.add_argument(
  110. "external_knowledge_api_id",
  111. type=str,
  112. nullable=True,
  113. required=False,
  114. )
  115. parser.add_argument(
  116. "provider",
  117. type=str,
  118. nullable=True,
  119. choices=Dataset.PROVIDER_LIST,
  120. required=False,
  121. default="vendor",
  122. )
  123. parser.add_argument(
  124. "external_knowledge_id",
  125. type=str,
  126. nullable=True,
  127. required=False,
  128. )
  129. args = parser.parse_args()
  130. # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
  131. if not current_user.is_dataset_editor:
  132. raise Forbidden()
  133. try:
  134. dataset = DatasetService.create_empty_dataset(
  135. tenant_id=current_user.current_tenant_id,
  136. name=args["name"],
  137. description=args["description"],
  138. indexing_technique=args["indexing_technique"],
  139. account=current_user,
  140. permission=DatasetPermissionEnum.ONLY_ME,
  141. provider=args["provider"],
  142. external_knowledge_api_id=args["external_knowledge_api_id"],
  143. external_knowledge_id=args["external_knowledge_id"],
  144. )
  145. except services.errors.dataset.DatasetNameDuplicateError:
  146. raise DatasetNameDuplicateError()
  147. return marshal(dataset, dataset_detail_fields), 201
  148. class DatasetApi(Resource):
  149. @setup_required
  150. @login_required
  151. @account_initialization_required
  152. def get(self, dataset_id):
  153. dataset_id_str = str(dataset_id)
  154. dataset = DatasetService.get_dataset(dataset_id_str)
  155. if dataset is None:
  156. raise NotFound("Dataset not found.")
  157. try:
  158. DatasetService.check_dataset_permission(dataset, current_user)
  159. except services.errors.account.NoPermissionError as e:
  160. raise Forbidden(str(e))
  161. data = marshal(dataset, dataset_detail_fields)
  162. if data.get("permission") == "partial_members":
  163. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  164. data.update({"partial_member_list": part_users_list})
  165. # check embedding setting
  166. provider_manager = ProviderManager()
  167. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  168. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  169. model_names = []
  170. for embedding_model in embedding_models:
  171. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  172. if data["indexing_technique"] == "high_quality":
  173. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  174. if item_model in model_names:
  175. data["embedding_available"] = True
  176. else:
  177. data["embedding_available"] = False
  178. else:
  179. data["embedding_available"] = True
  180. if data.get("permission") == "partial_members":
  181. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  182. data.update({"partial_member_list": part_users_list})
  183. return data, 200
  184. @setup_required
  185. @login_required
  186. @account_initialization_required
  187. def patch(self, dataset_id):
  188. dataset_id_str = str(dataset_id)
  189. dataset = DatasetService.get_dataset(dataset_id_str)
  190. if dataset is None:
  191. raise NotFound("Dataset not found.")
  192. parser = reqparse.RequestParser()
  193. parser.add_argument(
  194. "name",
  195. nullable=False,
  196. help="type is required. Name must be between 1 to 40 characters.",
  197. type=_validate_name,
  198. )
  199. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  200. parser.add_argument(
  201. "indexing_technique",
  202. type=str,
  203. location="json",
  204. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  205. nullable=True,
  206. help="Invalid indexing technique.",
  207. )
  208. parser.add_argument(
  209. "permission",
  210. type=str,
  211. location="json",
  212. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  213. help="Invalid permission.",
  214. )
  215. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  216. parser.add_argument(
  217. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  218. )
  219. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  220. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  221. parser.add_argument(
  222. "external_retrieval_model",
  223. type=dict,
  224. required=False,
  225. nullable=True,
  226. location="json",
  227. help="Invalid external retrieval model.",
  228. )
  229. parser.add_argument(
  230. "external_knowledge_id",
  231. type=str,
  232. required=False,
  233. nullable=True,
  234. location="json",
  235. help="Invalid external knowledge id.",
  236. )
  237. parser.add_argument(
  238. "external_knowledge_api_id",
  239. type=str,
  240. required=False,
  241. nullable=True,
  242. location="json",
  243. help="Invalid external knowledge api id.",
  244. )
  245. args = parser.parse_args()
  246. data = request.get_json()
  247. # check embedding model setting
  248. if data.get("indexing_technique") == "high_quality":
  249. DatasetService.check_embedding_model_setting(
  250. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  251. )
  252. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  253. DatasetPermissionService.check_permission(
  254. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  255. )
  256. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  257. if dataset is None:
  258. raise NotFound("Dataset not found.")
  259. result_data = marshal(dataset, dataset_detail_fields)
  260. tenant_id = current_user.current_tenant_id
  261. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  262. DatasetPermissionService.update_partial_member_list(
  263. tenant_id, dataset_id_str, data.get("partial_member_list")
  264. )
  265. # clear partial member list when permission is only_me or all_team_members
  266. elif (
  267. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  268. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  269. ):
  270. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  271. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  272. result_data.update({"partial_member_list": partial_member_list})
  273. return result_data, 200
  274. @setup_required
  275. @login_required
  276. @account_initialization_required
  277. def delete(self, dataset_id):
  278. dataset_id_str = str(dataset_id)
  279. # The role of the current user in the ta table must be admin, owner, or editor
  280. if not current_user.is_editor or current_user.is_dataset_operator:
  281. raise Forbidden()
  282. try:
  283. if DatasetService.delete_dataset(dataset_id_str, current_user):
  284. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  285. return {"result": "success"}, 204
  286. else:
  287. raise NotFound("Dataset not found.")
  288. except services.errors.dataset.DatasetInUseError:
  289. raise DatasetInUseError()
  290. class DatasetUseCheckApi(Resource):
  291. @setup_required
  292. @login_required
  293. @account_initialization_required
  294. def get(self, dataset_id):
  295. dataset_id_str = str(dataset_id)
  296. dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
  297. return {"is_using": dataset_is_using}, 200
  298. class DatasetQueryApi(Resource):
  299. @setup_required
  300. @login_required
  301. @account_initialization_required
  302. def get(self, dataset_id):
  303. dataset_id_str = str(dataset_id)
  304. dataset = DatasetService.get_dataset(dataset_id_str)
  305. if dataset is None:
  306. raise NotFound("Dataset not found.")
  307. try:
  308. DatasetService.check_dataset_permission(dataset, current_user)
  309. except services.errors.account.NoPermissionError as e:
  310. raise Forbidden(str(e))
  311. page = request.args.get("page", default=1, type=int)
  312. limit = request.args.get("limit", default=20, type=int)
  313. dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
  314. response = {
  315. "data": marshal(dataset_queries, dataset_query_detail_fields),
  316. "has_more": len(dataset_queries) == limit,
  317. "limit": limit,
  318. "total": total,
  319. "page": page,
  320. }
  321. return response, 200
  322. class DatasetIndexingEstimateApi(Resource):
  323. @setup_required
  324. @login_required
  325. @account_initialization_required
  326. def post(self):
  327. parser = reqparse.RequestParser()
  328. parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
  329. parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
  330. parser.add_argument(
  331. "indexing_technique",
  332. type=str,
  333. required=True,
  334. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  335. nullable=True,
  336. location="json",
  337. )
  338. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  339. parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
  340. parser.add_argument(
  341. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  342. )
  343. args = parser.parse_args()
  344. # validate args
  345. DocumentService.estimate_args_validate(args)
  346. extract_settings = []
  347. if args["info_list"]["data_source_type"] == "upload_file":
  348. file_ids = args["info_list"]["file_info_list"]["file_ids"]
  349. file_details = (
  350. db.session.query(UploadFile)
  351. .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
  352. .all()
  353. )
  354. if file_details is None:
  355. raise NotFound("File not found.")
  356. if file_details:
  357. for file_detail in file_details:
  358. extract_setting = ExtractSetting(
  359. datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
  360. )
  361. extract_settings.append(extract_setting)
  362. elif args["info_list"]["data_source_type"] == "notion_import":
  363. notion_info_list = args["info_list"]["notion_info_list"]
  364. for notion_info in notion_info_list:
  365. workspace_id = notion_info["workspace_id"]
  366. for page in notion_info["pages"]:
  367. extract_setting = ExtractSetting(
  368. datasource_type="notion_import",
  369. notion_info={
  370. "notion_workspace_id": workspace_id,
  371. "notion_obj_id": page["page_id"],
  372. "notion_page_type": page["type"],
  373. "tenant_id": current_user.current_tenant_id,
  374. },
  375. document_model=args["doc_form"],
  376. )
  377. extract_settings.append(extract_setting)
  378. elif args["info_list"]["data_source_type"] == "website_crawl":
  379. website_info_list = args["info_list"]["website_info_list"]
  380. for url in website_info_list["urls"]:
  381. extract_setting = ExtractSetting(
  382. datasource_type="website_crawl",
  383. website_info={
  384. "provider": website_info_list["provider"],
  385. "job_id": website_info_list["job_id"],
  386. "url": url,
  387. "tenant_id": current_user.current_tenant_id,
  388. "mode": "crawl",
  389. "only_main_content": website_info_list["only_main_content"],
  390. },
  391. document_model=args["doc_form"],
  392. )
  393. extract_settings.append(extract_setting)
  394. else:
  395. raise ValueError("Data source type not support")
  396. indexing_runner = IndexingRunner()
  397. try:
  398. response = indexing_runner.indexing_estimate(
  399. current_user.current_tenant_id,
  400. extract_settings,
  401. args["process_rule"],
  402. args["doc_form"],
  403. args["doc_language"],
  404. args["dataset_id"],
  405. args["indexing_technique"],
  406. )
  407. except LLMBadRequestError:
  408. raise ProviderNotInitializeError(
  409. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  410. )
  411. except ProviderTokenNotInitError as ex:
  412. raise ProviderNotInitializeError(ex.description)
  413. except Exception as e:
  414. raise IndexingEstimateError(str(e))
  415. return response.model_dump(), 200
  416. class DatasetRelatedAppListApi(Resource):
  417. @setup_required
  418. @login_required
  419. @account_initialization_required
  420. @marshal_with(related_app_list)
  421. def get(self, dataset_id):
  422. dataset_id_str = str(dataset_id)
  423. dataset = DatasetService.get_dataset(dataset_id_str)
  424. if dataset is None:
  425. raise NotFound("Dataset not found.")
  426. try:
  427. DatasetService.check_dataset_permission(dataset, current_user)
  428. except services.errors.account.NoPermissionError as e:
  429. raise Forbidden(str(e))
  430. app_dataset_joins = DatasetService.get_related_apps(dataset.id)
  431. related_apps = []
  432. for app_dataset_join in app_dataset_joins:
  433. app_model = app_dataset_join.app
  434. if app_model:
  435. related_apps.append(app_model)
  436. return {"data": related_apps, "total": len(related_apps)}, 200
  437. class DatasetIndexingStatusApi(Resource):
  438. @setup_required
  439. @login_required
  440. @account_initialization_required
  441. def get(self, dataset_id):
  442. dataset_id = str(dataset_id)
  443. documents = (
  444. db.session.query(Document)
  445. .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
  446. .all()
  447. )
  448. documents_status = []
  449. for document in documents:
  450. completed_segments = DocumentSegment.query.filter(
  451. DocumentSegment.completed_at.isnot(None),
  452. DocumentSegment.document_id == str(document.id),
  453. DocumentSegment.status != "re_segment",
  454. ).count()
  455. total_segments = DocumentSegment.query.filter(
  456. DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
  457. ).count()
  458. document.completed_segments = completed_segments
  459. document.total_segments = total_segments
  460. documents_status.append(marshal(document, document_status_fields))
  461. data = {"data": documents_status}
  462. return data
  463. class DatasetApiKeyApi(Resource):
  464. max_keys = 10
  465. token_prefix = "dataset-"
  466. resource_type = "dataset"
  467. @setup_required
  468. @login_required
  469. @account_initialization_required
  470. @marshal_with(api_key_list)
  471. def get(self):
  472. keys = (
  473. db.session.query(ApiToken)
  474. .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
  475. .all()
  476. )
  477. return {"items": keys}
  478. @setup_required
  479. @login_required
  480. @account_initialization_required
  481. @marshal_with(api_key_fields)
  482. def post(self):
  483. # The role of the current user in the ta table must be admin or owner
  484. if not current_user.is_admin_or_owner:
  485. raise Forbidden()
  486. current_key_count = (
  487. db.session.query(ApiToken)
  488. .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
  489. .count()
  490. )
  491. if current_key_count >= self.max_keys:
  492. flask_restful.abort(
  493. 400,
  494. message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
  495. code="max_keys_exceeded",
  496. )
  497. key = ApiToken.generate_api_key(self.token_prefix, 24)
  498. api_token = ApiToken()
  499. api_token.tenant_id = current_user.current_tenant_id
  500. api_token.token = key
  501. api_token.type = self.resource_type
  502. db.session.add(api_token)
  503. db.session.commit()
  504. return api_token, 200
  505. class DatasetApiDeleteApi(Resource):
  506. resource_type = "dataset"
  507. @setup_required
  508. @login_required
  509. @account_initialization_required
  510. def delete(self, api_key_id):
  511. api_key_id = str(api_key_id)
  512. # The role of the current user in the ta table must be admin or owner
  513. if not current_user.is_admin_or_owner:
  514. raise Forbidden()
  515. key = (
  516. db.session.query(ApiToken)
  517. .filter(
  518. ApiToken.tenant_id == current_user.current_tenant_id,
  519. ApiToken.type == self.resource_type,
  520. ApiToken.id == api_key_id,
  521. )
  522. .first()
  523. )
  524. if key is None:
  525. flask_restful.abort(404, message="API key not found")
  526. db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
  527. db.session.commit()
  528. return {"result": "success"}, 204
  529. class DatasetApiBaseUrlApi(Resource):
  530. @setup_required
  531. @login_required
  532. @account_initialization_required
  533. def get(self):
  534. return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
  535. class DatasetRetrievalSettingApi(Resource):
  536. @setup_required
  537. @login_required
  538. @account_initialization_required
  539. def get(self):
  540. vector_type = dify_config.VECTOR_STORE
  541. match vector_type:
  542. case (
  543. VectorType.RELYT
  544. | VectorType.TIDB_VECTOR
  545. | VectorType.CHROMA
  546. | VectorType.TENCENT
  547. | VectorType.PGVECTO_RS
  548. | VectorType.BAIDU
  549. | VectorType.VIKINGDB
  550. | VectorType.UPSTASH
  551. | VectorType.OCEANBASE
  552. ):
  553. return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
  554. case (
  555. VectorType.QDRANT
  556. | VectorType.WEAVIATE
  557. | VectorType.OPENSEARCH
  558. | VectorType.ANALYTICDB
  559. | VectorType.MYSCALE
  560. | VectorType.ORACLE
  561. | VectorType.ELASTICSEARCH
  562. | VectorType.ELASTICSEARCH_JA
  563. | VectorType.PGVECTOR
  564. | VectorType.TIDB_ON_QDRANT
  565. | VectorType.LINDORM
  566. | VectorType.COUCHBASE
  567. | VectorType.MILVUS
  568. ):
  569. return {
  570. "retrieval_method": [
  571. RetrievalMethod.SEMANTIC_SEARCH.value,
  572. RetrievalMethod.FULL_TEXT_SEARCH.value,
  573. RetrievalMethod.HYBRID_SEARCH.value,
  574. ]
  575. }
  576. case _:
  577. raise ValueError(f"Unsupported vector db type {vector_type}.")
  578. class DatasetRetrievalSettingMockApi(Resource):
  579. @setup_required
  580. @login_required
  581. @account_initialization_required
  582. def get(self, vector_type):
  583. match vector_type:
  584. case (
  585. VectorType.MILVUS
  586. | VectorType.RELYT
  587. | VectorType.TIDB_VECTOR
  588. | VectorType.CHROMA
  589. | VectorType.TENCENT
  590. | VectorType.PGVECTO_RS
  591. | VectorType.BAIDU
  592. | VectorType.VIKINGDB
  593. | VectorType.UPSTASH
  594. | VectorType.OCEANBASE
  595. ):
  596. return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
  597. case (
  598. VectorType.QDRANT
  599. | VectorType.WEAVIATE
  600. | VectorType.OPENSEARCH
  601. | VectorType.ANALYTICDB
  602. | VectorType.MYSCALE
  603. | VectorType.ORACLE
  604. | VectorType.ELASTICSEARCH
  605. | VectorType.ELASTICSEARCH_JA
  606. | VectorType.COUCHBASE
  607. | VectorType.PGVECTOR
  608. | VectorType.LINDORM
  609. ):
  610. return {
  611. "retrieval_method": [
  612. RetrievalMethod.SEMANTIC_SEARCH.value,
  613. RetrievalMethod.FULL_TEXT_SEARCH.value,
  614. RetrievalMethod.HYBRID_SEARCH.value,
  615. ]
  616. }
  617. case _:
  618. raise ValueError(f"Unsupported vector db type {vector_type}.")
  619. class DatasetErrorDocs(Resource):
  620. @setup_required
  621. @login_required
  622. @account_initialization_required
  623. def get(self, dataset_id):
  624. dataset_id_str = str(dataset_id)
  625. dataset = DatasetService.get_dataset(dataset_id_str)
  626. if dataset is None:
  627. raise NotFound("Dataset not found.")
  628. results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
  629. return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
  630. class DatasetPermissionUserListApi(Resource):
  631. @setup_required
  632. @login_required
  633. @account_initialization_required
  634. def get(self, dataset_id):
  635. dataset_id_str = str(dataset_id)
  636. dataset = DatasetService.get_dataset(dataset_id_str)
  637. if dataset is None:
  638. raise NotFound("Dataset not found.")
  639. try:
  640. DatasetService.check_dataset_permission(dataset, current_user)
  641. except services.errors.account.NoPermissionError as e:
  642. raise Forbidden(str(e))
  643. partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  644. return {
  645. "data": partial_members_list,
  646. }, 200
  647. class DatasetAutoDisableLogApi(Resource):
  648. @setup_required
  649. @login_required
  650. @account_initialization_required
  651. def get(self, dataset_id):
  652. dataset_id_str = str(dataset_id)
  653. dataset = DatasetService.get_dataset(dataset_id_str)
  654. if dataset is None:
  655. raise NotFound("Dataset not found.")
  656. return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
  657. api.add_resource(DatasetListApi, "/datasets")
  658. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
  659. api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
  660. api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
  661. api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
  662. api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
  663. api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
  664. api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
  665. api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
  666. api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
  667. api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
  668. api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
  669. api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
  670. api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
  671. api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")