dataset_retrieval.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161
  1. import json
  2. import math
  3. import re
  4. import threading
  5. from collections import Counter, defaultdict
  6. from collections.abc import Generator, Mapping
  7. from typing import Any, Optional, Union, cast
  8. from flask import Flask, current_app
  9. from sqlalchemy import Integer, and_, or_, text
  10. from sqlalchemy import cast as sqlalchemy_cast
  11. from core.app.app_config.entities import (
  12. DatasetEntity,
  13. DatasetRetrieveConfigEntity,
  14. MetadataFilteringCondition,
  15. ModelConfig,
  16. )
  17. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  18. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  19. from core.entities.agent_entities import PlanningStrategy
  20. from core.entities.model_entities import ModelStatus
  21. from core.memory.token_buffer_memory import TokenBufferMemory
  22. from core.model_manager import ModelInstance, ModelManager
  23. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
  24. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
  25. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  26. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  27. from core.ops.entities.trace_entity import TraceTaskName
  28. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  29. from core.ops.utils import measure_time
  30. from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
  31. from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
  32. from core.prompt.simple_prompt_transform import ModelMode
  33. from core.rag.data_post_processor.data_post_processor import DataPostProcessor
  34. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  35. from core.rag.datasource.retrieval_service import RetrievalService
  36. from core.rag.entities.context_entities import DocumentContext
  37. from core.rag.entities.metadata_entities import Condition, MetadataCondition
  38. from core.rag.index_processor.constant.index_type import IndexType
  39. from core.rag.models.document import Document
  40. from core.rag.rerank.rerank_type import RerankMode
  41. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  42. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  43. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  44. from core.rag.retrieval.template_prompts import (
  45. METADATA_FILTER_ASSISTANT_PROMPT_1,
  46. METADATA_FILTER_ASSISTANT_PROMPT_2,
  47. METADATA_FILTER_COMPLETION_PROMPT,
  48. METADATA_FILTER_SYSTEM_PROMPT,
  49. METADATA_FILTER_USER_PROMPT_1,
  50. METADATA_FILTER_USER_PROMPT_2,
  51. METADATA_FILTER_USER_PROMPT_3,
  52. )
  53. from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  54. from extensions.ext_database import db
  55. from libs.json_in_md_parser import parse_and_check_json_markdown
  56. from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
  57. from models.dataset import Document as DatasetDocument
  58. from services.external_knowledge_service import ExternalDatasetService
  59. default_retrieval_model: dict[str, Any] = {
  60. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  61. "reranking_enable": False,
  62. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  63. "top_k": 2,
  64. "score_threshold_enabled": False,
  65. }
  66. class DatasetRetrieval:
  67. def __init__(self, application_generate_entity=None):
  68. self.application_generate_entity = application_generate_entity
  69. def retrieve(
  70. self,
  71. app_id: str,
  72. user_id: str,
  73. tenant_id: str,
  74. model_config: ModelConfigWithCredentialsEntity,
  75. config: DatasetEntity,
  76. query: str,
  77. invoke_from: InvokeFrom,
  78. show_retrieve_source: bool,
  79. hit_callback: DatasetIndexToolCallbackHandler,
  80. message_id: str,
  81. memory: Optional[TokenBufferMemory] = None,
  82. inputs: Optional[Mapping[str, Any]] = None,
  83. ) -> Optional[str]:
  84. """
  85. Retrieve dataset.
  86. :param app_id: app_id
  87. :param user_id: user_id
  88. :param tenant_id: tenant id
  89. :param model_config: model config
  90. :param config: dataset config
  91. :param query: query
  92. :param invoke_from: invoke from
  93. :param show_retrieve_source: show retrieve source
  94. :param hit_callback: hit callback
  95. :param message_id: message id
  96. :param memory: memory
  97. :return:
  98. """
  99. dataset_ids = config.dataset_ids
  100. if len(dataset_ids) == 0:
  101. return None
  102. retrieve_config = config.retrieve_config
  103. # check model is support tool calling
  104. model_type_instance = model_config.provider_model_bundle.model_type_instance
  105. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  106. model_manager = ModelManager()
  107. model_instance = model_manager.get_model_instance(
  108. tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
  109. )
  110. # get model schema
  111. model_schema = model_type_instance.get_model_schema(
  112. model=model_config.model, credentials=model_config.credentials
  113. )
  114. if not model_schema:
  115. return None
  116. planning_strategy = PlanningStrategy.REACT_ROUTER
  117. features = model_schema.features
  118. if features:
  119. if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
  120. planning_strategy = PlanningStrategy.ROUTER
  121. available_datasets = []
  122. for dataset_id in dataset_ids:
  123. # get dataset from dataset id
  124. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  125. # pass if dataset is not available
  126. if not dataset:
  127. continue
  128. # pass if dataset is not available
  129. if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
  130. continue
  131. available_datasets.append(dataset)
  132. if inputs:
  133. inputs = {key: str(value) for key, value in inputs.items()}
  134. else:
  135. inputs = {}
  136. available_datasets_ids = [dataset.id for dataset in available_datasets]
  137. metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
  138. available_datasets_ids,
  139. query,
  140. tenant_id,
  141. user_id,
  142. retrieve_config.metadata_filtering_mode, # type: ignore
  143. retrieve_config.metadata_model_config, # type: ignore
  144. retrieve_config.metadata_filtering_conditions,
  145. inputs,
  146. )
  147. all_documents = []
  148. user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
  149. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  150. all_documents = self.single_retrieve(
  151. app_id,
  152. tenant_id,
  153. user_id,
  154. user_from,
  155. available_datasets,
  156. query,
  157. model_instance,
  158. model_config,
  159. planning_strategy,
  160. message_id,
  161. metadata_filter_document_ids,
  162. metadata_condition,
  163. )
  164. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  165. all_documents = self.multiple_retrieve(
  166. app_id,
  167. tenant_id,
  168. user_id,
  169. user_from,
  170. available_datasets,
  171. query,
  172. retrieve_config.top_k or 0,
  173. retrieve_config.score_threshold or 0,
  174. retrieve_config.rerank_mode or "reranking_model",
  175. retrieve_config.reranking_model,
  176. retrieve_config.weights,
  177. retrieve_config.reranking_enabled or True,
  178. message_id,
  179. metadata_filter_document_ids,
  180. metadata_condition,
  181. )
  182. dify_documents = [item for item in all_documents if item.provider == "dify"]
  183. external_documents = [item for item in all_documents if item.provider == "external"]
  184. document_context_list = []
  185. retrieval_resource_list = []
  186. # deal with external documents
  187. for item in external_documents:
  188. document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
  189. source = {
  190. "dataset_id": item.metadata.get("dataset_id"),
  191. "dataset_name": item.metadata.get("dataset_name"),
  192. "document_name": item.metadata.get("title"),
  193. "data_source_type": "external",
  194. "retriever_from": invoke_from.to_source(),
  195. "score": item.metadata.get("score"),
  196. "content": item.page_content,
  197. }
  198. retrieval_resource_list.append(source)
  199. # deal with dify documents
  200. if dify_documents:
  201. records = RetrievalService.format_retrieval_documents(dify_documents)
  202. if records:
  203. for record in records:
  204. segment = record.segment
  205. if segment.answer:
  206. document_context_list.append(
  207. DocumentContext(
  208. content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
  209. score=record.score,
  210. )
  211. )
  212. else:
  213. document_context_list.append(
  214. DocumentContext(
  215. content=segment.get_sign_content(),
  216. score=record.score,
  217. )
  218. )
  219. if show_retrieve_source:
  220. for record in records:
  221. segment = record.segment
  222. dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
  223. document = DatasetDocument.query.filter(
  224. DatasetDocument.id == segment.document_id,
  225. DatasetDocument.enabled == True,
  226. DatasetDocument.archived == False,
  227. ).first()
  228. if dataset and document:
  229. source = {
  230. "dataset_id": dataset.id,
  231. "dataset_name": dataset.name,
  232. "document_id": document.id,
  233. "document_name": document.name,
  234. "data_source_type": document.data_source_type,
  235. "segment_id": segment.id,
  236. "retriever_from": invoke_from.to_source(),
  237. "score": record.score or 0.0,
  238. "doc_metadata": document.doc_metadata,
  239. }
  240. if invoke_from.to_source() == "dev":
  241. source["hit_count"] = segment.hit_count
  242. source["word_count"] = segment.word_count
  243. source["segment_position"] = segment.position
  244. source["index_node_hash"] = segment.index_node_hash
  245. if segment.answer:
  246. source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
  247. else:
  248. source["content"] = segment.content
  249. retrieval_resource_list.append(source)
  250. if hit_callback and retrieval_resource_list:
  251. retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
  252. for position, item in enumerate(retrieval_resource_list, start=1):
  253. item["position"] = position
  254. hit_callback.return_retriever_resource_info(retrieval_resource_list)
  255. if document_context_list:
  256. document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
  257. return str("\n".join([document_context.content for document_context in document_context_list]))
  258. return ""
  259. def single_retrieve(
  260. self,
  261. app_id: str,
  262. tenant_id: str,
  263. user_id: str,
  264. user_from: str,
  265. available_datasets: list,
  266. query: str,
  267. model_instance: ModelInstance,
  268. model_config: ModelConfigWithCredentialsEntity,
  269. planning_strategy: PlanningStrategy,
  270. message_id: Optional[str] = None,
  271. metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
  272. metadata_condition: Optional[MetadataCondition] = None,
  273. ):
  274. tools = []
  275. for dataset in available_datasets:
  276. description = dataset.description
  277. if not description:
  278. description = "useful for when you want to answer queries about the " + dataset.name
  279. description = description.replace("\n", "").replace("\r", "")
  280. message_tool = PromptMessageTool(
  281. name=dataset.id,
  282. description=description,
  283. parameters={
  284. "type": "object",
  285. "properties": {},
  286. "required": [],
  287. },
  288. )
  289. tools.append(message_tool)
  290. dataset_id = None
  291. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  292. react_multi_dataset_router = ReactMultiDatasetRouter()
  293. dataset_id = react_multi_dataset_router.invoke(
  294. query, tools, model_config, model_instance, user_id, tenant_id
  295. )
  296. elif planning_strategy == PlanningStrategy.ROUTER:
  297. function_call_router = FunctionCallMultiDatasetRouter()
  298. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  299. if dataset_id:
  300. # get retrieval model config
  301. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  302. if dataset:
  303. results = []
  304. if dataset.provider == "external":
  305. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  306. tenant_id=dataset.tenant_id,
  307. dataset_id=dataset_id,
  308. query=query,
  309. external_retrieval_parameters=dataset.retrieval_model,
  310. metadata_condition=metadata_condition,
  311. )
  312. for external_document in external_documents:
  313. document = Document(
  314. page_content=external_document.get("content"),
  315. metadata=external_document.get("metadata"),
  316. provider="external",
  317. )
  318. if document.metadata is not None:
  319. document.metadata["score"] = external_document.get("score")
  320. document.metadata["title"] = external_document.get("title")
  321. document.metadata["dataset_id"] = dataset_id
  322. document.metadata["dataset_name"] = dataset.name
  323. results.append(document)
  324. else:
  325. if metadata_condition and not metadata_filter_document_ids:
  326. return []
  327. document_ids_filter = None
  328. if metadata_filter_document_ids:
  329. document_ids = metadata_filter_document_ids.get(dataset.id, [])
  330. if document_ids:
  331. document_ids_filter = document_ids
  332. else:
  333. return []
  334. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  335. # get top k
  336. top_k = retrieval_model_config["top_k"]
  337. # get retrieval method
  338. if dataset.indexing_technique == "economy":
  339. retrieval_method = "keyword_search"
  340. else:
  341. retrieval_method = retrieval_model_config["search_method"]
  342. # get reranking model
  343. reranking_model = (
  344. retrieval_model_config["reranking_model"]
  345. if retrieval_model_config["reranking_enable"]
  346. else None
  347. )
  348. # get score threshold
  349. score_threshold = 0.0
  350. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  351. if score_threshold_enabled:
  352. score_threshold = retrieval_model_config.get("score_threshold", 0.0)
  353. with measure_time() as timer:
  354. results = RetrievalService.retrieve(
  355. retrieval_method=retrieval_method,
  356. dataset_id=dataset.id,
  357. query=query,
  358. top_k=top_k,
  359. score_threshold=score_threshold,
  360. reranking_model=reranking_model,
  361. reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
  362. weights=retrieval_model_config.get("weights", None),
  363. document_ids_filter=document_ids_filter,
  364. )
  365. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  366. if results:
  367. self._on_retrieval_end(results, message_id, timer)
  368. return results
  369. return []
  370. def multiple_retrieve(
  371. self,
  372. app_id: str,
  373. tenant_id: str,
  374. user_id: str,
  375. user_from: str,
  376. available_datasets: list,
  377. query: str,
  378. top_k: int,
  379. score_threshold: float,
  380. reranking_mode: str,
  381. reranking_model: Optional[dict] = None,
  382. weights: Optional[dict[str, Any]] = None,
  383. reranking_enable: bool = True,
  384. message_id: Optional[str] = None,
  385. metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
  386. metadata_condition: Optional[MetadataCondition] = None,
  387. ):
  388. if not available_datasets:
  389. return []
  390. threads = []
  391. all_documents: list[Document] = []
  392. dataset_ids = [dataset.id for dataset in available_datasets]
  393. index_type_check = all(
  394. item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
  395. )
  396. if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
  397. raise ValueError(
  398. "The configured knowledge base list have different indexing technique, please set reranking model."
  399. )
  400. index_type = available_datasets[0].indexing_technique
  401. if index_type == "high_quality":
  402. embedding_model_check = all(
  403. item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
  404. )
  405. embedding_model_provider_check = all(
  406. item.embedding_model_provider == available_datasets[0].embedding_model_provider
  407. for item in available_datasets
  408. )
  409. if (
  410. reranking_enable
  411. and reranking_mode == "weighted_score"
  412. and (not embedding_model_check or not embedding_model_provider_check)
  413. ):
  414. raise ValueError(
  415. "The configured knowledge base list have different embedding model, please set reranking model."
  416. )
  417. if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
  418. if weights is not None:
  419. weights["vector_setting"]["embedding_provider_name"] = available_datasets[
  420. 0
  421. ].embedding_model_provider
  422. weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
  423. for dataset in available_datasets:
  424. index_type = dataset.indexing_technique
  425. document_ids_filter = None
  426. if dataset.provider != "external":
  427. if metadata_condition and not metadata_filter_document_ids:
  428. continue
  429. if metadata_filter_document_ids:
  430. document_ids = metadata_filter_document_ids.get(dataset.id, [])
  431. if document_ids:
  432. document_ids_filter = document_ids
  433. else:
  434. continue
  435. retrieval_thread = threading.Thread(
  436. target=self._retriever,
  437. kwargs={
  438. "flask_app": current_app._get_current_object(), # type: ignore
  439. "dataset_id": dataset.id,
  440. "query": query,
  441. "top_k": top_k,
  442. "all_documents": all_documents,
  443. "document_ids_filter": document_ids_filter,
  444. "metadata_condition": metadata_condition,
  445. },
  446. )
  447. threads.append(retrieval_thread)
  448. retrieval_thread.start()
  449. for thread in threads:
  450. thread.join()
  451. with measure_time() as timer:
  452. if reranking_enable:
  453. # do rerank for searched documents
  454. data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
  455. all_documents = data_post_processor.invoke(
  456. query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
  457. )
  458. else:
  459. if index_type == "economy":
  460. all_documents = self.calculate_keyword_score(query, all_documents, top_k)
  461. elif index_type == "high_quality":
  462. all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
  463. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  464. if all_documents:
  465. self._on_retrieval_end(all_documents, message_id, timer)
  466. return all_documents
  467. def _on_retrieval_end(
  468. self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
  469. ) -> None:
  470. """Handle retrieval end."""
  471. dify_documents = [document for document in documents if document.provider == "dify"]
  472. for document in dify_documents:
  473. if document.metadata is not None:
  474. dataset_document = DatasetDocument.query.filter(
  475. DatasetDocument.id == document.metadata["document_id"]
  476. ).first()
  477. if dataset_document:
  478. if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
  479. child_chunk = ChildChunk.query.filter(
  480. ChildChunk.index_node_id == document.metadata["doc_id"],
  481. ChildChunk.dataset_id == dataset_document.dataset_id,
  482. ChildChunk.document_id == dataset_document.id,
  483. ).first()
  484. if child_chunk:
  485. segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
  486. {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
  487. )
  488. db.session.commit()
  489. else:
  490. query = db.session.query(DocumentSegment).filter(
  491. DocumentSegment.index_node_id == document.metadata["doc_id"]
  492. )
  493. # if 'dataset_id' in document.metadata:
  494. if "dataset_id" in document.metadata:
  495. query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
  496. # add hit count to document segment
  497. query.update(
  498. {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
  499. )
  500. db.session.commit()
  501. # get tracing instance
  502. trace_manager: TraceQueueManager | None = (
  503. self.application_generate_entity.trace_manager if self.application_generate_entity else None
  504. )
  505. if trace_manager:
  506. trace_manager.add_trace_task(
  507. TraceTask(
  508. TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
  509. )
  510. )
  511. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  512. """
  513. Handle query.
  514. """
  515. if not query:
  516. return
  517. dataset_queries = []
  518. for dataset_id in dataset_ids:
  519. dataset_query = DatasetQuery(
  520. dataset_id=dataset_id,
  521. content=query,
  522. source="app",
  523. source_app_id=app_id,
  524. created_by_role=user_from,
  525. created_by=user_id,
  526. )
  527. dataset_queries.append(dataset_query)
  528. if dataset_queries:
  529. db.session.add_all(dataset_queries)
  530. db.session.commit()
  531. def _retriever(
  532. self,
  533. flask_app: Flask,
  534. dataset_id: str,
  535. query: str,
  536. top_k: int,
  537. all_documents: list,
  538. document_ids_filter: Optional[list[str]] = None,
  539. metadata_condition: Optional[MetadataCondition] = None,
  540. ):
  541. with flask_app.app_context():
  542. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  543. if not dataset:
  544. return []
  545. if dataset.provider == "external":
  546. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  547. tenant_id=dataset.tenant_id,
  548. dataset_id=dataset_id,
  549. query=query,
  550. external_retrieval_parameters=dataset.retrieval_model,
  551. metadata_condition=metadata_condition,
  552. )
  553. for external_document in external_documents:
  554. document = Document(
  555. page_content=external_document.get("content"),
  556. metadata=external_document.get("metadata"),
  557. provider="external",
  558. )
  559. if document.metadata is not None:
  560. document.metadata["score"] = external_document.get("score")
  561. document.metadata["title"] = external_document.get("title")
  562. document.metadata["dataset_id"] = dataset_id
  563. document.metadata["dataset_name"] = dataset.name
  564. all_documents.append(document)
  565. else:
  566. # get retrieval model , if the model is not setting , using default
  567. retrieval_model = dataset.retrieval_model or default_retrieval_model
  568. if dataset.indexing_technique == "economy":
  569. # use keyword table query
  570. documents = RetrievalService.retrieve(
  571. retrieval_method="keyword_search",
  572. dataset_id=dataset.id,
  573. query=query,
  574. top_k=top_k,
  575. document_ids_filter=document_ids_filter,
  576. )
  577. if documents:
  578. all_documents.extend(documents)
  579. else:
  580. if top_k > 0:
  581. # retrieval source
  582. documents = RetrievalService.retrieve(
  583. retrieval_method=retrieval_model["search_method"],
  584. dataset_id=dataset.id,
  585. query=query,
  586. top_k=retrieval_model.get("top_k") or 2,
  587. score_threshold=retrieval_model.get("score_threshold", 0.0)
  588. if retrieval_model["score_threshold_enabled"]
  589. else 0.0,
  590. reranking_model=retrieval_model.get("reranking_model", None)
  591. if retrieval_model["reranking_enable"]
  592. else None,
  593. reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
  594. weights=retrieval_model.get("weights", None),
  595. document_ids_filter=document_ids_filter,
  596. )
  597. all_documents.extend(documents)
  598. def to_dataset_retriever_tool(
  599. self,
  600. tenant_id: str,
  601. dataset_ids: list[str],
  602. retrieve_config: DatasetRetrieveConfigEntity,
  603. return_resource: bool,
  604. invoke_from: InvokeFrom,
  605. hit_callback: DatasetIndexToolCallbackHandler,
  606. ) -> Optional[list[DatasetRetrieverBaseTool]]:
  607. """
  608. A dataset tool is a tool that can be used to retrieve information from a dataset
  609. :param tenant_id: tenant id
  610. :param dataset_ids: dataset ids
  611. :param retrieve_config: retrieve config
  612. :param return_resource: return resource
  613. :param invoke_from: invoke from
  614. :param hit_callback: hit callback
  615. """
  616. tools = []
  617. available_datasets = []
  618. for dataset_id in dataset_ids:
  619. # get dataset from dataset id
  620. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  621. # pass if dataset is not available
  622. if not dataset:
  623. continue
  624. # pass if dataset is not available
  625. if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
  626. continue
  627. available_datasets.append(dataset)
  628. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  629. # get retrieval model config
  630. default_retrieval_model = {
  631. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  632. "reranking_enable": False,
  633. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  634. "top_k": 2,
  635. "score_threshold_enabled": False,
  636. }
  637. for dataset in available_datasets:
  638. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  639. # get top k
  640. top_k = retrieval_model_config["top_k"]
  641. # get score threshold
  642. score_threshold = None
  643. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  644. if score_threshold_enabled:
  645. score_threshold = retrieval_model_config.get("score_threshold")
  646. from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  647. tool = DatasetRetrieverTool.from_dataset(
  648. dataset=dataset,
  649. top_k=top_k,
  650. score_threshold=score_threshold,
  651. hit_callbacks=[hit_callback],
  652. return_resource=return_resource,
  653. retriever_from=invoke_from.to_source(),
  654. )
  655. tools.append(tool)
  656. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  657. from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  658. if retrieve_config.reranking_model is None:
  659. raise ValueError("Reranking model is required for multiple retrieval")
  660. tool = DatasetMultiRetrieverTool.from_dataset(
  661. dataset_ids=[dataset.id for dataset in available_datasets],
  662. tenant_id=tenant_id,
  663. top_k=retrieve_config.top_k or 2,
  664. score_threshold=retrieve_config.score_threshold,
  665. hit_callbacks=[hit_callback],
  666. return_resource=return_resource,
  667. retriever_from=invoke_from.to_source(),
  668. reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
  669. reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
  670. )
  671. tools.append(tool)
  672. return tools
  673. def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
  674. """
  675. Calculate keywords scores
  676. :param query: search query
  677. :param documents: documents for reranking
  678. :return:
  679. """
  680. keyword_table_handler = JiebaKeywordTableHandler()
  681. query_keywords = keyword_table_handler.extract_keywords(query, None)
  682. documents_keywords = []
  683. for document in documents:
  684. if document.metadata is not None:
  685. # get the document keywords
  686. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  687. document.metadata["keywords"] = document_keywords
  688. documents_keywords.append(document_keywords)
  689. # Counter query keywords(TF)
  690. query_keyword_counts = Counter(query_keywords)
  691. # total documents
  692. total_documents = len(documents)
  693. # calculate all documents' keywords IDF
  694. all_keywords = set()
  695. for document_keywords in documents_keywords:
  696. all_keywords.update(document_keywords)
  697. keyword_idf = {}
  698. for keyword in all_keywords:
  699. # calculate include query keywords' documents
  700. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  701. # IDF
  702. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  703. query_tfidf = {}
  704. for keyword, count in query_keyword_counts.items():
  705. tf = count
  706. idf = keyword_idf.get(keyword, 0)
  707. query_tfidf[keyword] = tf * idf
  708. # calculate all documents' TF-IDF
  709. documents_tfidf = []
  710. for document_keywords in documents_keywords:
  711. document_keyword_counts = Counter(document_keywords)
  712. document_tfidf = {}
  713. for keyword, count in document_keyword_counts.items():
  714. tf = count
  715. idf = keyword_idf.get(keyword, 0)
  716. document_tfidf[keyword] = tf * idf
  717. documents_tfidf.append(document_tfidf)
  718. def cosine_similarity(vec1, vec2):
  719. intersection = set(vec1.keys()) & set(vec2.keys())
  720. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  721. sum1 = sum(vec1[x] ** 2 for x in vec1)
  722. sum2 = sum(vec2[x] ** 2 for x in vec2)
  723. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  724. if not denominator:
  725. return 0.0
  726. else:
  727. return float(numerator) / denominator
  728. similarities = []
  729. for document_tfidf in documents_tfidf:
  730. similarity = cosine_similarity(query_tfidf, document_tfidf)
  731. similarities.append(similarity)
  732. for document, score in zip(documents, similarities):
  733. # format document
  734. if document.metadata is not None:
  735. document.metadata["score"] = score
  736. documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
  737. return documents[:top_k] if top_k else documents
  738. def calculate_vector_score(
  739. self, all_documents: list[Document], top_k: int, score_threshold: float
  740. ) -> list[Document]:
  741. filter_documents = []
  742. for document in all_documents:
  743. if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
  744. filter_documents.append(document)
  745. if not filter_documents:
  746. return []
  747. filter_documents = sorted(
  748. filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
  749. )
  750. return filter_documents[:top_k] if top_k else filter_documents
  751. def _get_metadata_filter_condition(
  752. self,
  753. dataset_ids: list,
  754. query: str,
  755. tenant_id: str,
  756. user_id: str,
  757. metadata_filtering_mode: str,
  758. metadata_model_config: ModelConfig,
  759. metadata_filtering_conditions: Optional[MetadataFilteringCondition],
  760. inputs: dict,
  761. ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
  762. document_query = db.session.query(DatasetDocument).filter(
  763. DatasetDocument.dataset_id.in_(dataset_ids),
  764. DatasetDocument.indexing_status == "completed",
  765. DatasetDocument.enabled == True,
  766. DatasetDocument.archived == False,
  767. )
  768. filters = [] # type: ignore
  769. metadata_condition = None
  770. if metadata_filtering_mode == "disabled":
  771. return None, None
  772. elif metadata_filtering_mode == "automatic":
  773. automatic_metadata_filters = self._automatic_metadata_filter_func(
  774. dataset_ids, query, tenant_id, user_id, metadata_model_config
  775. )
  776. if automatic_metadata_filters:
  777. conditions = []
  778. for filter in automatic_metadata_filters:
  779. self._process_metadata_filter_func(
  780. filter.get("condition"), # type: ignore
  781. filter.get("metadata_name"), # type: ignore
  782. filter.get("value"),
  783. filters, # type: ignore
  784. )
  785. conditions.append(
  786. Condition(
  787. name=filter.get("metadata_name"), # type: ignore
  788. comparison_operator=filter.get("condition"), # type: ignore
  789. value=filter.get("value"),
  790. )
  791. )
  792. metadata_condition = MetadataCondition(
  793. logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
  794. conditions=conditions,
  795. )
  796. elif metadata_filtering_mode == "manual":
  797. if metadata_filtering_conditions:
  798. metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
  799. for condition in metadata_filtering_conditions.conditions: # type: ignore
  800. metadata_name = condition.name
  801. expected_value = condition.value
  802. if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
  803. if isinstance(expected_value, str):
  804. expected_value = self._replace_metadata_filter_value(expected_value, inputs)
  805. filters = self._process_metadata_filter_func(
  806. condition.comparison_operator, metadata_name, expected_value, filters
  807. )
  808. else:
  809. raise ValueError("Invalid metadata filtering mode")
  810. if filters:
  811. if metadata_filtering_conditions.logical_operator == "or": # type: ignore
  812. document_query = document_query.filter(or_(*filters))
  813. else:
  814. document_query = document_query.filter(and_(*filters))
  815. documents = document_query.all()
  816. # group by dataset_id
  817. metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
  818. for document in documents:
  819. metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
  820. return metadata_filter_document_ids, metadata_condition
  821. def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
  822. def replacer(match):
  823. key = match.group(1)
  824. return str(inputs.get(key, f"{{{{{key}}}}}"))
  825. pattern = re.compile(r"\{\{(\w+)\}\}")
  826. return pattern.sub(replacer, text)
  827. def _automatic_metadata_filter_func(
  828. self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
  829. ) -> Optional[list[dict[str, Any]]]:
  830. # get all metadata field
  831. metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
  832. all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
  833. # get metadata model config
  834. if metadata_model_config is None:
  835. raise ValueError("metadata_model_config is required")
  836. # get metadata model instance
  837. # fetch model config
  838. model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
  839. # fetch prompt messages
  840. prompt_messages, stop = self._get_prompt_template(
  841. model_config=model_config,
  842. mode=metadata_model_config.mode,
  843. metadata_fields=all_metadata_fields,
  844. query=query or "",
  845. )
  846. result_text = ""
  847. try:
  848. # handle invoke result
  849. invoke_result = cast(
  850. Generator[LLMResult, None, None],
  851. model_instance.invoke_llm(
  852. prompt_messages=prompt_messages,
  853. model_parameters=model_config.parameters,
  854. stop=stop,
  855. stream=True,
  856. user=user_id,
  857. ),
  858. )
  859. # handle invoke result
  860. result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
  861. result_text_json = parse_and_check_json_markdown(result_text, [])
  862. automatic_metadata_filters = []
  863. if "metadata_map" in result_text_json:
  864. metadata_map = result_text_json["metadata_map"]
  865. for item in metadata_map:
  866. if item.get("metadata_field_name") in all_metadata_fields:
  867. automatic_metadata_filters.append(
  868. {
  869. "metadata_name": item.get("metadata_field_name"),
  870. "value": item.get("metadata_field_value"),
  871. "condition": item.get("comparison_operator"),
  872. }
  873. )
  874. except Exception as e:
  875. return None
  876. return automatic_metadata_filters
  877. def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
  878. match condition:
  879. case "contains":
  880. filters.append(
  881. (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
  882. )
  883. case "not contains":
  884. filters.append(
  885. (text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
  886. key=metadata_name, value=f"%{value}%"
  887. )
  888. )
  889. case "start with":
  890. filters.append(
  891. (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
  892. )
  893. case "end with":
  894. filters.append(
  895. (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
  896. )
  897. case "is" | "=":
  898. if isinstance(value, str):
  899. filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
  900. else:
  901. filters.append(
  902. sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
  903. )
  904. case "is not" | "≠":
  905. if isinstance(value, str):
  906. filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
  907. else:
  908. filters.append(
  909. sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
  910. )
  911. case "empty":
  912. filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
  913. case "not empty":
  914. filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
  915. case "before" | "<":
  916. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
  917. case "after" | ">":
  918. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
  919. case "≤" | ">=":
  920. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
  921. case "≥" | ">=":
  922. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
  923. case _:
  924. pass
  925. return filters
  926. def _fetch_model_config(
  927. self, tenant_id: str, model: ModelConfig
  928. ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
  929. """
  930. Fetch model config
  931. :param node_data: node data
  932. :return:
  933. """
  934. if model is None:
  935. raise ValueError("single_retrieval_config is required")
  936. model_name = model.name
  937. provider_name = model.provider
  938. model_manager = ModelManager()
  939. model_instance = model_manager.get_model_instance(
  940. tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
  941. )
  942. provider_model_bundle = model_instance.provider_model_bundle
  943. model_type_instance = model_instance.model_type_instance
  944. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  945. model_credentials = model_instance.credentials
  946. # check model
  947. provider_model = provider_model_bundle.configuration.get_provider_model(
  948. model=model_name, model_type=ModelType.LLM
  949. )
  950. if provider_model is None:
  951. raise ValueError(f"Model {model_name} not exist.")
  952. if provider_model.status == ModelStatus.NO_CONFIGURE:
  953. raise ValueError(f"Model {model_name} credentials is not initialized.")
  954. elif provider_model.status == ModelStatus.NO_PERMISSION:
  955. raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
  956. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  957. raise ValueError(f"Model provider {provider_name} quota exceeded.")
  958. # model config
  959. completion_params = model.completion_params
  960. stop = []
  961. if "stop" in completion_params:
  962. stop = completion_params["stop"]
  963. del completion_params["stop"]
  964. # get model mode
  965. model_mode = model.mode
  966. if not model_mode:
  967. raise ValueError("LLM mode is required.")
  968. model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
  969. if not model_schema:
  970. raise ValueError(f"Model {model_name} not exist.")
  971. return model_instance, ModelConfigWithCredentialsEntity(
  972. provider=provider_name,
  973. model=model_name,
  974. model_schema=model_schema,
  975. mode=model_mode,
  976. provider_model_bundle=provider_model_bundle,
  977. credentials=model_credentials,
  978. parameters=completion_params,
  979. stop=stop,
  980. )
  981. def _get_prompt_template(
  982. self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
  983. ):
  984. model_mode = ModelMode.value_of(mode)
  985. input_text = query
  986. prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
  987. if model_mode == ModelMode.CHAT:
  988. prompt_template = []
  989. system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
  990. prompt_template.append(system_prompt_messages)
  991. user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
  992. prompt_template.append(user_prompt_message_1)
  993. assistant_prompt_message_1 = ChatModelMessage(
  994. role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
  995. )
  996. prompt_template.append(assistant_prompt_message_1)
  997. user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
  998. prompt_template.append(user_prompt_message_2)
  999. assistant_prompt_message_2 = ChatModelMessage(
  1000. role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
  1001. )
  1002. prompt_template.append(assistant_prompt_message_2)
  1003. user_prompt_message_3 = ChatModelMessage(
  1004. role=PromptMessageRole.USER,
  1005. text=METADATA_FILTER_USER_PROMPT_3.format(
  1006. input_text=input_text,
  1007. metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
  1008. ),
  1009. )
  1010. prompt_template.append(user_prompt_message_3)
  1011. elif model_mode == ModelMode.COMPLETION:
  1012. prompt_template = CompletionModelPromptTemplate(
  1013. text=METADATA_FILTER_COMPLETION_PROMPT.format(
  1014. input_text=input_text,
  1015. metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
  1016. )
  1017. )
  1018. else:
  1019. raise ValueError(f"Model mode {model_mode} not support.")
  1020. prompt_transform = AdvancedPromptTransform()
  1021. prompt_messages = prompt_transform.get_prompt(
  1022. prompt_template=prompt_template,
  1023. inputs={},
  1024. query=query or "",
  1025. files=[],
  1026. context=None,
  1027. memory_config=None,
  1028. memory=None,
  1029. model_config=model_config,
  1030. )
  1031. stop = model_config.stop
  1032. return prompt_messages, stop
  1033. def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
  1034. """
  1035. Handle invoke result
  1036. :param invoke_result: invoke result
  1037. :return:
  1038. """
  1039. model = None
  1040. prompt_messages: list[PromptMessage] = []
  1041. full_text = ""
  1042. usage = None
  1043. for result in invoke_result:
  1044. text = result.delta.message.content
  1045. full_text += text
  1046. if not model:
  1047. model = result.model
  1048. if not prompt_messages:
  1049. prompt_messages = result.prompt_messages
  1050. if not usage and result.delta.usage:
  1051. usage = result.delta.usage
  1052. if not usage:
  1053. usage = LLMUsage.empty_usage()
  1054. return full_text, usage