indexing_runner.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import Optional, List, cast
  9. from flask_login import current_user
  10. from langchain.schema import Document
  11. from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
  12. from core.data_loader.file_extractor import FileExtractor
  13. from core.data_loader.loader.notion import NotionLoader
  14. from core.docstore.dataset_docstore import DatesetDocumentStore
  15. from core.generator.llm_generator import LLMGenerator
  16. from core.index.index import IndexBuilder
  17. from core.model_providers.error import ProviderTokenNotInitError
  18. from core.model_providers.model_factory import ModelFactory
  19. from core.model_providers.models.entity.message import MessageType
  20. from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
  21. from extensions.ext_database import db
  22. from extensions.ext_redis import redis_client
  23. from extensions.ext_storage import storage
  24. from libs import helper
  25. from models.dataset import Document as DatasetDocument
  26. from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
  27. from models.model import UploadFile
  28. from models.source import DataSourceBinding
  29. class IndexingRunner:
  30. def __init__(self):
  31. self.storage = storage
  32. def run(self, dataset_documents: List[DatasetDocument]):
  33. """Run the indexing process."""
  34. for dataset_document in dataset_documents:
  35. try:
  36. # get dataset
  37. dataset = Dataset.query.filter_by(
  38. id=dataset_document.dataset_id
  39. ).first()
  40. if not dataset:
  41. raise ValueError("no dataset found")
  42. # load file
  43. text_docs = self._load_data(dataset_document)
  44. # get the process rule
  45. processing_rule = db.session.query(DatasetProcessRule). \
  46. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  47. first()
  48. # get splitter
  49. splitter = self._get_splitter(processing_rule)
  50. # split to documents
  51. documents = self._step_split(
  52. text_docs=text_docs,
  53. splitter=splitter,
  54. dataset=dataset,
  55. dataset_document=dataset_document,
  56. processing_rule=processing_rule
  57. )
  58. # new_documents = []
  59. # for document in documents:
  60. # response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
  61. # document_qa_list = self.format_split_text(response)
  62. # for result in document_qa_list:
  63. # document = Document(page_content=result['question'], metadata={'source': result['answer']})
  64. # new_documents.append(document)
  65. # build index
  66. self._build_index(
  67. dataset=dataset,
  68. dataset_document=dataset_document,
  69. documents=documents
  70. )
  71. except DocumentIsPausedException:
  72. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  73. except ProviderTokenNotInitError as e:
  74. dataset_document.indexing_status = 'error'
  75. dataset_document.error = str(e.description)
  76. dataset_document.stopped_at = datetime.datetime.utcnow()
  77. db.session.commit()
  78. except Exception as e:
  79. logging.exception("consume document failed")
  80. dataset_document.indexing_status = 'error'
  81. dataset_document.error = str(e)
  82. dataset_document.stopped_at = datetime.datetime.utcnow()
  83. db.session.commit()
  84. def format_split_text(self, text):
  85. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
  86. matches = re.findall(regex, text, re.MULTILINE)
  87. result = []
  88. for match in matches:
  89. q = match[0]
  90. a = match[1]
  91. if q and a:
  92. result.append({
  93. "question": q,
  94. "answer": re.sub(r"\n\s*", "\n", a.strip())
  95. })
  96. return result
  97. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  98. """Run the indexing process when the index_status is splitting."""
  99. try:
  100. # get dataset
  101. dataset = Dataset.query.filter_by(
  102. id=dataset_document.dataset_id
  103. ).first()
  104. if not dataset:
  105. raise ValueError("no dataset found")
  106. # get exist document_segment list and delete
  107. document_segments = DocumentSegment.query.filter_by(
  108. dataset_id=dataset.id,
  109. document_id=dataset_document.id
  110. ).all()
  111. db.session.delete(document_segments)
  112. db.session.commit()
  113. # load file
  114. text_docs = self._load_data(dataset_document)
  115. # get the process rule
  116. processing_rule = db.session.query(DatasetProcessRule). \
  117. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  118. first()
  119. # get splitter
  120. splitter = self._get_splitter(processing_rule)
  121. # split to documents
  122. documents = self._step_split(
  123. text_docs=text_docs,
  124. splitter=splitter,
  125. dataset=dataset,
  126. dataset_document=dataset_document,
  127. processing_rule=processing_rule
  128. )
  129. # build index
  130. self._build_index(
  131. dataset=dataset,
  132. dataset_document=dataset_document,
  133. documents=documents
  134. )
  135. except DocumentIsPausedException:
  136. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  137. except ProviderTokenNotInitError as e:
  138. dataset_document.indexing_status = 'error'
  139. dataset_document.error = str(e.description)
  140. dataset_document.stopped_at = datetime.datetime.utcnow()
  141. db.session.commit()
  142. except Exception as e:
  143. logging.exception("consume document failed")
  144. dataset_document.indexing_status = 'error'
  145. dataset_document.error = str(e)
  146. dataset_document.stopped_at = datetime.datetime.utcnow()
  147. db.session.commit()
  148. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  149. """Run the indexing process when the index_status is indexing."""
  150. try:
  151. # get dataset
  152. dataset = Dataset.query.filter_by(
  153. id=dataset_document.dataset_id
  154. ).first()
  155. if not dataset:
  156. raise ValueError("no dataset found")
  157. # get exist document_segment list and delete
  158. document_segments = DocumentSegment.query.filter_by(
  159. dataset_id=dataset.id,
  160. document_id=dataset_document.id
  161. ).all()
  162. documents = []
  163. if document_segments:
  164. for document_segment in document_segments:
  165. # transform segment to node
  166. if document_segment.status != "completed":
  167. document = Document(
  168. page_content=document_segment.content,
  169. metadata={
  170. "doc_id": document_segment.index_node_id,
  171. "doc_hash": document_segment.index_node_hash,
  172. "document_id": document_segment.document_id,
  173. "dataset_id": document_segment.dataset_id,
  174. }
  175. )
  176. documents.append(document)
  177. # build index
  178. self._build_index(
  179. dataset=dataset,
  180. dataset_document=dataset_document,
  181. documents=documents
  182. )
  183. except DocumentIsPausedException:
  184. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  185. except ProviderTokenNotInitError as e:
  186. dataset_document.indexing_status = 'error'
  187. dataset_document.error = str(e.description)
  188. dataset_document.stopped_at = datetime.datetime.utcnow()
  189. db.session.commit()
  190. except Exception as e:
  191. logging.exception("consume document failed")
  192. dataset_document.indexing_status = 'error'
  193. dataset_document.error = str(e)
  194. dataset_document.stopped_at = datetime.datetime.utcnow()
  195. db.session.commit()
  196. def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
  197. doc_form: str = None) -> dict:
  198. """
  199. Estimate the indexing for the document.
  200. """
  201. embedding_model = ModelFactory.get_embedding_model(
  202. tenant_id=tenant_id
  203. )
  204. tokens = 0
  205. preview_texts = []
  206. total_segments = 0
  207. for file_detail in file_details:
  208. # load data from file
  209. text_docs = FileExtractor.load(file_detail)
  210. processing_rule = DatasetProcessRule(
  211. mode=tmp_processing_rule["mode"],
  212. rules=json.dumps(tmp_processing_rule["rules"])
  213. )
  214. # get splitter
  215. splitter = self._get_splitter(processing_rule)
  216. # split to documents
  217. documents = self._split_to_documents_for_estimate(
  218. text_docs=text_docs,
  219. splitter=splitter,
  220. processing_rule=processing_rule
  221. )
  222. total_segments += len(documents)
  223. for document in documents:
  224. if len(preview_texts) < 5:
  225. preview_texts.append(document.page_content)
  226. tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
  227. text_generation_model = ModelFactory.get_text_generation_model(
  228. tenant_id=tenant_id
  229. )
  230. if doc_form and doc_form == 'qa_model':
  231. if len(preview_texts) > 0:
  232. # qa model document
  233. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
  234. document_qa_list = self.format_split_text(response)
  235. return {
  236. "total_segments": total_segments * 20,
  237. "tokens": total_segments * 2000,
  238. "total_price": '{:f}'.format(
  239. text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
  240. "currency": embedding_model.get_currency(),
  241. "qa_preview": document_qa_list,
  242. "preview": preview_texts
  243. }
  244. return {
  245. "total_segments": total_segments,
  246. "tokens": tokens,
  247. "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
  248. "currency": embedding_model.get_currency(),
  249. "preview": preview_texts
  250. }
  251. def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
  252. """
  253. Estimate the indexing for the document.
  254. """
  255. embedding_model = ModelFactory.get_embedding_model(
  256. tenant_id=tenant_id
  257. )
  258. # load data from notion
  259. tokens = 0
  260. preview_texts = []
  261. total_segments = 0
  262. for notion_info in notion_info_list:
  263. workspace_id = notion_info['workspace_id']
  264. data_source_binding = DataSourceBinding.query.filter(
  265. db.and_(
  266. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  267. DataSourceBinding.provider == 'notion',
  268. DataSourceBinding.disabled == False,
  269. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  270. )
  271. ).first()
  272. if not data_source_binding:
  273. raise ValueError('Data source binding not found.')
  274. for page in notion_info['pages']:
  275. loader = NotionLoader(
  276. notion_access_token=data_source_binding.access_token,
  277. notion_workspace_id=workspace_id,
  278. notion_obj_id=page['page_id'],
  279. notion_page_type=page['type']
  280. )
  281. documents = loader.load()
  282. processing_rule = DatasetProcessRule(
  283. mode=tmp_processing_rule["mode"],
  284. rules=json.dumps(tmp_processing_rule["rules"])
  285. )
  286. # get splitter
  287. splitter = self._get_splitter(processing_rule)
  288. # split to documents
  289. documents = self._split_to_documents_for_estimate(
  290. text_docs=documents,
  291. splitter=splitter,
  292. processing_rule=processing_rule
  293. )
  294. total_segments += len(documents)
  295. for document in documents:
  296. if len(preview_texts) < 5:
  297. preview_texts.append(document.page_content)
  298. tokens += embedding_model.get_num_tokens(document.page_content)
  299. text_generation_model = ModelFactory.get_text_generation_model(
  300. tenant_id=tenant_id
  301. )
  302. if doc_form and doc_form == 'qa_model':
  303. if len(preview_texts) > 0:
  304. # qa model document
  305. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0])
  306. document_qa_list = self.format_split_text(response)
  307. return {
  308. "total_segments": total_segments * 20,
  309. "tokens": total_segments * 2000,
  310. "total_price": '{:f}'.format(
  311. text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
  312. "currency": embedding_model.get_currency(),
  313. "qa_preview": document_qa_list,
  314. "preview": preview_texts
  315. }
  316. return {
  317. "total_segments": total_segments,
  318. "tokens": tokens,
  319. "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
  320. "currency": embedding_model.get_currency(),
  321. "preview": preview_texts
  322. }
  323. def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
  324. # load file
  325. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  326. return []
  327. data_source_info = dataset_document.data_source_info_dict
  328. text_docs = []
  329. if dataset_document.data_source_type == 'upload_file':
  330. if not data_source_info or 'upload_file_id' not in data_source_info:
  331. raise ValueError("no upload file found")
  332. file_detail = db.session.query(UploadFile). \
  333. filter(UploadFile.id == data_source_info['upload_file_id']). \
  334. one_or_none()
  335. text_docs = FileExtractor.load(file_detail)
  336. elif dataset_document.data_source_type == 'notion_import':
  337. loader = NotionLoader.from_document(dataset_document)
  338. text_docs = loader.load()
  339. # update document status to splitting
  340. self._update_document_index_status(
  341. document_id=dataset_document.id,
  342. after_indexing_status="splitting",
  343. extra_update_params={
  344. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  345. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  346. }
  347. )
  348. # replace doc id to document model id
  349. text_docs = cast(List[Document], text_docs)
  350. for text_doc in text_docs:
  351. # remove invalid symbol
  352. text_doc.page_content = self.filter_string(text_doc.page_content)
  353. text_doc.metadata['document_id'] = dataset_document.id
  354. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  355. return text_docs
  356. def filter_string(self, text):
  357. text = re.sub(r'<\|', '<', text)
  358. text = re.sub(r'\|>', '>', text)
  359. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
  360. return text
  361. def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
  362. """
  363. Get the NodeParser object according to the processing rule.
  364. """
  365. if processing_rule.mode == "custom":
  366. # The user-defined segmentation rule
  367. rules = json.loads(processing_rule.rules)
  368. segmentation = rules["segmentation"]
  369. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  370. raise ValueError("Custom segment length should be between 50 and 1000.")
  371. separator = segmentation["separator"]
  372. if separator:
  373. separator = separator.replace('\\n', '\n')
  374. character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
  375. chunk_size=segmentation["max_tokens"],
  376. chunk_overlap=0,
  377. fixed_separator=separator,
  378. separators=["\n\n", "。", ".", " ", ""]
  379. )
  380. else:
  381. # Automatic segmentation
  382. character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  383. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  384. chunk_overlap=0,
  385. separators=["\n\n", "。", ".", " ", ""]
  386. )
  387. return character_splitter
  388. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  389. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  390. -> List[Document]:
  391. """
  392. Split the text documents into documents and save them to the document segment.
  393. """
  394. documents = self._split_to_documents(
  395. text_docs=text_docs,
  396. splitter=splitter,
  397. processing_rule=processing_rule,
  398. tenant_id=dataset.tenant_id,
  399. document_form=dataset_document.doc_form
  400. )
  401. # save node to document segment
  402. doc_store = DatesetDocumentStore(
  403. dataset=dataset,
  404. user_id=dataset_document.created_by,
  405. document_id=dataset_document.id
  406. )
  407. # add document segments
  408. doc_store.add_documents(documents)
  409. # update document status to indexing
  410. cur_time = datetime.datetime.utcnow()
  411. self._update_document_index_status(
  412. document_id=dataset_document.id,
  413. after_indexing_status="indexing",
  414. extra_update_params={
  415. DatasetDocument.cleaning_completed_at: cur_time,
  416. DatasetDocument.splitting_completed_at: cur_time,
  417. }
  418. )
  419. # update segment status to indexing
  420. self._update_segments_by_document(
  421. dataset_document_id=dataset_document.id,
  422. update_params={
  423. DocumentSegment.status: "indexing",
  424. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  425. }
  426. )
  427. return documents
  428. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  429. processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
  430. """
  431. Split the text documents into nodes.
  432. """
  433. all_documents = []
  434. all_qa_documents = []
  435. for text_doc in text_docs:
  436. # document clean
  437. document_text = self._document_clean(text_doc.page_content, processing_rule)
  438. text_doc.page_content = document_text
  439. # parse document to nodes
  440. documents = splitter.split_documents([text_doc])
  441. split_documents = []
  442. for document_node in documents:
  443. doc_id = str(uuid.uuid4())
  444. hash = helper.generate_text_hash(document_node.page_content)
  445. document_node.metadata['doc_id'] = doc_id
  446. document_node.metadata['doc_hash'] = hash
  447. split_documents.append(document_node)
  448. all_documents.extend(split_documents)
  449. # processing qa document
  450. if document_form == 'qa_model':
  451. for i in range(0, len(all_documents), 10):
  452. threads = []
  453. sub_documents = all_documents[i:i + 10]
  454. for doc in sub_documents:
  455. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  456. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents})
  457. threads.append(document_format_thread)
  458. document_format_thread.start()
  459. for thread in threads:
  460. thread.join()
  461. return all_qa_documents
  462. return all_documents
  463. def format_qa_document(self, tenant_id: str, document_node, all_qa_documents):
  464. format_documents = []
  465. if document_node.page_content is None or not document_node.page_content.strip():
  466. return
  467. try:
  468. # qa model document
  469. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content)
  470. document_qa_list = self.format_split_text(response)
  471. qa_documents = []
  472. for result in document_qa_list:
  473. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  474. doc_id = str(uuid.uuid4())
  475. hash = helper.generate_text_hash(result['question'])
  476. qa_document.metadata['answer'] = result['answer']
  477. qa_document.metadata['doc_id'] = doc_id
  478. qa_document.metadata['doc_hash'] = hash
  479. qa_documents.append(qa_document)
  480. format_documents.extend(qa_documents)
  481. except Exception as e:
  482. logging.error(str(e))
  483. all_qa_documents.extend(format_documents)
  484. def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
  485. processing_rule: DatasetProcessRule) -> List[Document]:
  486. """
  487. Split the text documents into nodes.
  488. """
  489. all_documents = []
  490. for text_doc in text_docs:
  491. # document clean
  492. document_text = self._document_clean(text_doc.page_content, processing_rule)
  493. text_doc.page_content = document_text
  494. # parse document to nodes
  495. documents = splitter.split_documents([text_doc])
  496. split_documents = []
  497. for document in documents:
  498. if document.page_content is None or not document.page_content.strip():
  499. continue
  500. doc_id = str(uuid.uuid4())
  501. hash = helper.generate_text_hash(document.page_content)
  502. document.metadata['doc_id'] = doc_id
  503. document.metadata['doc_hash'] = hash
  504. split_documents.append(document)
  505. all_documents.extend(split_documents)
  506. return all_documents
  507. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  508. """
  509. Clean the document text according to the processing rules.
  510. """
  511. if processing_rule.mode == "automatic":
  512. rules = DatasetProcessRule.AUTOMATIC_RULES
  513. else:
  514. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  515. if 'pre_processing_rules' in rules:
  516. pre_processing_rules = rules["pre_processing_rules"]
  517. for pre_processing_rule in pre_processing_rules:
  518. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  519. # Remove extra spaces
  520. pattern = r'\n{3,}'
  521. text = re.sub(pattern, '\n\n', text)
  522. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  523. text = re.sub(pattern, ' ', text)
  524. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  525. # Remove email
  526. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  527. text = re.sub(pattern, '', text)
  528. # Remove URL
  529. pattern = r'https?://[^\s]+'
  530. text = re.sub(pattern, '', text)
  531. return text
  532. def format_split_text(self, text):
  533. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
  534. matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
  535. result = [] # 存储最终的结果
  536. for match in matches:
  537. q = match[0]
  538. a = match[1]
  539. if q and a:
  540. # 如果Q和A都存在,就将其添加到结果中
  541. result.append({
  542. "question": q,
  543. "answer": re.sub(r"\n\s*", "\n", a.strip())
  544. })
  545. return result
  546. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  547. """
  548. Build the index for the document.
  549. """
  550. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  551. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  552. embedding_model = ModelFactory.get_embedding_model(
  553. tenant_id=dataset.tenant_id
  554. )
  555. # chunk nodes by chunk size
  556. indexing_start_at = time.perf_counter()
  557. tokens = 0
  558. chunk_size = 100
  559. for i in range(0, len(documents), chunk_size):
  560. # check document is paused
  561. self._check_document_paused_status(dataset_document.id)
  562. chunk_documents = documents[i:i + chunk_size]
  563. tokens += sum(
  564. embedding_model.get_num_tokens(document.page_content)
  565. for document in chunk_documents
  566. )
  567. # save vector index
  568. if vector_index:
  569. vector_index.add_texts(chunk_documents)
  570. # save keyword index
  571. keyword_table_index.add_texts(chunk_documents)
  572. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  573. db.session.query(DocumentSegment).filter(
  574. DocumentSegment.document_id == dataset_document.id,
  575. DocumentSegment.index_node_id.in_(document_ids),
  576. DocumentSegment.status == "indexing"
  577. ).update({
  578. DocumentSegment.status: "completed",
  579. DocumentSegment.completed_at: datetime.datetime.utcnow()
  580. })
  581. db.session.commit()
  582. indexing_end_at = time.perf_counter()
  583. # update document status to completed
  584. self._update_document_index_status(
  585. document_id=dataset_document.id,
  586. after_indexing_status="completed",
  587. extra_update_params={
  588. DatasetDocument.tokens: tokens,
  589. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  590. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  591. }
  592. )
  593. def _check_document_paused_status(self, document_id: str):
  594. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  595. result = redis_client.get(indexing_cache_key)
  596. if result:
  597. raise DocumentIsPausedException()
  598. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  599. extra_update_params: Optional[dict] = None) -> None:
  600. """
  601. Update the document indexing status.
  602. """
  603. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  604. if count > 0:
  605. raise DocumentIsPausedException()
  606. update_params = {
  607. DatasetDocument.indexing_status: after_indexing_status
  608. }
  609. if extra_update_params:
  610. update_params.update(extra_update_params)
  611. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  612. db.session.commit()
  613. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  614. """
  615. Update the document segment by document id.
  616. """
  617. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  618. db.session.commit()
  619. class DocumentIsPausedException(Exception):
  620. pass