dataset.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. import base64
  2. import enum
  3. import hashlib
  4. import hmac
  5. import json
  6. import logging
  7. import os
  8. import pickle
  9. import re
  10. import time
  11. from json import JSONDecodeError
  12. from typing import Any, cast
  13. from sqlalchemy import func
  14. from sqlalchemy.dialects.postgresql import JSONB
  15. from configs import dify_config
  16. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  17. from extensions.ext_storage import storage
  18. from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
  19. from .account import Account
  20. from .engine import db
  21. from .model import App, Tag, TagBinding, UploadFile
  22. from .types import StringUUID
  23. class DatasetPermissionEnum(enum.StrEnum):
  24. ONLY_ME = "only_me"
  25. ALL_TEAM = "all_team_members"
  26. PARTIAL_TEAM = "partial_members"
  27. class Dataset(db.Model): # type: ignore[name-defined]
  28. __tablename__ = "datasets"
  29. __table_args__ = (
  30. db.PrimaryKeyConstraint("id", name="dataset_pkey"),
  31. db.Index("dataset_tenant_idx", "tenant_id"),
  32. db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
  33. )
  34. INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
  35. PROVIDER_LIST = ["vendor", "external", None]
  36. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  37. tenant_id = db.Column(StringUUID, nullable=False)
  38. name = db.Column(db.String(255), nullable=False)
  39. description = db.Column(db.Text, nullable=True)
  40. provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))
  41. permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))
  42. data_source_type = db.Column(db.String(255))
  43. indexing_technique = db.Column(db.String(255), nullable=True)
  44. index_struct = db.Column(db.Text, nullable=True)
  45. created_by = db.Column(StringUUID, nullable=False)
  46. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  47. updated_by = db.Column(StringUUID, nullable=True)
  48. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  49. embedding_model = db.Column(db.String(255), nullable=True)
  50. embedding_model_provider = db.Column(db.String(255), nullable=True)
  51. collection_binding_id = db.Column(StringUUID, nullable=True)
  52. retrieval_model = db.Column(JSONB, nullable=True)
  53. @property
  54. def dataset_keyword_table(self):
  55. dataset_keyword_table = (
  56. db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
  57. )
  58. if dataset_keyword_table:
  59. return dataset_keyword_table
  60. return None
  61. @property
  62. def index_struct_dict(self):
  63. return json.loads(self.index_struct) if self.index_struct else None
  64. @property
  65. def external_retrieval_model(self):
  66. default_retrieval_model = {
  67. "top_k": 2,
  68. "score_threshold": 0.0,
  69. }
  70. return self.retrieval_model or default_retrieval_model
  71. @property
  72. def created_by_account(self):
  73. return db.session.get(Account, self.created_by)
  74. @property
  75. def latest_process_rule(self):
  76. return (
  77. DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
  78. .order_by(DatasetProcessRule.created_at.desc())
  79. .first()
  80. )
  81. @property
  82. def app_count(self):
  83. return (
  84. db.session.query(func.count(AppDatasetJoin.id))
  85. .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
  86. .scalar()
  87. )
  88. @property
  89. def document_count(self):
  90. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  91. @property
  92. def available_document_count(self):
  93. return (
  94. db.session.query(func.count(Document.id))
  95. .filter(
  96. Document.dataset_id == self.id,
  97. Document.indexing_status == "completed",
  98. Document.enabled == True,
  99. Document.archived == False,
  100. )
  101. .scalar()
  102. )
  103. @property
  104. def available_segment_count(self):
  105. return (
  106. db.session.query(func.count(DocumentSegment.id))
  107. .filter(
  108. DocumentSegment.dataset_id == self.id,
  109. DocumentSegment.status == "completed",
  110. DocumentSegment.enabled == True,
  111. )
  112. .scalar()
  113. )
  114. @property
  115. def word_count(self):
  116. return (
  117. Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
  118. .filter(Document.dataset_id == self.id)
  119. .scalar()
  120. )
  121. @property
  122. def doc_form(self):
  123. document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
  124. if document:
  125. return document.doc_form
  126. return None
  127. @property
  128. def retrieval_model_dict(self):
  129. default_retrieval_model = {
  130. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  131. "reranking_enable": False,
  132. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  133. "top_k": 2,
  134. "score_threshold_enabled": False,
  135. }
  136. return self.retrieval_model or default_retrieval_model
  137. @property
  138. def tags(self):
  139. tags = (
  140. db.session.query(Tag)
  141. .join(TagBinding, Tag.id == TagBinding.tag_id)
  142. .filter(
  143. TagBinding.target_id == self.id,
  144. TagBinding.tenant_id == self.tenant_id,
  145. Tag.tenant_id == self.tenant_id,
  146. Tag.type == "knowledge",
  147. )
  148. .all()
  149. )
  150. return tags or []
  151. @property
  152. def external_knowledge_info(self):
  153. if self.provider != "external":
  154. return None
  155. external_knowledge_binding = (
  156. db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
  157. )
  158. if not external_knowledge_binding:
  159. return None
  160. external_knowledge_api = (
  161. db.session.query(ExternalKnowledgeApis)
  162. .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
  163. .first()
  164. )
  165. if not external_knowledge_api:
  166. return None
  167. return {
  168. "external_knowledge_id": external_knowledge_binding.external_knowledge_id,
  169. "external_knowledge_api_id": external_knowledge_api.id,
  170. "external_knowledge_api_name": external_knowledge_api.name,
  171. "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
  172. }
  173. @staticmethod
  174. def gen_collection_name_by_id(dataset_id: str) -> str:
  175. normalized_dataset_id = dataset_id.replace("-", "_")
  176. return f"Vector_index_{normalized_dataset_id}_Node"
  177. class DatasetProcessRule(db.Model): # type: ignore[name-defined]
  178. __tablename__ = "dataset_process_rules"
  179. __table_args__ = (
  180. db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
  181. db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
  182. )
  183. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  184. dataset_id = db.Column(StringUUID, nullable=False)
  185. mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
  186. rules = db.Column(db.Text, nullable=True)
  187. created_by = db.Column(StringUUID, nullable=False)
  188. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  189. MODES = ["automatic", "custom", "hierarchical"]
  190. PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
  191. AUTOMATIC_RULES: dict[str, Any] = {
  192. "pre_processing_rules": [
  193. {"id": "remove_extra_spaces", "enabled": True},
  194. {"id": "remove_urls_emails", "enabled": False},
  195. ],
  196. "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
  197. }
  198. def to_dict(self):
  199. return {
  200. "id": self.id,
  201. "dataset_id": self.dataset_id,
  202. "mode": self.mode,
  203. "rules": self.rules_dict,
  204. }
  205. @property
  206. def rules_dict(self):
  207. try:
  208. return json.loads(self.rules) if self.rules else None
  209. except JSONDecodeError:
  210. return None
  211. class Document(db.Model): # type: ignore[name-defined]
  212. __tablename__ = "documents"
  213. __table_args__ = (
  214. db.PrimaryKeyConstraint("id", name="document_pkey"),
  215. db.Index("document_dataset_id_idx", "dataset_id"),
  216. db.Index("document_is_paused_idx", "is_paused"),
  217. db.Index("document_tenant_idx", "tenant_id"),
  218. )
  219. # initial fields
  220. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  221. tenant_id = db.Column(StringUUID, nullable=False)
  222. dataset_id = db.Column(StringUUID, nullable=False)
  223. position = db.Column(db.Integer, nullable=False)
  224. data_source_type = db.Column(db.String(255), nullable=False)
  225. data_source_info = db.Column(db.Text, nullable=True)
  226. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  227. batch = db.Column(db.String(255), nullable=False)
  228. name = db.Column(db.String(255), nullable=False)
  229. created_from = db.Column(db.String(255), nullable=False)
  230. created_by = db.Column(StringUUID, nullable=False)
  231. created_api_request_id = db.Column(StringUUID, nullable=True)
  232. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  233. # start processing
  234. processing_started_at = db.Column(db.DateTime, nullable=True)
  235. # parsing
  236. file_id = db.Column(db.Text, nullable=True)
  237. word_count = db.Column(db.Integer, nullable=True)
  238. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  239. # cleaning
  240. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  241. # split
  242. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  243. # indexing
  244. tokens = db.Column(db.Integer, nullable=True)
  245. indexing_latency = db.Column(db.Float, nullable=True)
  246. completed_at = db.Column(db.DateTime, nullable=True)
  247. # pause
  248. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
  249. paused_by = db.Column(StringUUID, nullable=True)
  250. paused_at = db.Column(db.DateTime, nullable=True)
  251. # error
  252. error = db.Column(db.Text, nullable=True)
  253. stopped_at = db.Column(db.DateTime, nullable=True)
  254. # basic fields
  255. indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  256. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  257. disabled_at = db.Column(db.DateTime, nullable=True)
  258. disabled_by = db.Column(StringUUID, nullable=True)
  259. archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  260. archived_reason = db.Column(db.String(255), nullable=True)
  261. archived_by = db.Column(StringUUID, nullable=True)
  262. archived_at = db.Column(db.DateTime, nullable=True)
  263. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  264. doc_type = db.Column(db.String(40), nullable=True)
  265. doc_metadata = db.Column(db.JSON, nullable=True)
  266. doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
  267. doc_language = db.Column(db.String(255), nullable=True)
  268. DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
  269. @property
  270. def display_status(self):
  271. status = None
  272. if self.indexing_status == "waiting":
  273. status = "queuing"
  274. elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
  275. status = "paused"
  276. elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
  277. status = "indexing"
  278. elif self.indexing_status == "error":
  279. status = "error"
  280. elif self.indexing_status == "completed" and not self.archived and self.enabled:
  281. status = "available"
  282. elif self.indexing_status == "completed" and not self.archived and not self.enabled:
  283. status = "disabled"
  284. elif self.indexing_status == "completed" and self.archived:
  285. status = "archived"
  286. return status
  287. @property
  288. def data_source_info_dict(self):
  289. if self.data_source_info:
  290. try:
  291. data_source_info_dict = json.loads(self.data_source_info)
  292. except JSONDecodeError:
  293. data_source_info_dict = {}
  294. return data_source_info_dict
  295. return None
  296. @property
  297. def data_source_detail_dict(self):
  298. if self.data_source_info:
  299. if self.data_source_type == "upload_file":
  300. data_source_info_dict = json.loads(self.data_source_info)
  301. file_detail = (
  302. db.session.query(UploadFile)
  303. .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
  304. .one_or_none()
  305. )
  306. if file_detail:
  307. return {
  308. "upload_file": {
  309. "id": file_detail.id,
  310. "name": file_detail.name,
  311. "size": file_detail.size,
  312. "extension": file_detail.extension,
  313. "mime_type": file_detail.mime_type,
  314. "created_by": file_detail.created_by,
  315. "created_at": file_detail.created_at.timestamp(),
  316. }
  317. }
  318. elif self.data_source_type in {"notion_import", "website_crawl"}:
  319. return json.loads(self.data_source_info)
  320. return {}
  321. @property
  322. def average_segment_length(self):
  323. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  324. return self.word_count // self.segment_count
  325. return 0
  326. @property
  327. def dataset_process_rule(self):
  328. if self.dataset_process_rule_id:
  329. return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
  330. return None
  331. @property
  332. def dataset(self):
  333. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  334. @property
  335. def segment_count(self):
  336. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  337. @property
  338. def hit_count(self):
  339. return (
  340. DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
  341. .filter(DocumentSegment.document_id == self.id)
  342. .scalar()
  343. )
  344. @property
  345. def process_rule_dict(self):
  346. if self.dataset_process_rule_id:
  347. return self.dataset_process_rule.to_dict()
  348. return None
  349. def to_dict(self):
  350. return {
  351. "id": self.id,
  352. "tenant_id": self.tenant_id,
  353. "dataset_id": self.dataset_id,
  354. "position": self.position,
  355. "data_source_type": self.data_source_type,
  356. "data_source_info": self.data_source_info,
  357. "dataset_process_rule_id": self.dataset_process_rule_id,
  358. "batch": self.batch,
  359. "name": self.name,
  360. "created_from": self.created_from,
  361. "created_by": self.created_by,
  362. "created_api_request_id": self.created_api_request_id,
  363. "created_at": self.created_at,
  364. "processing_started_at": self.processing_started_at,
  365. "file_id": self.file_id,
  366. "word_count": self.word_count,
  367. "parsing_completed_at": self.parsing_completed_at,
  368. "cleaning_completed_at": self.cleaning_completed_at,
  369. "splitting_completed_at": self.splitting_completed_at,
  370. "tokens": self.tokens,
  371. "indexing_latency": self.indexing_latency,
  372. "completed_at": self.completed_at,
  373. "is_paused": self.is_paused,
  374. "paused_by": self.paused_by,
  375. "paused_at": self.paused_at,
  376. "error": self.error,
  377. "stopped_at": self.stopped_at,
  378. "indexing_status": self.indexing_status,
  379. "enabled": self.enabled,
  380. "disabled_at": self.disabled_at,
  381. "disabled_by": self.disabled_by,
  382. "archived": self.archived,
  383. "archived_reason": self.archived_reason,
  384. "archived_by": self.archived_by,
  385. "archived_at": self.archived_at,
  386. "updated_at": self.updated_at,
  387. "doc_type": self.doc_type,
  388. "doc_metadata": self.doc_metadata,
  389. "doc_form": self.doc_form,
  390. "doc_language": self.doc_language,
  391. "display_status": self.display_status,
  392. "data_source_info_dict": self.data_source_info_dict,
  393. "average_segment_length": self.average_segment_length,
  394. "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
  395. "dataset": self.dataset.to_dict() if self.dataset else None,
  396. "segment_count": self.segment_count,
  397. "hit_count": self.hit_count,
  398. }
  399. @classmethod
  400. def from_dict(cls, data: dict):
  401. return cls(
  402. id=data.get("id"),
  403. tenant_id=data.get("tenant_id"),
  404. dataset_id=data.get("dataset_id"),
  405. position=data.get("position"),
  406. data_source_type=data.get("data_source_type"),
  407. data_source_info=data.get("data_source_info"),
  408. dataset_process_rule_id=data.get("dataset_process_rule_id"),
  409. batch=data.get("batch"),
  410. name=data.get("name"),
  411. created_from=data.get("created_from"),
  412. created_by=data.get("created_by"),
  413. created_api_request_id=data.get("created_api_request_id"),
  414. created_at=data.get("created_at"),
  415. processing_started_at=data.get("processing_started_at"),
  416. file_id=data.get("file_id"),
  417. word_count=data.get("word_count"),
  418. parsing_completed_at=data.get("parsing_completed_at"),
  419. cleaning_completed_at=data.get("cleaning_completed_at"),
  420. splitting_completed_at=data.get("splitting_completed_at"),
  421. tokens=data.get("tokens"),
  422. indexing_latency=data.get("indexing_latency"),
  423. completed_at=data.get("completed_at"),
  424. is_paused=data.get("is_paused"),
  425. paused_by=data.get("paused_by"),
  426. paused_at=data.get("paused_at"),
  427. error=data.get("error"),
  428. stopped_at=data.get("stopped_at"),
  429. indexing_status=data.get("indexing_status"),
  430. enabled=data.get("enabled"),
  431. disabled_at=data.get("disabled_at"),
  432. disabled_by=data.get("disabled_by"),
  433. archived=data.get("archived"),
  434. archived_reason=data.get("archived_reason"),
  435. archived_by=data.get("archived_by"),
  436. archived_at=data.get("archived_at"),
  437. updated_at=data.get("updated_at"),
  438. doc_type=data.get("doc_type"),
  439. doc_metadata=data.get("doc_metadata"),
  440. doc_form=data.get("doc_form"),
  441. doc_language=data.get("doc_language"),
  442. )
  443. class DocumentSegment(db.Model): # type: ignore[name-defined]
  444. __tablename__ = "document_segments"
  445. __table_args__ = (
  446. db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
  447. db.Index("document_segment_dataset_id_idx", "dataset_id"),
  448. db.Index("document_segment_document_id_idx", "document_id"),
  449. db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
  450. db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
  451. db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"),
  452. db.Index("document_segment_tenant_idx", "tenant_id"),
  453. )
  454. # initial fields
  455. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  456. tenant_id = db.Column(StringUUID, nullable=False)
  457. dataset_id = db.Column(StringUUID, nullable=False)
  458. document_id = db.Column(StringUUID, nullable=False)
  459. position = db.Column(db.Integer, nullable=False)
  460. content = db.Column(db.Text, nullable=False)
  461. answer = db.Column(db.Text, nullable=True)
  462. word_count = db.Column(db.Integer, nullable=False)
  463. tokens = db.Column(db.Integer, nullable=False)
  464. # indexing fields
  465. keywords = db.Column(db.JSON, nullable=True)
  466. index_node_id = db.Column(db.String(255), nullable=True)
  467. index_node_hash = db.Column(db.String(255), nullable=True)
  468. # basic fields
  469. hit_count = db.Column(db.Integer, nullable=False, default=0)
  470. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  471. disabled_at = db.Column(db.DateTime, nullable=True)
  472. disabled_by = db.Column(StringUUID, nullable=True)
  473. status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  474. created_by = db.Column(StringUUID, nullable=False)
  475. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  476. updated_by = db.Column(StringUUID, nullable=True)
  477. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  478. indexing_at = db.Column(db.DateTime, nullable=True)
  479. completed_at = db.Column(db.DateTime, nullable=True)
  480. error = db.Column(db.Text, nullable=True)
  481. stopped_at = db.Column(db.DateTime, nullable=True)
  482. @property
  483. def dataset(self):
  484. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  485. @property
  486. def document(self):
  487. return db.session.query(Document).filter(Document.id == self.document_id).first()
  488. @property
  489. def previous_segment(self):
  490. return (
  491. db.session.query(DocumentSegment)
  492. .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
  493. .first()
  494. )
  495. @property
  496. def next_segment(self):
  497. return (
  498. db.session.query(DocumentSegment)
  499. .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
  500. .first()
  501. )
  502. @property
  503. def child_chunks(self):
  504. process_rule = self.document.dataset_process_rule
  505. if process_rule.mode == "hierarchical":
  506. rules = Rule(**process_rule.rules_dict)
  507. if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
  508. child_chunks = (
  509. db.session.query(ChildChunk)
  510. .filter(ChildChunk.segment_id == self.id)
  511. .order_by(ChildChunk.position.asc())
  512. .all()
  513. )
  514. return child_chunks or []
  515. else:
  516. return []
  517. else:
  518. return []
  519. def get_sign_content(self):
  520. signed_urls = []
  521. text = self.content
  522. # For data before v0.10.0
  523. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  524. matches = re.finditer(pattern, text)
  525. for match in matches:
  526. upload_file_id = match.group(1)
  527. nonce = os.urandom(16).hex()
  528. timestamp = str(int(time.time()))
  529. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  530. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
  531. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  532. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  533. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  534. signed_url = f"{match.group(0)}?{params}"
  535. signed_urls.append((match.start(), match.end(), signed_url))
  536. # For data after v0.10.0
  537. pattern = r"/files/([a-f0-9\-]+)/file-preview"
  538. matches = re.finditer(pattern, text)
  539. for match in matches:
  540. upload_file_id = match.group(1)
  541. nonce = os.urandom(16).hex()
  542. timestamp = str(int(time.time()))
  543. data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
  544. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
  545. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  546. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  547. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  548. signed_url = f"{match.group(0)}?{params}"
  549. signed_urls.append((match.start(), match.end(), signed_url))
  550. # Reconstruct the text with signed URLs
  551. offset = 0
  552. for start, end, signed_url in signed_urls:
  553. text = text[: start + offset] + signed_url + text[end + offset :]
  554. offset += len(signed_url) - (end - start)
  555. return text
  556. class ChildChunk(db.Model): # type: ignore[name-defined]
  557. __tablename__ = "child_chunks"
  558. __table_args__ = (
  559. db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
  560. db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
  561. )
  562. # initial fields
  563. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  564. tenant_id = db.Column(StringUUID, nullable=False)
  565. dataset_id = db.Column(StringUUID, nullable=False)
  566. document_id = db.Column(StringUUID, nullable=False)
  567. segment_id = db.Column(StringUUID, nullable=False)
  568. position = db.Column(db.Integer, nullable=False)
  569. content = db.Column(db.Text, nullable=False)
  570. word_count = db.Column(db.Integer, nullable=False)
  571. # indexing fields
  572. index_node_id = db.Column(db.String(255), nullable=True)
  573. index_node_hash = db.Column(db.String(255), nullable=True)
  574. type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
  575. created_by = db.Column(StringUUID, nullable=False)
  576. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  577. updated_by = db.Column(StringUUID, nullable=True)
  578. updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  579. indexing_at = db.Column(db.DateTime, nullable=True)
  580. completed_at = db.Column(db.DateTime, nullable=True)
  581. error = db.Column(db.Text, nullable=True)
  582. @property
  583. def dataset(self):
  584. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  585. @property
  586. def document(self):
  587. return db.session.query(Document).filter(Document.id == self.document_id).first()
  588. @property
  589. def segment(self):
  590. return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
  591. class AppDatasetJoin(db.Model): # type: ignore[name-defined]
  592. __tablename__ = "app_dataset_joins"
  593. __table_args__ = (
  594. db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
  595. db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
  596. )
  597. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
  598. app_id = db.Column(StringUUID, nullable=False)
  599. dataset_id = db.Column(StringUUID, nullable=False)
  600. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  601. @property
  602. def app(self):
  603. return db.session.get(App, self.app_id)
  604. class DatasetQuery(db.Model): # type: ignore[name-defined]
  605. __tablename__ = "dataset_queries"
  606. __table_args__ = (
  607. db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
  608. db.Index("dataset_query_dataset_id_idx", "dataset_id"),
  609. )
  610. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
  611. dataset_id = db.Column(StringUUID, nullable=False)
  612. content = db.Column(db.Text, nullable=False)
  613. source = db.Column(db.String(255), nullable=False)
  614. source_app_id = db.Column(StringUUID, nullable=True)
  615. created_by_role = db.Column(db.String, nullable=False)
  616. created_by = db.Column(StringUUID, nullable=False)
  617. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  618. class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
  619. __tablename__ = "dataset_keyword_tables"
  620. __table_args__ = (
  621. db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
  622. db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
  623. )
  624. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  625. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  626. keyword_table = db.Column(db.Text, nullable=False)
  627. data_source_type = db.Column(
  628. db.String(255), nullable=False, server_default=db.text("'database'::character varying")
  629. )
  630. @property
  631. def keyword_table_dict(self):
  632. class SetDecoder(json.JSONDecoder):
  633. def __init__(self, *args, **kwargs):
  634. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  635. def object_hook(self, dct):
  636. if isinstance(dct, dict):
  637. for keyword, node_idxs in dct.items():
  638. if isinstance(node_idxs, list):
  639. dct[keyword] = set(node_idxs)
  640. return dct
  641. # get dataset
  642. dataset = Dataset.query.filter_by(id=self.dataset_id).first()
  643. if not dataset:
  644. return None
  645. if self.data_source_type == "database":
  646. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  647. else:
  648. file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt"
  649. try:
  650. keyword_table_text = storage.load_once(file_key)
  651. if keyword_table_text:
  652. return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)
  653. return None
  654. except Exception as e:
  655. logging.exception(f"Failed to load keyword table from file: {file_key}")
  656. return None
  657. class Embedding(db.Model): # type: ignore[name-defined]
  658. __tablename__ = "embeddings"
  659. __table_args__ = (
  660. db.PrimaryKeyConstraint("id", name="embedding_pkey"),
  661. db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
  662. db.Index("created_at_idx", "created_at"),
  663. )
  664. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  665. model_name = db.Column(
  666. db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
  667. )
  668. hash = db.Column(db.String(64), nullable=False)
  669. embedding = db.Column(db.LargeBinary, nullable=False)
  670. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  671. provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
  672. def set_embedding(self, embedding_data: list[float]):
  673. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  674. def get_embedding(self) -> list[float]:
  675. return cast(list[float], pickle.loads(self.embedding))
  676. class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
  677. __tablename__ = "dataset_collection_bindings"
  678. __table_args__ = (
  679. db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
  680. db.Index("provider_model_name_idx", "provider_name", "model_name"),
  681. )
  682. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  683. provider_name = db.Column(db.String(40), nullable=False)
  684. model_name = db.Column(db.String(255), nullable=False)
  685. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  686. collection_name = db.Column(db.String(64), nullable=False)
  687. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  688. class TidbAuthBinding(db.Model): # type: ignore[name-defined]
  689. __tablename__ = "tidb_auth_bindings"
  690. __table_args__ = (
  691. db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
  692. db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
  693. db.Index("tidb_auth_bindings_active_idx", "active"),
  694. db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
  695. db.Index("tidb_auth_bindings_status_idx", "status"),
  696. )
  697. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  698. tenant_id = db.Column(StringUUID, nullable=True)
  699. cluster_id = db.Column(db.String(255), nullable=False)
  700. cluster_name = db.Column(db.String(255), nullable=False)
  701. active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  702. status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
  703. account = db.Column(db.String(255), nullable=False)
  704. password = db.Column(db.String(255), nullable=False)
  705. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  706. class Whitelist(db.Model): # type: ignore[name-defined]
  707. __tablename__ = "whitelists"
  708. __table_args__ = (
  709. db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
  710. db.Index("whitelists_tenant_idx", "tenant_id"),
  711. )
  712. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  713. tenant_id = db.Column(StringUUID, nullable=True)
  714. category = db.Column(db.String(255), nullable=False)
  715. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  716. class DatasetPermission(db.Model): # type: ignore[name-defined]
  717. __tablename__ = "dataset_permissions"
  718. __table_args__ = (
  719. db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
  720. db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
  721. db.Index("idx_dataset_permissions_account_id", "account_id"),
  722. db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
  723. )
  724. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
  725. dataset_id = db.Column(StringUUID, nullable=False)
  726. account_id = db.Column(StringUUID, nullable=False)
  727. tenant_id = db.Column(StringUUID, nullable=False)
  728. has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  729. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  730. class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
  731. __tablename__ = "external_knowledge_apis"
  732. __table_args__ = (
  733. db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
  734. db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
  735. db.Index("external_knowledge_apis_name_idx", "name"),
  736. )
  737. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  738. name = db.Column(db.String(255), nullable=False)
  739. description = db.Column(db.String(255), nullable=False)
  740. tenant_id = db.Column(StringUUID, nullable=False)
  741. settings = db.Column(db.Text, nullable=True)
  742. created_by = db.Column(StringUUID, nullable=False)
  743. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  744. updated_by = db.Column(StringUUID, nullable=True)
  745. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  746. def to_dict(self):
  747. return {
  748. "id": self.id,
  749. "tenant_id": self.tenant_id,
  750. "name": self.name,
  751. "description": self.description,
  752. "settings": self.settings_dict,
  753. "dataset_bindings": self.dataset_bindings,
  754. "created_by": self.created_by,
  755. "created_at": self.created_at.isoformat(),
  756. }
  757. @property
  758. def settings_dict(self):
  759. try:
  760. return json.loads(self.settings) if self.settings else None
  761. except JSONDecodeError:
  762. return None
  763. @property
  764. def dataset_bindings(self):
  765. external_knowledge_bindings = (
  766. db.session.query(ExternalKnowledgeBindings)
  767. .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
  768. .all()
  769. )
  770. dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
  771. datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
  772. dataset_bindings = []
  773. for dataset in datasets:
  774. dataset_bindings.append({"id": dataset.id, "name": dataset.name})
  775. return dataset_bindings
  776. class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
  777. __tablename__ = "external_knowledge_bindings"
  778. __table_args__ = (
  779. db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
  780. db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
  781. db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
  782. db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
  783. db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
  784. )
  785. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  786. tenant_id = db.Column(StringUUID, nullable=False)
  787. external_knowledge_api_id = db.Column(StringUUID, nullable=False)
  788. dataset_id = db.Column(StringUUID, nullable=False)
  789. external_knowledge_id = db.Column(db.Text, nullable=False)
  790. created_by = db.Column(StringUUID, nullable=False)
  791. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  792. updated_by = db.Column(StringUUID, nullable=True)
  793. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  794. class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
  795. __tablename__ = "dataset_auto_disable_logs"
  796. __table_args__ = (
  797. db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
  798. db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
  799. db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
  800. db.Index("dataset_auto_disable_log_created_atx", "created_at"),
  801. )
  802. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  803. tenant_id = db.Column(StringUUID, nullable=False)
  804. dataset_id = db.Column(StringUUID, nullable=False)
  805. document_id = db.Column(StringUUID, nullable=False)
  806. notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  807. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))