indexing_runner.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. import concurrent.futures
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from typing import Any, Optional, cast
  10. from flask import Flask, current_app
  11. from flask_login import current_user # type: ignore
  12. from sqlalchemy.orm.exc import ObjectDeletedError
  13. from configs import dify_config
  14. from core.errors.error import ProviderTokenNotInitError
  15. from core.llm_generator.llm_generator import LLMGenerator
  16. from core.model_manager import ModelInstance, ModelManager
  17. from core.model_runtime.entities.model_entities import ModelType
  18. from core.rag.cleaner.clean_processor import CleanProcessor
  19. from core.rag.datasource.keyword.keyword_factory import Keyword
  20. from core.rag.docstore.dataset_docstore import DatasetDocumentStore
  21. from core.rag.extractor.entity.extract_setting import ExtractSetting
  22. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  23. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  24. from core.rag.models.document import Document
  25. from core.rag.splitter.fixed_text_splitter import (
  26. EnhanceRecursiveCharacterTextSplitter,
  27. FixedRecursiveCharacterTextSplitter,
  28. )
  29. from core.rag.splitter.text_splitter import TextSplitter
  30. from core.tools.utils.text_processing_utils import remove_leading_symbols
  31. from core.tools.utils.web_reader_tool import get_image_upload_file_ids
  32. from extensions.ext_database import db
  33. from extensions.ext_redis import redis_client
  34. from extensions.ext_storage import storage
  35. from libs import helper
  36. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  37. from models.dataset import Document as DatasetDocument
  38. from models.model import UploadFile
  39. from services.feature_service import FeatureService
  40. class IndexingRunner:
  41. def __init__(self):
  42. self.storage = storage
  43. self.model_manager = ModelManager()
  44. def run(self, dataset_documents: list[DatasetDocument]):
  45. """Run the indexing process."""
  46. for dataset_document in dataset_documents:
  47. try:
  48. # get dataset
  49. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  50. if not dataset:
  51. raise ValueError("no dataset found")
  52. # get the process rule
  53. processing_rule = (
  54. db.session.query(DatasetProcessRule)
  55. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  56. .first()
  57. )
  58. if not processing_rule:
  59. raise ValueError("no process rule found")
  60. index_type = dataset_document.doc_form
  61. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  62. # extract
  63. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  64. # transform
  65. documents = self._transform(
  66. index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
  67. )
  68. # save segment
  69. self._load_segments(dataset, dataset_document, documents)
  70. # load
  71. self._load(
  72. index_processor=index_processor,
  73. dataset=dataset,
  74. dataset_document=dataset_document,
  75. documents=documents,
  76. )
  77. except DocumentIsPausedError:
  78. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  79. except ProviderTokenNotInitError as e:
  80. dataset_document.indexing_status = "error"
  81. dataset_document.error = str(e.description)
  82. dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  83. db.session.commit()
  84. except ObjectDeletedError:
  85. logging.warning("Document deleted, document id: {}".format(dataset_document.id))
  86. except Exception as e:
  87. logging.exception("consume document failed")
  88. dataset_document.indexing_status = "error"
  89. dataset_document.error = str(e)
  90. dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  91. db.session.commit()
  92. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  93. """Run the indexing process when the index_status is splitting."""
  94. try:
  95. # get dataset
  96. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  97. if not dataset:
  98. raise ValueError("no dataset found")
  99. # get exist document_segment list and delete
  100. document_segments = DocumentSegment.query.filter_by(
  101. dataset_id=dataset.id, document_id=dataset_document.id
  102. ).all()
  103. for document_segment in document_segments:
  104. db.session.delete(document_segment)
  105. db.session.commit()
  106. # get the process rule
  107. processing_rule = (
  108. db.session.query(DatasetProcessRule)
  109. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  110. .first()
  111. )
  112. if not processing_rule:
  113. raise ValueError("no process rule found")
  114. index_type = dataset_document.doc_form
  115. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  116. # extract
  117. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  118. # transform
  119. documents = self._transform(
  120. index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
  121. )
  122. # save segment
  123. self._load_segments(dataset, dataset_document, documents)
  124. # load
  125. self._load(
  126. index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
  127. )
  128. except DocumentIsPausedError:
  129. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  130. except ProviderTokenNotInitError as e:
  131. dataset_document.indexing_status = "error"
  132. dataset_document.error = str(e.description)
  133. dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  134. db.session.commit()
  135. except Exception as e:
  136. logging.exception("consume document failed")
  137. dataset_document.indexing_status = "error"
  138. dataset_document.error = str(e)
  139. dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  140. db.session.commit()
  141. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  142. """Run the indexing process when the index_status is indexing."""
  143. try:
  144. # get dataset
  145. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  146. if not dataset:
  147. raise ValueError("no dataset found")
  148. # get exist document_segment list and delete
  149. document_segments = DocumentSegment.query.filter_by(
  150. dataset_id=dataset.id, document_id=dataset_document.id
  151. ).all()
  152. documents = []
  153. if document_segments:
  154. for document_segment in document_segments:
  155. # transform segment to node
  156. if document_segment.status != "completed":
  157. document = Document(
  158. page_content=document_segment.content,
  159. metadata={
  160. "doc_id": document_segment.index_node_id,
  161. "doc_hash": document_segment.index_node_hash,
  162. "document_id": document_segment.document_id,
  163. "dataset_id": document_segment.dataset_id,
  164. },
  165. )
  166. documents.append(document)
  167. # build index
  168. # get the process rule
  169. processing_rule = (
  170. db.session.query(DatasetProcessRule)
  171. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  172. .first()
  173. )
  174. index_type = dataset_document.doc_form
  175. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  176. self._load(
  177. index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
  178. )
  179. except DocumentIsPausedError:
  180. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  181. except ProviderTokenNotInitError as e:
  182. dataset_document.indexing_status = "error"
  183. dataset_document.error = str(e.description)
  184. dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  185. db.session.commit()
  186. except Exception as e:
  187. logging.exception("consume document failed")
  188. dataset_document.indexing_status = "error"
  189. dataset_document.error = str(e)
  190. dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  191. db.session.commit()
  192. def indexing_estimate(
  193. self,
  194. tenant_id: str,
  195. extract_settings: list[ExtractSetting],
  196. tmp_processing_rule: dict,
  197. doc_form: Optional[str] = None,
  198. doc_language: str = "English",
  199. dataset_id: Optional[str] = None,
  200. indexing_technique: str = "economy",
  201. ) -> dict:
  202. """
  203. Estimate the indexing for the document.
  204. """
  205. # check document limit
  206. features = FeatureService.get_features(tenant_id)
  207. if features.billing.enabled:
  208. count = len(extract_settings)
  209. batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
  210. if count > batch_upload_limit:
  211. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  212. embedding_model_instance = None
  213. if dataset_id:
  214. dataset = Dataset.query.filter_by(id=dataset_id).first()
  215. if not dataset:
  216. raise ValueError("Dataset not found.")
  217. if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
  218. if dataset.embedding_model_provider:
  219. embedding_model_instance = self.model_manager.get_model_instance(
  220. tenant_id=tenant_id,
  221. provider=dataset.embedding_model_provider,
  222. model_type=ModelType.TEXT_EMBEDDING,
  223. model=dataset.embedding_model,
  224. )
  225. else:
  226. embedding_model_instance = self.model_manager.get_default_model_instance(
  227. tenant_id=tenant_id,
  228. model_type=ModelType.TEXT_EMBEDDING,
  229. )
  230. else:
  231. if indexing_technique == "high_quality":
  232. embedding_model_instance = self.model_manager.get_default_model_instance(
  233. tenant_id=tenant_id,
  234. model_type=ModelType.TEXT_EMBEDDING,
  235. )
  236. preview_texts: list[str] = []
  237. total_segments = 0
  238. index_type = doc_form
  239. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  240. all_text_docs = []
  241. for extract_setting in extract_settings:
  242. # extract
  243. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  244. all_text_docs.extend(text_docs)
  245. processing_rule = DatasetProcessRule(
  246. mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
  247. )
  248. # get splitter
  249. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  250. # split to documents
  251. documents = self._split_to_documents_for_estimate(
  252. text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
  253. )
  254. total_segments += len(documents)
  255. for document in documents:
  256. if len(preview_texts) < 5:
  257. preview_texts.append(document.page_content)
  258. # delete image files and related db records
  259. image_upload_file_ids = get_image_upload_file_ids(document.page_content)
  260. for upload_file_id in image_upload_file_ids:
  261. image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
  262. try:
  263. if image_file:
  264. storage.delete(image_file.key)
  265. except Exception:
  266. logging.exception(
  267. "Delete image_files failed while indexing_estimate, \
  268. image_upload_file_is: {}".format(upload_file_id)
  269. )
  270. db.session.delete(image_file)
  271. if doc_form and doc_form == "qa_model":
  272. if len(preview_texts) > 0:
  273. # qa model document
  274. response = LLMGenerator.generate_qa_document(
  275. current_user.current_tenant_id, preview_texts[0], doc_language
  276. )
  277. document_qa_list = self.format_split_text(response)
  278. return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
  279. return {"total_segments": total_segments, "preview": preview_texts}
  280. def _extract(
  281. self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
  282. ) -> list[Document]:
  283. # load file
  284. if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
  285. return []
  286. data_source_info = dataset_document.data_source_info_dict
  287. text_docs = []
  288. if dataset_document.data_source_type == "upload_file":
  289. if not data_source_info or "upload_file_id" not in data_source_info:
  290. raise ValueError("no upload file found")
  291. file_detail = (
  292. db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
  293. )
  294. if file_detail:
  295. extract_setting = ExtractSetting(
  296. datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
  297. )
  298. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  299. elif dataset_document.data_source_type == "notion_import":
  300. if (
  301. not data_source_info
  302. or "notion_workspace_id" not in data_source_info
  303. or "notion_page_id" not in data_source_info
  304. ):
  305. raise ValueError("no notion import info found")
  306. extract_setting = ExtractSetting(
  307. datasource_type="notion_import",
  308. notion_info={
  309. "notion_workspace_id": data_source_info["notion_workspace_id"],
  310. "notion_obj_id": data_source_info["notion_page_id"],
  311. "notion_page_type": data_source_info["type"],
  312. "document": dataset_document,
  313. "tenant_id": dataset_document.tenant_id,
  314. },
  315. document_model=dataset_document.doc_form,
  316. )
  317. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  318. elif dataset_document.data_source_type == "website_crawl":
  319. if (
  320. not data_source_info
  321. or "provider" not in data_source_info
  322. or "url" not in data_source_info
  323. or "job_id" not in data_source_info
  324. ):
  325. raise ValueError("no website import info found")
  326. extract_setting = ExtractSetting(
  327. datasource_type="website_crawl",
  328. website_info={
  329. "provider": data_source_info["provider"],
  330. "job_id": data_source_info["job_id"],
  331. "tenant_id": dataset_document.tenant_id,
  332. "url": data_source_info["url"],
  333. "mode": data_source_info["mode"],
  334. "only_main_content": data_source_info["only_main_content"],
  335. },
  336. document_model=dataset_document.doc_form,
  337. )
  338. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  339. # update document status to splitting
  340. self._update_document_index_status(
  341. document_id=dataset_document.id,
  342. after_indexing_status="splitting",
  343. extra_update_params={
  344. DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
  345. DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  346. },
  347. )
  348. # replace doc id to document model id
  349. text_docs = cast(list[Document], text_docs)
  350. for text_doc in text_docs:
  351. if text_doc.metadata is not None:
  352. text_doc.metadata["document_id"] = dataset_document.id
  353. text_doc.metadata["dataset_id"] = dataset_document.dataset_id
  354. return text_docs
  355. @staticmethod
  356. def filter_string(text):
  357. text = re.sub(r"<\|", "<", text)
  358. text = re.sub(r"\|>", ">", text)
  359. text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
  360. # Unicode U+FFFE
  361. text = re.sub("\ufffe", "", text)
  362. return text
  363. @staticmethod
  364. def _get_splitter(
  365. processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
  366. ) -> TextSplitter:
  367. """
  368. Get the NodeParser object according to the processing rule.
  369. """
  370. character_splitter: TextSplitter
  371. if processing_rule.mode == "custom":
  372. # The user-defined segmentation rule
  373. rules = json.loads(processing_rule.rules)
  374. segmentation = rules["segmentation"]
  375. max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
  376. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
  377. raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
  378. separator = segmentation["separator"]
  379. if separator:
  380. separator = separator.replace("\\n", "\n")
  381. if segmentation.get("chunk_overlap"):
  382. chunk_overlap = segmentation["chunk_overlap"]
  383. else:
  384. chunk_overlap = 0
  385. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  386. chunk_size=segmentation["max_tokens"],
  387. chunk_overlap=chunk_overlap,
  388. fixed_separator=separator,
  389. separators=["\n\n", "。", ". ", " ", ""],
  390. embedding_model_instance=embedding_model_instance,
  391. )
  392. else:
  393. # Automatic segmentation
  394. automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"])
  395. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  396. chunk_size=automatic_rules["max_tokens"],
  397. chunk_overlap=automatic_rules["chunk_overlap"],
  398. separators=["\n\n", "。", ". ", " ", ""],
  399. embedding_model_instance=embedding_model_instance,
  400. )
  401. return character_splitter
  402. def _step_split(
  403. self,
  404. text_docs: list[Document],
  405. splitter: TextSplitter,
  406. dataset: Dataset,
  407. dataset_document: DatasetDocument,
  408. processing_rule: DatasetProcessRule,
  409. ) -> list[Document]:
  410. """
  411. Split the text documents into documents and save them to the document segment.
  412. """
  413. documents = self._split_to_documents(
  414. text_docs=text_docs,
  415. splitter=splitter,
  416. processing_rule=processing_rule,
  417. tenant_id=dataset.tenant_id,
  418. document_form=dataset_document.doc_form,
  419. document_language=dataset_document.doc_language,
  420. )
  421. # save node to document segment
  422. doc_store = DatasetDocumentStore(
  423. dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
  424. )
  425. # add document segments
  426. doc_store.add_documents(documents)
  427. # update document status to indexing
  428. cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  429. self._update_document_index_status(
  430. document_id=dataset_document.id,
  431. after_indexing_status="indexing",
  432. extra_update_params={
  433. DatasetDocument.cleaning_completed_at: cur_time,
  434. DatasetDocument.splitting_completed_at: cur_time,
  435. },
  436. )
  437. # update segment status to indexing
  438. self._update_segments_by_document(
  439. dataset_document_id=dataset_document.id,
  440. update_params={
  441. DocumentSegment.status: "indexing",
  442. DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  443. },
  444. )
  445. return documents
  446. def _split_to_documents(
  447. self,
  448. text_docs: list[Document],
  449. splitter: TextSplitter,
  450. processing_rule: DatasetProcessRule,
  451. tenant_id: str,
  452. document_form: str,
  453. document_language: str,
  454. ) -> list[Document]:
  455. """
  456. Split the text documents into nodes.
  457. """
  458. all_documents: list[Document] = []
  459. all_qa_documents: list[Document] = []
  460. for text_doc in text_docs:
  461. # document clean
  462. document_text = self._document_clean(text_doc.page_content, processing_rule)
  463. text_doc.page_content = document_text
  464. # parse document to nodes
  465. documents = splitter.split_documents([text_doc])
  466. split_documents = []
  467. for document_node in documents:
  468. if document_node.page_content.strip():
  469. if document_node.metadata is not None:
  470. doc_id = str(uuid.uuid4())
  471. hash = helper.generate_text_hash(document_node.page_content)
  472. document_node.metadata["doc_id"] = doc_id
  473. document_node.metadata["doc_hash"] = hash
  474. # delete Splitter character
  475. page_content = document_node.page_content
  476. document_node.page_content = remove_leading_symbols(page_content)
  477. if document_node.page_content:
  478. split_documents.append(document_node)
  479. all_documents.extend(split_documents)
  480. # processing qa document
  481. if document_form == "qa_model":
  482. for i in range(0, len(all_documents), 10):
  483. threads = []
  484. sub_documents = all_documents[i : i + 10]
  485. for doc in sub_documents:
  486. document_format_thread = threading.Thread(
  487. target=self.format_qa_document,
  488. kwargs={
  489. "flask_app": current_app._get_current_object(), # type: ignore
  490. "tenant_id": tenant_id,
  491. "document_node": doc,
  492. "all_qa_documents": all_qa_documents,
  493. "document_language": document_language,
  494. },
  495. )
  496. threads.append(document_format_thread)
  497. document_format_thread.start()
  498. for thread in threads:
  499. thread.join()
  500. return all_qa_documents
  501. return all_documents
  502. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  503. format_documents = []
  504. if document_node.page_content is None or not document_node.page_content.strip():
  505. return
  506. with flask_app.app_context():
  507. try:
  508. # qa model document
  509. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  510. document_qa_list = self.format_split_text(response)
  511. qa_documents = []
  512. for result in document_qa_list:
  513. qa_document = Document(
  514. page_content=result["question"], metadata=document_node.metadata.model_copy()
  515. )
  516. if qa_document.metadata is not None:
  517. doc_id = str(uuid.uuid4())
  518. hash = helper.generate_text_hash(result["question"])
  519. qa_document.metadata["answer"] = result["answer"]
  520. qa_document.metadata["doc_id"] = doc_id
  521. qa_document.metadata["doc_hash"] = hash
  522. qa_documents.append(qa_document)
  523. format_documents.extend(qa_documents)
  524. except Exception as e:
  525. logging.exception("Failed to format qa document")
  526. all_qa_documents.extend(format_documents)
  527. def _split_to_documents_for_estimate(
  528. self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
  529. ) -> list[Document]:
  530. """
  531. Split the text documents into nodes.
  532. """
  533. all_documents: list[Document] = []
  534. for text_doc in text_docs:
  535. # document clean
  536. document_text = self._document_clean(text_doc.page_content, processing_rule)
  537. text_doc.page_content = document_text
  538. # parse document to nodes
  539. documents = splitter.split_documents([text_doc])
  540. split_documents = []
  541. for document in documents:
  542. if document.page_content is None or not document.page_content.strip():
  543. continue
  544. if document.metadata is not None:
  545. doc_id = str(uuid.uuid4())
  546. hash = helper.generate_text_hash(document.page_content)
  547. document.metadata["doc_id"] = doc_id
  548. document.metadata["doc_hash"] = hash
  549. split_documents.append(document)
  550. all_documents.extend(split_documents)
  551. return all_documents
  552. @staticmethod
  553. def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
  554. """
  555. Clean the document text according to the processing rules.
  556. """
  557. if processing_rule.mode == "automatic":
  558. rules = DatasetProcessRule.AUTOMATIC_RULES
  559. else:
  560. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  561. document_text = CleanProcessor.clean(text, {"rules": rules})
  562. return document_text
  563. @staticmethod
  564. def format_split_text(text):
  565. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  566. matches = re.findall(regex, text, re.UNICODE)
  567. return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
  568. def _load(
  569. self,
  570. index_processor: BaseIndexProcessor,
  571. dataset: Dataset,
  572. dataset_document: DatasetDocument,
  573. documents: list[Document],
  574. ) -> None:
  575. """
  576. insert index and update document/segment status to completed
  577. """
  578. embedding_model_instance = None
  579. if dataset.indexing_technique == "high_quality":
  580. embedding_model_instance = self.model_manager.get_model_instance(
  581. tenant_id=dataset.tenant_id,
  582. provider=dataset.embedding_model_provider,
  583. model_type=ModelType.TEXT_EMBEDDING,
  584. model=dataset.embedding_model,
  585. )
  586. # chunk nodes by chunk size
  587. indexing_start_at = time.perf_counter()
  588. tokens = 0
  589. chunk_size = 10
  590. # create keyword index
  591. create_keyword_thread = threading.Thread(
  592. target=self._process_keyword_index,
  593. args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
  594. )
  595. create_keyword_thread.start()
  596. if dataset.indexing_technique == "high_quality":
  597. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  598. futures = []
  599. for i in range(0, len(documents), chunk_size):
  600. chunk_documents = documents[i : i + chunk_size]
  601. futures.append(
  602. executor.submit(
  603. self._process_chunk,
  604. current_app._get_current_object(), # type: ignore
  605. index_processor,
  606. chunk_documents,
  607. dataset,
  608. dataset_document,
  609. embedding_model_instance,
  610. )
  611. )
  612. for future in futures:
  613. tokens += future.result()
  614. create_keyword_thread.join()
  615. indexing_end_at = time.perf_counter()
  616. # update document status to completed
  617. self._update_document_index_status(
  618. document_id=dataset_document.id,
  619. after_indexing_status="completed",
  620. extra_update_params={
  621. DatasetDocument.tokens: tokens,
  622. DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  623. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  624. DatasetDocument.error: None,
  625. },
  626. )
  627. @staticmethod
  628. def _process_keyword_index(flask_app, dataset_id, document_id, documents):
  629. with flask_app.app_context():
  630. dataset = Dataset.query.filter_by(id=dataset_id).first()
  631. if not dataset:
  632. raise ValueError("no dataset found")
  633. keyword = Keyword(dataset)
  634. keyword.create(documents)
  635. if dataset.indexing_technique != "high_quality":
  636. document_ids = [document.metadata["doc_id"] for document in documents]
  637. db.session.query(DocumentSegment).filter(
  638. DocumentSegment.document_id == document_id,
  639. DocumentSegment.dataset_id == dataset_id,
  640. DocumentSegment.index_node_id.in_(document_ids),
  641. DocumentSegment.status == "indexing",
  642. ).update(
  643. {
  644. DocumentSegment.status: "completed",
  645. DocumentSegment.enabled: True,
  646. DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  647. }
  648. )
  649. db.session.commit()
  650. def _process_chunk(
  651. self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
  652. ):
  653. with flask_app.app_context():
  654. # check document is paused
  655. self._check_document_paused_status(dataset_document.id)
  656. tokens = 0
  657. if embedding_model_instance:
  658. tokens += sum(
  659. embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
  660. for document in chunk_documents
  661. )
  662. # load index
  663. index_processor.load(dataset, chunk_documents, with_keywords=False)
  664. document_ids = [document.metadata["doc_id"] for document in chunk_documents]
  665. db.session.query(DocumentSegment).filter(
  666. DocumentSegment.document_id == dataset_document.id,
  667. DocumentSegment.dataset_id == dataset.id,
  668. DocumentSegment.index_node_id.in_(document_ids),
  669. DocumentSegment.status == "indexing",
  670. ).update(
  671. {
  672. DocumentSegment.status: "completed",
  673. DocumentSegment.enabled: True,
  674. DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  675. }
  676. )
  677. db.session.commit()
  678. return tokens
  679. @staticmethod
  680. def _check_document_paused_status(document_id: str):
  681. indexing_cache_key = "document_{}_is_paused".format(document_id)
  682. result = redis_client.get(indexing_cache_key)
  683. if result:
  684. raise DocumentIsPausedError()
  685. @staticmethod
  686. def _update_document_index_status(
  687. document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
  688. ) -> None:
  689. """
  690. Update the document indexing status.
  691. """
  692. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  693. if count > 0:
  694. raise DocumentIsPausedError()
  695. document = DatasetDocument.query.filter_by(id=document_id).first()
  696. if not document:
  697. raise DocumentIsDeletedPausedError()
  698. update_params = {DatasetDocument.indexing_status: after_indexing_status}
  699. if extra_update_params:
  700. update_params.update(extra_update_params)
  701. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  702. db.session.commit()
  703. @staticmethod
  704. def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
  705. """
  706. Update the document segment by document id.
  707. """
  708. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  709. db.session.commit()
  710. @staticmethod
  711. def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
  712. """
  713. Batch add segments index processing
  714. """
  715. documents = []
  716. for segment in segments:
  717. document = Document(
  718. page_content=segment.content,
  719. metadata={
  720. "doc_id": segment.index_node_id,
  721. "doc_hash": segment.index_node_hash,
  722. "document_id": segment.document_id,
  723. "dataset_id": segment.dataset_id,
  724. },
  725. )
  726. documents.append(document)
  727. # save vector index
  728. index_type = dataset.doc_form
  729. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  730. index_processor.load(dataset, documents)
  731. def _transform(
  732. self,
  733. index_processor: BaseIndexProcessor,
  734. dataset: Dataset,
  735. text_docs: list[Document],
  736. doc_language: str,
  737. process_rule: dict,
  738. ) -> list[Document]:
  739. # get embedding model instance
  740. embedding_model_instance = None
  741. if dataset.indexing_technique == "high_quality":
  742. if dataset.embedding_model_provider:
  743. embedding_model_instance = self.model_manager.get_model_instance(
  744. tenant_id=dataset.tenant_id,
  745. provider=dataset.embedding_model_provider,
  746. model_type=ModelType.TEXT_EMBEDDING,
  747. model=dataset.embedding_model,
  748. )
  749. else:
  750. embedding_model_instance = self.model_manager.get_default_model_instance(
  751. tenant_id=dataset.tenant_id,
  752. model_type=ModelType.TEXT_EMBEDDING,
  753. )
  754. documents = index_processor.transform(
  755. text_docs,
  756. embedding_model_instance=embedding_model_instance,
  757. process_rule=process_rule,
  758. tenant_id=dataset.tenant_id,
  759. doc_language=doc_language,
  760. )
  761. return documents
  762. def _load_segments(self, dataset, dataset_document, documents):
  763. # save node to document segment
  764. doc_store = DatasetDocumentStore(
  765. dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
  766. )
  767. # add document segments
  768. doc_store.add_documents(documents)
  769. # update document status to indexing
  770. cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  771. self._update_document_index_status(
  772. document_id=dataset_document.id,
  773. after_indexing_status="indexing",
  774. extra_update_params={
  775. DatasetDocument.cleaning_completed_at: cur_time,
  776. DatasetDocument.splitting_completed_at: cur_time,
  777. },
  778. )
  779. # update segment status to indexing
  780. self._update_segments_by_document(
  781. dataset_document_id=dataset_document.id,
  782. update_params={
  783. DocumentSegment.status: "indexing",
  784. DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
  785. },
  786. )
  787. pass
  788. class DocumentIsPausedError(Exception):
  789. pass
  790. class DocumentIsDeletedPausedError(Exception):
  791. pass