indexing_runner.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. import concurrent
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from concurrent.futures import ThreadPoolExecutor
  10. from typing import Optional, List, cast
  11. from flask_login import current_user
  12. from langchain.schema import Document
  13. from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
  14. from core.data_loader.file_extractor import FileExtractor
  15. from core.data_loader.loader.notion import NotionLoader
  16. from core.docstore.dataset_docstore import DatesetDocumentStore
  17. from core.generator.llm_generator import LLMGenerator
  18. from core.index.index import IndexBuilder
  19. from core.llm.error import ProviderTokenNotInitError
  20. from core.llm.llm_builder import LLMBuilder
  21. from core.llm.streamable_open_ai import StreamableOpenAI
  22. from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
  23. from core.llm.token_calculator import TokenCalculator
  24. from extensions.ext_database import db
  25. from extensions.ext_redis import redis_client
  26. from extensions.ext_storage import storage
  27. from libs import helper
  28. from models.dataset import Document as DatasetDocument
  29. from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
  30. from models.model import UploadFile
  31. from models.source import DataSourceBinding
  32. class IndexingRunner:
  33. def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
  34. self.storage = storage
  35. self.embedding_model_name = embedding_model_name
  36. def run(self, dataset_documents: List[DatasetDocument]):
  37. """Run the indexing process."""
  38. for dataset_document in dataset_documents:
  39. try:
  40. # get dataset
  41. dataset = Dataset.query.filter_by(
  42. id=dataset_document.dataset_id
  43. ).first()
  44. if not dataset:
  45. raise ValueError("no dataset found")
  46. # load file
  47. text_docs = self._load_data(dataset_document)
  48. # get the process rule
  49. processing_rule = db.session.query(DatasetProcessRule). \
  50. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  51. first()
  52. # get splitter
  53. splitter = self._get_splitter(processing_rule)
  54. # split to documents
  55. documents = self._step_split(
  56. text_docs=text_docs,
  57. splitter=splitter,
  58. dataset=dataset,
  59. dataset_document=dataset_document,
  60. processing_rule=processing_rule
  61. )
  62. # new_documents = []
  63. # for document in documents:
  64. # response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
  65. # document_qa_list = self.format_split_text(response)
  66. # for result in document_qa_list:
  67. # document = Document(page_content=result['question'], metadata={'source': result['answer']})
  68. # new_documents.append(document)
  69. # build index
  70. self._build_index(
  71. dataset=dataset,
  72. dataset_document=dataset_document,
  73. documents=documents
  74. )
  75. except DocumentIsPausedException:
  76. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  77. except ProviderTokenNotInitError as e:
  78. dataset_document.indexing_status = 'error'
  79. dataset_document.error = str(e.description)
  80. dataset_document.stopped_at = datetime.datetime.utcnow()
  81. db.session.commit()
  82. except Exception as e:
  83. logging.exception("consume document failed")
  84. dataset_document.indexing_status = 'error'
  85. dataset_document.error = str(e)
  86. dataset_document.stopped_at = datetime.datetime.utcnow()
  87. db.session.commit()
  88. def format_split_text(self, text):
  89. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
  90. matches = re.findall(regex, text, re.MULTILINE)
  91. result = []
  92. for match in matches:
  93. q = match[0]
  94. a = match[1]
  95. if q and a:
  96. result.append({
  97. "question": q,
  98. "answer": re.sub(r"\n\s*", "\n", a.strip())
  99. })
  100. return result
  101. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  102. """Run the indexing process when the index_status is splitting."""
  103. try:
  104. # get dataset
  105. dataset = Dataset.query.filter_by(
  106. id=dataset_document.dataset_id
  107. ).first()
  108. if not dataset:
  109. raise ValueError("no dataset found")
  110. # get exist document_segment list and delete
  111. document_segments = DocumentSegment.query.filter_by(
  112. dataset_id=dataset.id,
  113. document_id=dataset_document.id
  114. ).all()
  115. db.session.delete(document_segments)
  116. db.session.commit()
  117. # load file
  118. text_docs = self._load_data(dataset_document)
  119. # get the process rule
  120. processing_rule = db.session.query(DatasetProcessRule). \
  121. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  122. first()
  123. # get splitter
  124. splitter = self._get_splitter(processing_rule)
  125. # split to documents
  126. documents = self._step_split(
  127. text_docs=text_docs,
  128. splitter=splitter,
  129. dataset=dataset,
  130. dataset_document=dataset_document,
  131. processing_rule=processing_rule
  132. )
  133. # build index
  134. self._build_index(
  135. dataset=dataset,
  136. dataset_document=dataset_document,
  137. documents=documents
  138. )
  139. except DocumentIsPausedException:
  140. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  141. except ProviderTokenNotInitError as e:
  142. dataset_document.indexing_status = 'error'
  143. dataset_document.error = str(e.description)
  144. dataset_document.stopped_at = datetime.datetime.utcnow()
  145. db.session.commit()
  146. except Exception as e:
  147. logging.exception("consume document failed")
  148. dataset_document.indexing_status = 'error'
  149. dataset_document.error = str(e)
  150. dataset_document.stopped_at = datetime.datetime.utcnow()
  151. db.session.commit()
  152. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  153. """Run the indexing process when the index_status is indexing."""
  154. try:
  155. # get dataset
  156. dataset = Dataset.query.filter_by(
  157. id=dataset_document.dataset_id
  158. ).first()
  159. if not dataset:
  160. raise ValueError("no dataset found")
  161. # get exist document_segment list and delete
  162. document_segments = DocumentSegment.query.filter_by(
  163. dataset_id=dataset.id,
  164. document_id=dataset_document.id
  165. ).all()
  166. documents = []
  167. if document_segments:
  168. for document_segment in document_segments:
  169. # transform segment to node
  170. if document_segment.status != "completed":
  171. document = Document(
  172. page_content=document_segment.content,
  173. metadata={
  174. "doc_id": document_segment.index_node_id,
  175. "doc_hash": document_segment.index_node_hash,
  176. "document_id": document_segment.document_id,
  177. "dataset_id": document_segment.dataset_id,
  178. }
  179. )
  180. documents.append(document)
  181. # build index
  182. self._build_index(
  183. dataset=dataset,
  184. dataset_document=dataset_document,
  185. documents=documents
  186. )
  187. except DocumentIsPausedException:
  188. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  189. except ProviderTokenNotInitError as e:
  190. dataset_document.indexing_status = 'error'
  191. dataset_document.error = str(e.description)
  192. dataset_document.stopped_at = datetime.datetime.utcnow()
  193. db.session.commit()
  194. except Exception as e:
  195. logging.exception("consume document failed")
  196. dataset_document.indexing_status = 'error'
  197. dataset_document.error = str(e)
  198. dataset_document.stopped_at = datetime.datetime.utcnow()
  199. db.session.commit()
  200. def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict,
  201. doc_form: str = None) -> dict:
  202. """
  203. Estimate the indexing for the document.
  204. """
  205. tokens = 0
  206. preview_texts = []
  207. total_segments = 0
  208. for file_detail in file_details:
  209. # load data from file
  210. text_docs = FileExtractor.load(file_detail)
  211. processing_rule = DatasetProcessRule(
  212. mode=tmp_processing_rule["mode"],
  213. rules=json.dumps(tmp_processing_rule["rules"])
  214. )
  215. # get splitter
  216. splitter = self._get_splitter(processing_rule)
  217. # split to documents
  218. documents = self._split_to_documents_for_estimate(
  219. text_docs=text_docs,
  220. splitter=splitter,
  221. processing_rule=processing_rule
  222. )
  223. total_segments += len(documents)
  224. for document in documents:
  225. if len(preview_texts) < 5:
  226. preview_texts.append(document.page_content)
  227. tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
  228. self.filter_string(document.page_content))
  229. if doc_form and doc_form == 'qa_model':
  230. if len(preview_texts) > 0:
  231. # qa model document
  232. llm: StreamableOpenAI = LLMBuilder.to_llm(
  233. tenant_id=current_user.current_tenant_id,
  234. model_name='gpt-3.5-turbo',
  235. max_tokens=2000
  236. )
  237. response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
  238. document_qa_list = self.format_split_text(response)
  239. return {
  240. "total_segments": total_segments * 20,
  241. "tokens": total_segments * 2000,
  242. "total_price": '{:f}'.format(
  243. TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
  244. "currency": TokenCalculator.get_currency(self.embedding_model_name),
  245. "qa_preview": document_qa_list,
  246. "preview": preview_texts
  247. }
  248. return {
  249. "total_segments": total_segments,
  250. "tokens": tokens,
  251. "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
  252. "currency": TokenCalculator.get_currency(self.embedding_model_name),
  253. "preview": preview_texts
  254. }
  255. def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict:
  256. """
  257. Estimate the indexing for the document.
  258. """
  259. # load data from notion
  260. tokens = 0
  261. preview_texts = []
  262. total_segments = 0
  263. for notion_info in notion_info_list:
  264. workspace_id = notion_info['workspace_id']
  265. data_source_binding = DataSourceBinding.query.filter(
  266. db.and_(
  267. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  268. DataSourceBinding.provider == 'notion',
  269. DataSourceBinding.disabled == False,
  270. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  271. )
  272. ).first()
  273. if not data_source_binding:
  274. raise ValueError('Data source binding not found.')
  275. for page in notion_info['pages']:
  276. loader = NotionLoader(
  277. notion_access_token=data_source_binding.access_token,
  278. notion_workspace_id=workspace_id,
  279. notion_obj_id=page['page_id'],
  280. notion_page_type=page['type']
  281. )
  282. documents = loader.load()
  283. processing_rule = DatasetProcessRule(
  284. mode=tmp_processing_rule["mode"],
  285. rules=json.dumps(tmp_processing_rule["rules"])
  286. )
  287. # get splitter
  288. splitter = self._get_splitter(processing_rule)
  289. # split to documents
  290. documents = self._split_to_documents_for_estimate(
  291. text_docs=documents,
  292. splitter=splitter,
  293. processing_rule=processing_rule
  294. )
  295. total_segments += len(documents)
  296. for document in documents:
  297. if len(preview_texts) < 5:
  298. preview_texts.append(document.page_content)
  299. tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
  300. if doc_form and doc_form == 'qa_model':
  301. if len(preview_texts) > 0:
  302. # qa model document
  303. llm: StreamableOpenAI = LLMBuilder.to_llm(
  304. tenant_id=current_user.current_tenant_id,
  305. model_name='gpt-3.5-turbo',
  306. max_tokens=2000
  307. )
  308. response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0])
  309. document_qa_list = self.format_split_text(response)
  310. return {
  311. "total_segments": total_segments * 20,
  312. "tokens": total_segments * 2000,
  313. "total_price": '{:f}'.format(
  314. TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')),
  315. "currency": TokenCalculator.get_currency(self.embedding_model_name),
  316. "qa_preview": document_qa_list,
  317. "preview": preview_texts
  318. }
  319. return {
  320. "total_segments": total_segments,
  321. "tokens": tokens,
  322. "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
  323. "currency": TokenCalculator.get_currency(self.embedding_model_name),
  324. "preview": preview_texts
  325. }
  326. def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
  327. # load file
  328. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  329. return []
  330. data_source_info = dataset_document.data_source_info_dict
  331. text_docs = []
  332. if dataset_document.data_source_type == 'upload_file':
  333. if not data_source_info or 'upload_file_id' not in data_source_info:
  334. raise ValueError("no upload file found")
  335. file_detail = db.session.query(UploadFile). \
  336. filter(UploadFile.id == data_source_info['upload_file_id']). \
  337. one_or_none()
  338. text_docs = FileExtractor.load(file_detail)
  339. elif dataset_document.data_source_type == 'notion_import':
  340. loader = NotionLoader.from_document(dataset_document)
  341. text_docs = loader.load()
  342. # update document status to splitting
  343. self._update_document_index_status(
  344. document_id=dataset_document.id,
  345. after_indexing_status="splitting",
  346. extra_update_params={
  347. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  348. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  349. }
  350. )
  351. # replace doc id to document model id
  352. text_docs = cast(List[Document], text_docs)
  353. for text_doc in text_docs:
  354. # remove invalid symbol
  355. text_doc.page_content = self.filter_string(text_doc.page_content)
  356. text_doc.metadata['document_id'] = dataset_document.id
  357. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  358. return text_docs
  359. def filter_string(self, text):
  360. text = re.sub(r'<\|', '<', text)
  361. text = re.sub(r'\|>', '>', text)
  362. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
  363. return text
  364. def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
  365. """
  366. Get the NodeParser object according to the processing rule.
  367. """
  368. if processing_rule.mode == "custom":
  369. # The user-defined segmentation rule
  370. rules = json.loads(processing_rule.rules)
  371. segmentation = rules["segmentation"]
  372. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  373. raise ValueError("Custom segment length should be between 50 and 1000.")
  374. separator = segmentation["separator"]
  375. if separator:
  376. separator = separator.replace('\\n', '\n')
  377. character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
  378. chunk_size=segmentation["max_tokens"],
  379. chunk_overlap=0,
  380. fixed_separator=separator,
  381. separators=["\n\n", "。", ".", " ", ""]
  382. )
  383. else:
  384. # Automatic segmentation
  385. character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
  386. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  387. chunk_overlap=0,
  388. separators=["\n\n", "。", ".", " ", ""]
  389. )
  390. return character_splitter
  391. def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
  392. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  393. -> List[Document]:
  394. """
  395. Split the text documents into documents and save them to the document segment.
  396. """
  397. documents = self._split_to_documents(
  398. text_docs=text_docs,
  399. splitter=splitter,
  400. processing_rule=processing_rule,
  401. tenant_id=dataset.tenant_id,
  402. document_form=dataset_document.doc_form
  403. )
  404. # save node to document segment
  405. doc_store = DatesetDocumentStore(
  406. dataset=dataset,
  407. user_id=dataset_document.created_by,
  408. embedding_model_name=self.embedding_model_name,
  409. document_id=dataset_document.id
  410. )
  411. # add document segments
  412. doc_store.add_documents(documents)
  413. # update document status to indexing
  414. cur_time = datetime.datetime.utcnow()
  415. self._update_document_index_status(
  416. document_id=dataset_document.id,
  417. after_indexing_status="indexing",
  418. extra_update_params={
  419. DatasetDocument.cleaning_completed_at: cur_time,
  420. DatasetDocument.splitting_completed_at: cur_time,
  421. }
  422. )
  423. # update segment status to indexing
  424. self._update_segments_by_document(
  425. dataset_document_id=dataset_document.id,
  426. update_params={
  427. DocumentSegment.status: "indexing",
  428. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  429. }
  430. )
  431. return documents
  432. def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
  433. processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]:
  434. """
  435. Split the text documents into nodes.
  436. """
  437. all_documents = []
  438. all_qa_documents = []
  439. for text_doc in text_docs:
  440. # document clean
  441. document_text = self._document_clean(text_doc.page_content, processing_rule)
  442. text_doc.page_content = document_text
  443. # parse document to nodes
  444. documents = splitter.split_documents([text_doc])
  445. split_documents = []
  446. for document_node in documents:
  447. doc_id = str(uuid.uuid4())
  448. hash = helper.generate_text_hash(document_node.page_content)
  449. document_node.metadata['doc_id'] = doc_id
  450. document_node.metadata['doc_hash'] = hash
  451. split_documents.append(document_node)
  452. all_documents.extend(split_documents)
  453. # processing qa document
  454. if document_form == 'qa_model':
  455. llm: StreamableOpenAI = LLMBuilder.to_llm(
  456. tenant_id=tenant_id,
  457. model_name='gpt-3.5-turbo',
  458. max_tokens=2000
  459. )
  460. for i in range(0, len(all_documents), 10):
  461. threads = []
  462. sub_documents = all_documents[i:i + 10]
  463. for doc in sub_documents:
  464. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  465. 'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
  466. threads.append(document_format_thread)
  467. document_format_thread.start()
  468. for thread in threads:
  469. thread.join()
  470. return all_qa_documents
  471. return all_documents
  472. def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
  473. format_documents = []
  474. if document_node.page_content is None or not document_node.page_content.strip():
  475. return
  476. try:
  477. # qa model document
  478. response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
  479. document_qa_list = self.format_split_text(response)
  480. qa_documents = []
  481. for result in document_qa_list:
  482. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  483. doc_id = str(uuid.uuid4())
  484. hash = helper.generate_text_hash(result['question'])
  485. qa_document.metadata['answer'] = result['answer']
  486. qa_document.metadata['doc_id'] = doc_id
  487. qa_document.metadata['doc_hash'] = hash
  488. qa_documents.append(qa_document)
  489. format_documents.extend(qa_documents)
  490. except Exception as e:
  491. logging.error(str(e))
  492. all_qa_documents.extend(format_documents)
  493. def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
  494. processing_rule: DatasetProcessRule) -> List[Document]:
  495. """
  496. Split the text documents into nodes.
  497. """
  498. all_documents = []
  499. for text_doc in text_docs:
  500. # document clean
  501. document_text = self._document_clean(text_doc.page_content, processing_rule)
  502. text_doc.page_content = document_text
  503. # parse document to nodes
  504. documents = splitter.split_documents([text_doc])
  505. split_documents = []
  506. for document in documents:
  507. if document.page_content is None or not document.page_content.strip():
  508. continue
  509. doc_id = str(uuid.uuid4())
  510. hash = helper.generate_text_hash(document.page_content)
  511. document.metadata['doc_id'] = doc_id
  512. document.metadata['doc_hash'] = hash
  513. split_documents.append(document)
  514. all_documents.extend(split_documents)
  515. return all_documents
  516. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  517. """
  518. Clean the document text according to the processing rules.
  519. """
  520. if processing_rule.mode == "automatic":
  521. rules = DatasetProcessRule.AUTOMATIC_RULES
  522. else:
  523. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  524. if 'pre_processing_rules' in rules:
  525. pre_processing_rules = rules["pre_processing_rules"]
  526. for pre_processing_rule in pre_processing_rules:
  527. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  528. # Remove extra spaces
  529. pattern = r'\n{3,}'
  530. text = re.sub(pattern, '\n\n', text)
  531. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  532. text = re.sub(pattern, ' ', text)
  533. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  534. # Remove email
  535. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  536. text = re.sub(pattern, '', text)
  537. # Remove URL
  538. pattern = r'https?://[^\s]+'
  539. text = re.sub(pattern, '', text)
  540. return text
  541. def format_split_text(self, text):
  542. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
  543. matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
  544. result = [] # 存储最终的结果
  545. for match in matches:
  546. q = match[0]
  547. a = match[1]
  548. if q and a:
  549. # 如果Q和A都存在,就将其添加到结果中
  550. result.append({
  551. "question": q,
  552. "answer": re.sub(r"\n\s*", "\n", a.strip())
  553. })
  554. return result
  555. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
  556. """
  557. Build the index for the document.
  558. """
  559. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  560. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  561. # chunk nodes by chunk size
  562. indexing_start_at = time.perf_counter()
  563. tokens = 0
  564. chunk_size = 100
  565. for i in range(0, len(documents), chunk_size):
  566. # check document is paused
  567. self._check_document_paused_status(dataset_document.id)
  568. chunk_documents = documents[i:i + chunk_size]
  569. tokens += sum(
  570. TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
  571. for document in chunk_documents
  572. )
  573. # save vector index
  574. if vector_index:
  575. vector_index.add_texts(chunk_documents)
  576. # save keyword index
  577. keyword_table_index.add_texts(chunk_documents)
  578. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  579. db.session.query(DocumentSegment).filter(
  580. DocumentSegment.document_id == dataset_document.id,
  581. DocumentSegment.index_node_id.in_(document_ids),
  582. DocumentSegment.status == "indexing"
  583. ).update({
  584. DocumentSegment.status: "completed",
  585. DocumentSegment.completed_at: datetime.datetime.utcnow()
  586. })
  587. db.session.commit()
  588. indexing_end_at = time.perf_counter()
  589. # update document status to completed
  590. self._update_document_index_status(
  591. document_id=dataset_document.id,
  592. after_indexing_status="completed",
  593. extra_update_params={
  594. DatasetDocument.tokens: tokens,
  595. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  596. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  597. }
  598. )
  599. def _check_document_paused_status(self, document_id: str):
  600. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  601. result = redis_client.get(indexing_cache_key)
  602. if result:
  603. raise DocumentIsPausedException()
  604. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  605. extra_update_params: Optional[dict] = None) -> None:
  606. """
  607. Update the document indexing status.
  608. """
  609. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  610. if count > 0:
  611. raise DocumentIsPausedException()
  612. update_params = {
  613. DatasetDocument.indexing_status: after_indexing_status
  614. }
  615. if extra_update_params:
  616. update_params.update(extra_update_params)
  617. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  618. db.session.commit()
  619. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  620. """
  621. Update the document segment by document id.
  622. """
  623. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  624. db.session.commit()
  625. class DocumentIsPausedException(Exception):
  626. pass