indexing_runner.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import Optional, cast
  9. from flask import Flask, current_app
  10. from flask_login import current_user
  11. from langchain.schema import Document
  12. from langchain.text_splitter import TextSplitter
  13. from sqlalchemy.orm.exc import ObjectDeletedError
  14. from core.data_loader.file_extractor import FileExtractor
  15. from core.data_loader.loader.notion import NotionLoader
  16. from core.docstore.dataset_docstore import DatasetDocumentStore
  17. from core.errors.error import ProviderTokenNotInitError
  18. from core.generator.llm_generator import LLMGenerator
  19. from core.index.index import IndexBuilder
  20. from core.model_manager import ModelInstance, ModelManager
  21. from core.model_runtime.entities.model_entities import ModelType, PriceType
  22. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  23. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  24. from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
  25. from extensions.ext_database import db
  26. from extensions.ext_redis import redis_client
  27. from extensions.ext_storage import storage
  28. from libs import helper
  29. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  30. from models.dataset import Document as DatasetDocument
  31. from models.model import UploadFile
  32. from models.source import DataSourceBinding
  33. from services.feature_service import FeatureService
  34. class IndexingRunner:
  35. def __init__(self):
  36. self.storage = storage
  37. self.model_manager = ModelManager()
  38. def run(self, dataset_documents: list[DatasetDocument]):
  39. """Run the indexing process."""
  40. for dataset_document in dataset_documents:
  41. try:
  42. # get dataset
  43. dataset = Dataset.query.filter_by(
  44. id=dataset_document.dataset_id
  45. ).first()
  46. if not dataset:
  47. raise ValueError("no dataset found")
  48. # get the process rule
  49. processing_rule = db.session.query(DatasetProcessRule). \
  50. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  51. first()
  52. # load file
  53. text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
  54. # get embedding model instance
  55. embedding_model_instance = None
  56. if dataset.indexing_technique == 'high_quality':
  57. if dataset.embedding_model_provider:
  58. embedding_model_instance = self.model_manager.get_model_instance(
  59. tenant_id=dataset.tenant_id,
  60. provider=dataset.embedding_model_provider,
  61. model_type=ModelType.TEXT_EMBEDDING,
  62. model=dataset.embedding_model
  63. )
  64. else:
  65. embedding_model_instance = self.model_manager.get_default_model_instance(
  66. tenant_id=dataset.tenant_id,
  67. model_type=ModelType.TEXT_EMBEDDING,
  68. )
  69. # get splitter
  70. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  71. # split to documents
  72. documents = self._step_split(
  73. text_docs=text_docs,
  74. splitter=splitter,
  75. dataset=dataset,
  76. dataset_document=dataset_document,
  77. processing_rule=processing_rule
  78. )
  79. self._build_index(
  80. dataset=dataset,
  81. dataset_document=dataset_document,
  82. documents=documents
  83. )
  84. except DocumentIsPausedException:
  85. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  86. except ProviderTokenNotInitError as e:
  87. dataset_document.indexing_status = 'error'
  88. dataset_document.error = str(e.description)
  89. dataset_document.stopped_at = datetime.datetime.utcnow()
  90. db.session.commit()
  91. except ObjectDeletedError:
  92. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  93. except Exception as e:
  94. logging.exception("consume document failed")
  95. dataset_document.indexing_status = 'error'
  96. dataset_document.error = str(e)
  97. dataset_document.stopped_at = datetime.datetime.utcnow()
  98. db.session.commit()
  99. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  100. """Run the indexing process when the index_status is splitting."""
  101. try:
  102. # get dataset
  103. dataset = Dataset.query.filter_by(
  104. id=dataset_document.dataset_id
  105. ).first()
  106. if not dataset:
  107. raise ValueError("no dataset found")
  108. # get exist document_segment list and delete
  109. document_segments = DocumentSegment.query.filter_by(
  110. dataset_id=dataset.id,
  111. document_id=dataset_document.id
  112. ).all()
  113. for document_segment in document_segments:
  114. db.session.delete(document_segment)
  115. db.session.commit()
  116. # get the process rule
  117. processing_rule = db.session.query(DatasetProcessRule). \
  118. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  119. first()
  120. # load file
  121. text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
  122. # get embedding model instance
  123. embedding_model_instance = None
  124. if dataset.indexing_technique == 'high_quality':
  125. if dataset.embedding_model_provider:
  126. embedding_model_instance = self.model_manager.get_model_instance(
  127. tenant_id=dataset.tenant_id,
  128. provider=dataset.embedding_model_provider,
  129. model_type=ModelType.TEXT_EMBEDDING,
  130. model=dataset.embedding_model
  131. )
  132. else:
  133. embedding_model_instance = self.model_manager.get_default_model_instance(
  134. tenant_id=dataset.tenant_id,
  135. model_type=ModelType.TEXT_EMBEDDING,
  136. )
  137. # get splitter
  138. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  139. # split to documents
  140. documents = self._step_split(
  141. text_docs=text_docs,
  142. splitter=splitter,
  143. dataset=dataset,
  144. dataset_document=dataset_document,
  145. processing_rule=processing_rule
  146. )
  147. # build index
  148. self._build_index(
  149. dataset=dataset,
  150. dataset_document=dataset_document,
  151. documents=documents
  152. )
  153. except DocumentIsPausedException:
  154. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  155. except ProviderTokenNotInitError as e:
  156. dataset_document.indexing_status = 'error'
  157. dataset_document.error = str(e.description)
  158. dataset_document.stopped_at = datetime.datetime.utcnow()
  159. db.session.commit()
  160. except Exception as e:
  161. logging.exception("consume document failed")
  162. dataset_document.indexing_status = 'error'
  163. dataset_document.error = str(e)
  164. dataset_document.stopped_at = datetime.datetime.utcnow()
  165. db.session.commit()
  166. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  167. """Run the indexing process when the index_status is indexing."""
  168. try:
  169. # get dataset
  170. dataset = Dataset.query.filter_by(
  171. id=dataset_document.dataset_id
  172. ).first()
  173. if not dataset:
  174. raise ValueError("no dataset found")
  175. # get exist document_segment list and delete
  176. document_segments = DocumentSegment.query.filter_by(
  177. dataset_id=dataset.id,
  178. document_id=dataset_document.id
  179. ).all()
  180. documents = []
  181. if document_segments:
  182. for document_segment in document_segments:
  183. # transform segment to node
  184. if document_segment.status != "completed":
  185. document = Document(
  186. page_content=document_segment.content,
  187. metadata={
  188. "doc_id": document_segment.index_node_id,
  189. "doc_hash": document_segment.index_node_hash,
  190. "document_id": document_segment.document_id,
  191. "dataset_id": document_segment.dataset_id,
  192. }
  193. )
  194. documents.append(document)
  195. # build index
  196. self._build_index(
  197. dataset=dataset,
  198. dataset_document=dataset_document,
  199. documents=documents
  200. )
  201. except DocumentIsPausedException:
  202. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  203. except ProviderTokenNotInitError as e:
  204. dataset_document.indexing_status = 'error'
  205. dataset_document.error = str(e.description)
  206. dataset_document.stopped_at = datetime.datetime.utcnow()
  207. db.session.commit()
  208. except Exception as e:
  209. logging.exception("consume document failed")
  210. dataset_document.indexing_status = 'error'
  211. dataset_document.error = str(e)
  212. dataset_document.stopped_at = datetime.datetime.utcnow()
  213. db.session.commit()
  214. def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
  215. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  216. indexing_technique: str = 'economy') -> dict:
  217. """
  218. Estimate the indexing for the document.
  219. """
  220. # check document limit
  221. features = FeatureService.get_features(tenant_id)
  222. if features.billing.enabled:
  223. count = len(file_details)
  224. batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
  225. if count > batch_upload_limit:
  226. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  227. embedding_model_instance = None
  228. if dataset_id:
  229. dataset = Dataset.query.filter_by(
  230. id=dataset_id
  231. ).first()
  232. if not dataset:
  233. raise ValueError('Dataset not found.')
  234. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  235. if dataset.embedding_model_provider:
  236. embedding_model_instance = self.model_manager.get_model_instance(
  237. tenant_id=tenant_id,
  238. provider=dataset.embedding_model_provider,
  239. model_type=ModelType.TEXT_EMBEDDING,
  240. model=dataset.embedding_model
  241. )
  242. else:
  243. embedding_model_instance = self.model_manager.get_default_model_instance(
  244. tenant_id=tenant_id,
  245. model_type=ModelType.TEXT_EMBEDDING,
  246. )
  247. else:
  248. if indexing_technique == 'high_quality':
  249. embedding_model_instance = self.model_manager.get_default_model_instance(
  250. tenant_id=tenant_id,
  251. model_type=ModelType.TEXT_EMBEDDING,
  252. )
  253. tokens = 0
  254. preview_texts = []
  255. total_segments = 0
  256. total_price = 0
  257. currency = 'USD'
  258. for file_detail in file_details:
  259. processing_rule = DatasetProcessRule(
  260. mode=tmp_processing_rule["mode"],
  261. rules=json.dumps(tmp_processing_rule["rules"])
  262. )
  263. # load data from file
  264. text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
  265. # get splitter
  266. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  267. # split to documents
  268. documents = self._split_to_documents_for_estimate(
  269. text_docs=text_docs,
  270. splitter=splitter,
  271. processing_rule=processing_rule
  272. )
  273. total_segments += len(documents)
  274. for document in documents:
  275. if len(preview_texts) < 5:
  276. preview_texts.append(document.page_content)
  277. if indexing_technique == 'high_quality' or embedding_model_instance:
  278. embedding_model_type_instance = embedding_model_instance.model_type_instance
  279. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  280. tokens += embedding_model_type_instance.get_num_tokens(
  281. model=embedding_model_instance.model,
  282. credentials=embedding_model_instance.credentials,
  283. texts=[self.filter_string(document.page_content)]
  284. )
  285. if doc_form and doc_form == 'qa_model':
  286. model_instance = self.model_manager.get_default_model_instance(
  287. tenant_id=tenant_id,
  288. model_type=ModelType.LLM
  289. )
  290. model_type_instance = model_instance.model_type_instance
  291. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  292. if len(preview_texts) > 0:
  293. # qa model document
  294. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  295. doc_language)
  296. document_qa_list = self.format_split_text(response)
  297. price_info = model_type_instance.get_price(
  298. model=model_instance.model,
  299. credentials=model_instance.credentials,
  300. price_type=PriceType.INPUT,
  301. tokens=total_segments * 2000,
  302. )
  303. return {
  304. "total_segments": total_segments * 20,
  305. "tokens": total_segments * 2000,
  306. "total_price": '{:f}'.format(price_info.total_amount),
  307. "currency": price_info.currency,
  308. "qa_preview": document_qa_list,
  309. "preview": preview_texts
  310. }
  311. if embedding_model_instance:
  312. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  313. embedding_price_info = embedding_model_type_instance.get_price(
  314. model=embedding_model_instance.model,
  315. credentials=embedding_model_instance.credentials,
  316. price_type=PriceType.INPUT,
  317. tokens=tokens
  318. )
  319. total_price = '{:f}'.format(embedding_price_info.total_amount)
  320. currency = embedding_price_info.currency
  321. return {
  322. "total_segments": total_segments,
  323. "tokens": tokens,
  324. "total_price": total_price,
  325. "currency": currency,
  326. "preview": preview_texts
  327. }
  328. def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
  329. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  330. indexing_technique: str = 'economy') -> dict:
  331. """
  332. Estimate the indexing for the document.
  333. """
  334. # check document limit
  335. features = FeatureService.get_features(tenant_id)
  336. if features.billing.enabled:
  337. count = len(notion_info_list)
  338. batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
  339. if count > batch_upload_limit:
  340. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  341. embedding_model_instance = None
  342. if dataset_id:
  343. dataset = Dataset.query.filter_by(
  344. id=dataset_id
  345. ).first()
  346. if not dataset:
  347. raise ValueError('Dataset not found.')
  348. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  349. if dataset.embedding_model_provider:
  350. embedding_model_instance = self.model_manager.get_model_instance(
  351. tenant_id=tenant_id,
  352. provider=dataset.embedding_model_provider,
  353. model_type=ModelType.TEXT_EMBEDDING,
  354. model=dataset.embedding_model
  355. )
  356. else:
  357. embedding_model_instance = self.model_manager.get_default_model_instance(
  358. tenant_id=tenant_id,
  359. model_type=ModelType.TEXT_EMBEDDING,
  360. )
  361. else:
  362. if indexing_technique == 'high_quality':
  363. embedding_model_instance = self.model_manager.get_default_model_instance(
  364. tenant_id=tenant_id,
  365. model_type=ModelType.TEXT_EMBEDDING
  366. )
  367. # load data from notion
  368. tokens = 0
  369. preview_texts = []
  370. total_segments = 0
  371. total_price = 0
  372. currency = 'USD'
  373. for notion_info in notion_info_list:
  374. workspace_id = notion_info['workspace_id']
  375. data_source_binding = DataSourceBinding.query.filter(
  376. db.and_(
  377. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  378. DataSourceBinding.provider == 'notion',
  379. DataSourceBinding.disabled == False,
  380. DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
  381. )
  382. ).first()
  383. if not data_source_binding:
  384. raise ValueError('Data source binding not found.')
  385. for page in notion_info['pages']:
  386. loader = NotionLoader(
  387. notion_access_token=data_source_binding.access_token,
  388. notion_workspace_id=workspace_id,
  389. notion_obj_id=page['page_id'],
  390. notion_page_type=page['type']
  391. )
  392. documents = loader.load()
  393. processing_rule = DatasetProcessRule(
  394. mode=tmp_processing_rule["mode"],
  395. rules=json.dumps(tmp_processing_rule["rules"])
  396. )
  397. # get splitter
  398. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  399. # split to documents
  400. documents = self._split_to_documents_for_estimate(
  401. text_docs=documents,
  402. splitter=splitter,
  403. processing_rule=processing_rule
  404. )
  405. total_segments += len(documents)
  406. embedding_model_type_instance = None
  407. if embedding_model_instance:
  408. embedding_model_type_instance = embedding_model_instance.model_type_instance
  409. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  410. for document in documents:
  411. if len(preview_texts) < 5:
  412. preview_texts.append(document.page_content)
  413. if indexing_technique == 'high_quality' and embedding_model_type_instance:
  414. tokens += embedding_model_type_instance.get_num_tokens(
  415. model=embedding_model_instance.model,
  416. credentials=embedding_model_instance.credentials,
  417. texts=[document.page_content]
  418. )
  419. if doc_form and doc_form == 'qa_model':
  420. model_instance = self.model_manager.get_default_model_instance(
  421. tenant_id=tenant_id,
  422. model_type=ModelType.LLM
  423. )
  424. model_type_instance = model_instance.model_type_instance
  425. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  426. if len(preview_texts) > 0:
  427. # qa model document
  428. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  429. doc_language)
  430. document_qa_list = self.format_split_text(response)
  431. price_info = model_type_instance.get_price(
  432. model=model_instance.model,
  433. credentials=model_instance.credentials,
  434. price_type=PriceType.INPUT,
  435. tokens=total_segments * 2000,
  436. )
  437. return {
  438. "total_segments": total_segments * 20,
  439. "tokens": total_segments * 2000,
  440. "total_price": '{:f}'.format(price_info.total_amount),
  441. "currency": price_info.currency,
  442. "qa_preview": document_qa_list,
  443. "preview": preview_texts
  444. }
  445. if embedding_model_instance:
  446. embedding_model_type_instance = embedding_model_instance.model_type_instance
  447. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  448. embedding_price_info = embedding_model_type_instance.get_price(
  449. model=embedding_model_instance.model,
  450. credentials=embedding_model_instance.credentials,
  451. price_type=PriceType.INPUT,
  452. tokens=tokens
  453. )
  454. total_price = '{:f}'.format(embedding_price_info.total_amount)
  455. currency = embedding_price_info.currency
  456. return {
  457. "total_segments": total_segments,
  458. "tokens": tokens,
  459. "total_price": total_price,
  460. "currency": currency,
  461. "preview": preview_texts
  462. }
  463. def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
  464. # load file
  465. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  466. return []
  467. data_source_info = dataset_document.data_source_info_dict
  468. text_docs = []
  469. if dataset_document.data_source_type == 'upload_file':
  470. if not data_source_info or 'upload_file_id' not in data_source_info:
  471. raise ValueError("no upload file found")
  472. file_detail = db.session.query(UploadFile). \
  473. filter(UploadFile.id == data_source_info['upload_file_id']). \
  474. one_or_none()
  475. if file_detail:
  476. text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
  477. elif dataset_document.data_source_type == 'notion_import':
  478. loader = NotionLoader.from_document(dataset_document)
  479. text_docs = loader.load()
  480. # update document status to splitting
  481. self._update_document_index_status(
  482. document_id=dataset_document.id,
  483. after_indexing_status="splitting",
  484. extra_update_params={
  485. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  486. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  487. }
  488. )
  489. # replace doc id to document model id
  490. text_docs = cast(list[Document], text_docs)
  491. for text_doc in text_docs:
  492. # remove invalid symbol
  493. text_doc.page_content = self.filter_string(text_doc.page_content)
  494. text_doc.metadata['document_id'] = dataset_document.id
  495. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  496. return text_docs
  497. def filter_string(self, text):
  498. text = re.sub(r'<\|', '<', text)
  499. text = re.sub(r'\|>', '>', text)
  500. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
  501. # Unicode U+FFFE
  502. text = re.sub('\uFFFE', '', text)
  503. return text
  504. def _get_splitter(self, processing_rule: DatasetProcessRule,
  505. embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
  506. """
  507. Get the NodeParser object according to the processing rule.
  508. """
  509. if processing_rule.mode == "custom":
  510. # The user-defined segmentation rule
  511. rules = json.loads(processing_rule.rules)
  512. segmentation = rules["segmentation"]
  513. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  514. raise ValueError("Custom segment length should be between 50 and 1000.")
  515. separator = segmentation["separator"]
  516. if separator:
  517. separator = separator.replace('\\n', '\n')
  518. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  519. chunk_size=segmentation["max_tokens"],
  520. chunk_overlap=segmentation.get('chunk_overlap', 0),
  521. fixed_separator=separator,
  522. separators=["\n\n", "。", ".", " ", ""],
  523. embedding_model_instance=embedding_model_instance
  524. )
  525. else:
  526. # Automatic segmentation
  527. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  528. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  529. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
  530. separators=["\n\n", "。", ".", " ", ""],
  531. embedding_model_instance=embedding_model_instance
  532. )
  533. return character_splitter
  534. def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
  535. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  536. -> list[Document]:
  537. """
  538. Split the text documents into documents and save them to the document segment.
  539. """
  540. documents = self._split_to_documents(
  541. text_docs=text_docs,
  542. splitter=splitter,
  543. processing_rule=processing_rule,
  544. tenant_id=dataset.tenant_id,
  545. document_form=dataset_document.doc_form,
  546. document_language=dataset_document.doc_language
  547. )
  548. # save node to document segment
  549. doc_store = DatasetDocumentStore(
  550. dataset=dataset,
  551. user_id=dataset_document.created_by,
  552. document_id=dataset_document.id
  553. )
  554. # add document segments
  555. doc_store.add_documents(documents)
  556. # update document status to indexing
  557. cur_time = datetime.datetime.utcnow()
  558. self._update_document_index_status(
  559. document_id=dataset_document.id,
  560. after_indexing_status="indexing",
  561. extra_update_params={
  562. DatasetDocument.cleaning_completed_at: cur_time,
  563. DatasetDocument.splitting_completed_at: cur_time,
  564. }
  565. )
  566. # update segment status to indexing
  567. self._update_segments_by_document(
  568. dataset_document_id=dataset_document.id,
  569. update_params={
  570. DocumentSegment.status: "indexing",
  571. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  572. }
  573. )
  574. return documents
  575. def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
  576. processing_rule: DatasetProcessRule, tenant_id: str,
  577. document_form: str, document_language: str) -> list[Document]:
  578. """
  579. Split the text documents into nodes.
  580. """
  581. all_documents = []
  582. all_qa_documents = []
  583. for text_doc in text_docs:
  584. # document clean
  585. document_text = self._document_clean(text_doc.page_content, processing_rule)
  586. text_doc.page_content = document_text
  587. # parse document to nodes
  588. documents = splitter.split_documents([text_doc])
  589. split_documents = []
  590. for document_node in documents:
  591. if document_node.page_content.strip():
  592. doc_id = str(uuid.uuid4())
  593. hash = helper.generate_text_hash(document_node.page_content)
  594. document_node.metadata['doc_id'] = doc_id
  595. document_node.metadata['doc_hash'] = hash
  596. # delete Spliter character
  597. page_content = document_node.page_content
  598. if page_content.startswith(".") or page_content.startswith("。"):
  599. page_content = page_content[1:]
  600. else:
  601. page_content = page_content
  602. document_node.page_content = page_content
  603. if document_node.page_content:
  604. split_documents.append(document_node)
  605. all_documents.extend(split_documents)
  606. # processing qa document
  607. if document_form == 'qa_model':
  608. for i in range(0, len(all_documents), 10):
  609. threads = []
  610. sub_documents = all_documents[i:i + 10]
  611. for doc in sub_documents:
  612. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  613. 'flask_app': current_app._get_current_object(),
  614. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  615. 'document_language': document_language})
  616. threads.append(document_format_thread)
  617. document_format_thread.start()
  618. for thread in threads:
  619. thread.join()
  620. return all_qa_documents
  621. return all_documents
  622. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  623. format_documents = []
  624. if document_node.page_content is None or not document_node.page_content.strip():
  625. return
  626. with flask_app.app_context():
  627. try:
  628. # qa model document
  629. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  630. document_qa_list = self.format_split_text(response)
  631. qa_documents = []
  632. for result in document_qa_list:
  633. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  634. doc_id = str(uuid.uuid4())
  635. hash = helper.generate_text_hash(result['question'])
  636. qa_document.metadata['answer'] = result['answer']
  637. qa_document.metadata['doc_id'] = doc_id
  638. qa_document.metadata['doc_hash'] = hash
  639. qa_documents.append(qa_document)
  640. format_documents.extend(qa_documents)
  641. except Exception as e:
  642. logging.exception(e)
  643. all_qa_documents.extend(format_documents)
  644. def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
  645. processing_rule: DatasetProcessRule) -> list[Document]:
  646. """
  647. Split the text documents into nodes.
  648. """
  649. all_documents = []
  650. for text_doc in text_docs:
  651. # document clean
  652. document_text = self._document_clean(text_doc.page_content, processing_rule)
  653. text_doc.page_content = document_text
  654. # parse document to nodes
  655. documents = splitter.split_documents([text_doc])
  656. split_documents = []
  657. for document in documents:
  658. if document.page_content is None or not document.page_content.strip():
  659. continue
  660. doc_id = str(uuid.uuid4())
  661. hash = helper.generate_text_hash(document.page_content)
  662. document.metadata['doc_id'] = doc_id
  663. document.metadata['doc_hash'] = hash
  664. split_documents.append(document)
  665. all_documents.extend(split_documents)
  666. return all_documents
  667. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  668. """
  669. Clean the document text according to the processing rules.
  670. """
  671. if processing_rule.mode == "automatic":
  672. rules = DatasetProcessRule.AUTOMATIC_RULES
  673. else:
  674. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  675. if 'pre_processing_rules' in rules:
  676. pre_processing_rules = rules["pre_processing_rules"]
  677. for pre_processing_rule in pre_processing_rules:
  678. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  679. # Remove extra spaces
  680. pattern = r'\n{3,}'
  681. text = re.sub(pattern, '\n\n', text)
  682. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  683. text = re.sub(pattern, ' ', text)
  684. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  685. # Remove email
  686. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  687. text = re.sub(pattern, '', text)
  688. # Remove URL
  689. pattern = r'https?://[^\s]+'
  690. text = re.sub(pattern, '', text)
  691. return text
  692. def format_split_text(self, text):
  693. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  694. matches = re.findall(regex, text, re.UNICODE)
  695. return [
  696. {
  697. "question": q,
  698. "answer": re.sub(r"\n\s*", "\n", a.strip())
  699. }
  700. for q, a in matches if q and a
  701. ]
  702. def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
  703. """
  704. Build the index for the document.
  705. """
  706. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  707. keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
  708. embedding_model_instance = None
  709. if dataset.indexing_technique == 'high_quality':
  710. embedding_model_instance = self.model_manager.get_model_instance(
  711. tenant_id=dataset.tenant_id,
  712. provider=dataset.embedding_model_provider,
  713. model_type=ModelType.TEXT_EMBEDDING,
  714. model=dataset.embedding_model
  715. )
  716. # chunk nodes by chunk size
  717. indexing_start_at = time.perf_counter()
  718. tokens = 0
  719. chunk_size = 100
  720. embedding_model_type_instance = None
  721. if embedding_model_instance:
  722. embedding_model_type_instance = embedding_model_instance.model_type_instance
  723. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  724. for i in range(0, len(documents), chunk_size):
  725. # check document is paused
  726. self._check_document_paused_status(dataset_document.id)
  727. chunk_documents = documents[i:i + chunk_size]
  728. if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
  729. tokens += sum(
  730. embedding_model_type_instance.get_num_tokens(
  731. embedding_model_instance.model,
  732. embedding_model_instance.credentials,
  733. [document.page_content]
  734. )
  735. for document in chunk_documents
  736. )
  737. # save vector index
  738. if vector_index:
  739. vector_index.add_texts(chunk_documents)
  740. # save keyword index
  741. keyword_table_index.add_texts(chunk_documents)
  742. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  743. db.session.query(DocumentSegment).filter(
  744. DocumentSegment.document_id == dataset_document.id,
  745. DocumentSegment.index_node_id.in_(document_ids),
  746. DocumentSegment.status == "indexing"
  747. ).update({
  748. DocumentSegment.status: "completed",
  749. DocumentSegment.enabled: True,
  750. DocumentSegment.completed_at: datetime.datetime.utcnow()
  751. })
  752. db.session.commit()
  753. indexing_end_at = time.perf_counter()
  754. # update document status to completed
  755. self._update_document_index_status(
  756. document_id=dataset_document.id,
  757. after_indexing_status="completed",
  758. extra_update_params={
  759. DatasetDocument.tokens: tokens,
  760. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  761. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  762. }
  763. )
  764. def _check_document_paused_status(self, document_id: str):
  765. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  766. result = redis_client.get(indexing_cache_key)
  767. if result:
  768. raise DocumentIsPausedException()
  769. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  770. extra_update_params: Optional[dict] = None) -> None:
  771. """
  772. Update the document indexing status.
  773. """
  774. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  775. if count > 0:
  776. raise DocumentIsPausedException()
  777. document = DatasetDocument.query.filter_by(id=document_id).first()
  778. if not document:
  779. raise DocumentIsDeletedPausedException()
  780. update_params = {
  781. DatasetDocument.indexing_status: after_indexing_status
  782. }
  783. if extra_update_params:
  784. update_params.update(extra_update_params)
  785. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  786. db.session.commit()
  787. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  788. """
  789. Update the document segment by document id.
  790. """
  791. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  792. db.session.commit()
  793. def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
  794. """
  795. Batch add segments index processing
  796. """
  797. documents = []
  798. for segment in segments:
  799. document = Document(
  800. page_content=segment.content,
  801. metadata={
  802. "doc_id": segment.index_node_id,
  803. "doc_hash": segment.index_node_hash,
  804. "document_id": segment.document_id,
  805. "dataset_id": segment.dataset_id,
  806. }
  807. )
  808. documents.append(document)
  809. # save vector index
  810. index = IndexBuilder.get_index(dataset, 'high_quality')
  811. if index:
  812. index.add_texts(documents, duplicate_check=True)
  813. # save keyword index
  814. index = IndexBuilder.get_index(dataset, 'economy')
  815. if index:
  816. index.add_texts(documents)
  817. class DocumentIsPausedException(Exception):
  818. pass
  819. class DocumentIsDeletedPausedException(Exception):
  820. pass