data_source.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import datetime
  2. import json
  3. from flask import request
  4. from flask_login import current_user # type: ignore
  5. from flask_restful import Resource, marshal_with, reqparse # type: ignore
  6. from werkzeug.exceptions import NotFound
  7. from controllers.console import api
  8. from controllers.console.wraps import account_initialization_required, setup_required
  9. from core.indexing_runner import IndexingRunner
  10. from core.rag.extractor.entity.extract_setting import ExtractSetting
  11. from core.rag.extractor.notion_extractor import NotionExtractor
  12. from extensions.ext_database import db
  13. from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
  14. from libs.login import login_required
  15. from models import DataSourceOauthBinding, Document
  16. from services.dataset_service import DatasetService, DocumentService
  17. from tasks.document_indexing_sync_task import document_indexing_sync_task
  18. class DataSourceApi(Resource):
  19. @setup_required
  20. @login_required
  21. @account_initialization_required
  22. @marshal_with(integrate_list_fields)
  23. def get(self):
  24. # get workspace data source integrates
  25. data_source_integrates = (
  26. db.session.query(DataSourceOauthBinding)
  27. .filter(
  28. DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
  29. DataSourceOauthBinding.disabled == False,
  30. )
  31. .all()
  32. )
  33. base_url = request.url_root.rstrip("/")
  34. data_source_oauth_base_path = "/console/api/oauth/data-source"
  35. providers = ["notion"]
  36. integrate_data = []
  37. for provider in providers:
  38. # existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None)
  39. existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
  40. if existing_integrates:
  41. for existing_integrate in list(existing_integrates):
  42. integrate_data.append(
  43. {
  44. "id": existing_integrate.id,
  45. "provider": provider,
  46. "created_at": existing_integrate.created_at,
  47. "is_bound": True,
  48. "disabled": existing_integrate.disabled,
  49. "source_info": existing_integrate.source_info,
  50. "link": f"{base_url}{data_source_oauth_base_path}/{provider}",
  51. }
  52. )
  53. else:
  54. integrate_data.append(
  55. {
  56. "id": None,
  57. "provider": provider,
  58. "created_at": None,
  59. "source_info": None,
  60. "is_bound": False,
  61. "disabled": None,
  62. "link": f"{base_url}{data_source_oauth_base_path}/{provider}",
  63. }
  64. )
  65. return {"data": integrate_data}, 200
  66. @setup_required
  67. @login_required
  68. @account_initialization_required
  69. def patch(self, binding_id, action):
  70. binding_id = str(binding_id)
  71. action = str(action)
  72. data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
  73. if data_source_binding is None:
  74. raise NotFound("Data source binding not found.")
  75. # enable binding
  76. if action == "enable":
  77. if data_source_binding.disabled:
  78. data_source_binding.disabled = False
  79. data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  80. db.session.add(data_source_binding)
  81. db.session.commit()
  82. else:
  83. raise ValueError("Data source is not disabled.")
  84. # disable binding
  85. if action == "disable":
  86. if not data_source_binding.disabled:
  87. data_source_binding.disabled = True
  88. data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  89. db.session.add(data_source_binding)
  90. db.session.commit()
  91. else:
  92. raise ValueError("Data source is disabled.")
  93. return {"result": "success"}, 200
  94. class DataSourceNotionListApi(Resource):
  95. @setup_required
  96. @login_required
  97. @account_initialization_required
  98. @marshal_with(integrate_notion_info_list_fields)
  99. def get(self):
  100. dataset_id = request.args.get("dataset_id", default=None, type=str)
  101. exist_page_ids = []
  102. # import notion in the exist dataset
  103. if dataset_id:
  104. dataset = DatasetService.get_dataset(dataset_id)
  105. if not dataset:
  106. raise NotFound("Dataset not found.")
  107. if dataset.data_source_type != "notion_import":
  108. raise ValueError("Dataset is not notion type.")
  109. documents = Document.query.filter_by(
  110. dataset_id=dataset_id,
  111. tenant_id=current_user.current_tenant_id,
  112. data_source_type="notion_import",
  113. enabled=True,
  114. ).all()
  115. if documents:
  116. for document in documents:
  117. data_source_info = json.loads(document.data_source_info)
  118. exist_page_ids.append(data_source_info["notion_page_id"])
  119. # get all authorized pages
  120. data_source_bindings = DataSourceOauthBinding.query.filter_by(
  121. tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
  122. ).all()
  123. if not data_source_bindings:
  124. return {"notion_info": []}, 200
  125. pre_import_info_list = []
  126. for data_source_binding in data_source_bindings:
  127. source_info = data_source_binding.source_info
  128. pages = source_info["pages"]
  129. # Filter out already bound pages
  130. for page in pages:
  131. if page["page_id"] in exist_page_ids:
  132. page["is_bound"] = True
  133. else:
  134. page["is_bound"] = False
  135. pre_import_info = {
  136. "workspace_name": source_info["workspace_name"],
  137. "workspace_icon": source_info["workspace_icon"],
  138. "workspace_id": source_info["workspace_id"],
  139. "pages": pages,
  140. }
  141. pre_import_info_list.append(pre_import_info)
  142. return {"notion_info": pre_import_info_list}, 200
  143. class DataSourceNotionApi(Resource):
  144. @setup_required
  145. @login_required
  146. @account_initialization_required
  147. def get(self, workspace_id, page_id, page_type):
  148. workspace_id = str(workspace_id)
  149. page_id = str(page_id)
  150. data_source_binding = DataSourceOauthBinding.query.filter(
  151. db.and_(
  152. DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
  153. DataSourceOauthBinding.provider == "notion",
  154. DataSourceOauthBinding.disabled == False,
  155. DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
  156. )
  157. ).first()
  158. if not data_source_binding:
  159. raise NotFound("Data source binding not found.")
  160. extractor = NotionExtractor(
  161. notion_workspace_id=workspace_id,
  162. notion_obj_id=page_id,
  163. notion_page_type=page_type,
  164. notion_access_token=data_source_binding.access_token,
  165. tenant_id=current_user.current_tenant_id,
  166. )
  167. text_docs = extractor.extract()
  168. return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
  169. @setup_required
  170. @login_required
  171. @account_initialization_required
  172. def post(self):
  173. parser = reqparse.RequestParser()
  174. parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
  175. parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
  176. parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
  177. parser.add_argument(
  178. "doc_language", type=str, default="English", required=False, nullable=False, location="json"
  179. )
  180. args = parser.parse_args()
  181. # validate args
  182. DocumentService.estimate_args_validate(args)
  183. notion_info_list = args["notion_info_list"]
  184. extract_settings = []
  185. for notion_info in notion_info_list:
  186. workspace_id = notion_info["workspace_id"]
  187. for page in notion_info["pages"]:
  188. extract_setting = ExtractSetting(
  189. datasource_type="notion_import",
  190. notion_info={
  191. "notion_workspace_id": workspace_id,
  192. "notion_obj_id": page["page_id"],
  193. "notion_page_type": page["type"],
  194. "tenant_id": current_user.current_tenant_id,
  195. },
  196. document_model=args["doc_form"],
  197. )
  198. extract_settings.append(extract_setting)
  199. indexing_runner = IndexingRunner()
  200. response = indexing_runner.indexing_estimate(
  201. current_user.current_tenant_id,
  202. extract_settings,
  203. args["process_rule"],
  204. args["doc_form"],
  205. args["doc_language"],
  206. )
  207. return response.model_dump(), 200
  208. class DataSourceNotionDatasetSyncApi(Resource):
  209. @setup_required
  210. @login_required
  211. @account_initialization_required
  212. def get(self, dataset_id):
  213. dataset_id_str = str(dataset_id)
  214. dataset = DatasetService.get_dataset(dataset_id_str)
  215. if dataset is None:
  216. raise NotFound("Dataset not found.")
  217. documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
  218. for document in documents:
  219. document_indexing_sync_task.delay(dataset_id_str, document.id)
  220. return 200
  221. class DataSourceNotionDocumentSyncApi(Resource):
  222. @setup_required
  223. @login_required
  224. @account_initialization_required
  225. def get(self, dataset_id, document_id):
  226. dataset_id_str = str(dataset_id)
  227. document_id_str = str(document_id)
  228. dataset = DatasetService.get_dataset(dataset_id_str)
  229. if dataset is None:
  230. raise NotFound("Dataset not found.")
  231. document = DocumentService.get_document(dataset_id_str, document_id_str)
  232. if document is None:
  233. raise NotFound("Document not found.")
  234. document_indexing_sync_task.delay(dataset_id_str, document_id_str)
  235. return 200
  236. api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
  237. api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
  238. api.add_resource(
  239. DataSourceNotionApi,
  240. "/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
  241. "/datasets/notion-indexing-estimate",
  242. )
  243. api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
  244. api.add_resource(
  245. DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
  246. )