dataset.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. import base64
  2. import enum
  3. import hashlib
  4. import hmac
  5. import json
  6. import logging
  7. import os
  8. import pickle
  9. import re
  10. import time
  11. from json import JSONDecodeError
  12. from sqlalchemy import func
  13. from sqlalchemy.dialects.postgresql import JSONB
  14. from configs import dify_config
  15. from core.rag.retrieval.retrival_methods import RetrievalMethod
  16. from extensions.ext_database import db
  17. from extensions.ext_storage import storage
  18. from .account import Account
  19. from .model import App, Tag, TagBinding, UploadFile
  20. from .types import StringUUID
  21. class DatasetPermissionEnum(str, enum.Enum):
  22. ONLY_ME = 'only_me'
  23. ALL_TEAM = 'all_team_members'
  24. PARTIAL_TEAM = 'partial_members'
  25. class Dataset(db.Model):
  26. __tablename__ = 'datasets'
  27. __table_args__ = (
  28. db.PrimaryKeyConstraint('id', name='dataset_pkey'),
  29. db.Index('dataset_tenant_idx', 'tenant_id'),
  30. db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
  31. )
  32. INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
  33. id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
  34. tenant_id = db.Column(StringUUID, nullable=False)
  35. name = db.Column(db.String(255), nullable=False)
  36. description = db.Column(db.Text, nullable=True)
  37. provider = db.Column(db.String(255), nullable=False,
  38. server_default=db.text("'vendor'::character varying"))
  39. permission = db.Column(db.String(255), nullable=False,
  40. server_default=db.text("'only_me'::character varying"))
  41. data_source_type = db.Column(db.String(255))
  42. indexing_technique = db.Column(db.String(255), nullable=True)
  43. index_struct = db.Column(db.Text, nullable=True)
  44. created_by = db.Column(StringUUID, nullable=False)
  45. created_at = db.Column(db.DateTime, nullable=False,
  46. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  47. updated_by = db.Column(StringUUID, nullable=True)
  48. updated_at = db.Column(db.DateTime, nullable=False,
  49. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  50. embedding_model = db.Column(db.String(255), nullable=True)
  51. embedding_model_provider = db.Column(db.String(255), nullable=True)
  52. collection_binding_id = db.Column(StringUUID, nullable=True)
  53. retrieval_model = db.Column(JSONB, nullable=True)
  54. @property
  55. def dataset_keyword_table(self):
  56. dataset_keyword_table = db.session.query(DatasetKeywordTable).filter(
  57. DatasetKeywordTable.dataset_id == self.id).first()
  58. if dataset_keyword_table:
  59. return dataset_keyword_table
  60. return None
  61. @property
  62. def index_struct_dict(self):
  63. return json.loads(self.index_struct) if self.index_struct else None
  64. @property
  65. def created_by_account(self):
  66. return db.session.get(Account, self.created_by)
  67. @property
  68. def latest_process_rule(self):
  69. return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \
  70. .order_by(DatasetProcessRule.created_at.desc()).first()
  71. @property
  72. def app_count(self):
  73. return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id,
  74. App.id == AppDatasetJoin.app_id).scalar()
  75. @property
  76. def document_count(self):
  77. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  78. @property
  79. def available_document_count(self):
  80. return db.session.query(func.count(Document.id)).filter(
  81. Document.dataset_id == self.id,
  82. Document.indexing_status == 'completed',
  83. Document.enabled == True,
  84. Document.archived == False
  85. ).scalar()
  86. @property
  87. def available_segment_count(self):
  88. return db.session.query(func.count(DocumentSegment.id)).filter(
  89. DocumentSegment.dataset_id == self.id,
  90. DocumentSegment.status == 'completed',
  91. DocumentSegment.enabled == True
  92. ).scalar()
  93. @property
  94. def word_count(self):
  95. return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
  96. .filter(Document.dataset_id == self.id).scalar()
  97. @property
  98. def doc_form(self):
  99. document = db.session.query(Document).filter(
  100. Document.dataset_id == self.id).first()
  101. if document:
  102. return document.doc_form
  103. return None
  104. @property
  105. def retrieval_model_dict(self):
  106. default_retrieval_model = {
  107. 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
  108. 'reranking_enable': False,
  109. 'reranking_model': {
  110. 'reranking_provider_name': '',
  111. 'reranking_model_name': ''
  112. },
  113. 'top_k': 2,
  114. 'score_threshold_enabled': False
  115. }
  116. return self.retrieval_model if self.retrieval_model else default_retrieval_model
  117. @property
  118. def tags(self):
  119. tags = db.session.query(Tag).join(
  120. TagBinding,
  121. Tag.id == TagBinding.tag_id
  122. ).filter(
  123. TagBinding.target_id == self.id,
  124. TagBinding.tenant_id == self.tenant_id,
  125. Tag.tenant_id == self.tenant_id,
  126. Tag.type == 'knowledge'
  127. ).all()
  128. return tags if tags else []
  129. @staticmethod
  130. def gen_collection_name_by_id(dataset_id: str) -> str:
  131. normalized_dataset_id = dataset_id.replace("-", "_")
  132. return f'Vector_index_{normalized_dataset_id}_Node'
  133. class DatasetProcessRule(db.Model):
  134. __tablename__ = 'dataset_process_rules'
  135. __table_args__ = (
  136. db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'),
  137. db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
  138. )
  139. id = db.Column(StringUUID, nullable=False,
  140. server_default=db.text('uuid_generate_v4()'))
  141. dataset_id = db.Column(StringUUID, nullable=False)
  142. mode = db.Column(db.String(255), nullable=False,
  143. server_default=db.text("'automatic'::character varying"))
  144. rules = db.Column(db.Text, nullable=True)
  145. created_by = db.Column(StringUUID, nullable=False)
  146. created_at = db.Column(db.DateTime, nullable=False,
  147. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  148. MODES = ['automatic', 'custom']
  149. PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails']
  150. AUTOMATIC_RULES = {
  151. 'pre_processing_rules': [
  152. {'id': 'remove_extra_spaces', 'enabled': True},
  153. {'id': 'remove_urls_emails', 'enabled': False}
  154. ],
  155. 'segmentation': {
  156. 'delimiter': '\n',
  157. 'max_tokens': 500,
  158. 'chunk_overlap': 50
  159. }
  160. }
  161. def to_dict(self):
  162. return {
  163. 'id': self.id,
  164. 'dataset_id': self.dataset_id,
  165. 'mode': self.mode,
  166. 'rules': self.rules_dict,
  167. 'created_by': self.created_by,
  168. 'created_at': self.created_at,
  169. }
  170. @property
  171. def rules_dict(self):
  172. try:
  173. return json.loads(self.rules) if self.rules else None
  174. except JSONDecodeError:
  175. return None
  176. class Document(db.Model):
  177. __tablename__ = 'documents'
  178. __table_args__ = (
  179. db.PrimaryKeyConstraint('id', name='document_pkey'),
  180. db.Index('document_dataset_id_idx', 'dataset_id'),
  181. db.Index('document_is_paused_idx', 'is_paused'),
  182. db.Index('document_tenant_idx', 'tenant_id'),
  183. )
  184. # initial fields
  185. id = db.Column(StringUUID, nullable=False,
  186. server_default=db.text('uuid_generate_v4()'))
  187. tenant_id = db.Column(StringUUID, nullable=False)
  188. dataset_id = db.Column(StringUUID, nullable=False)
  189. position = db.Column(db.Integer, nullable=False)
  190. data_source_type = db.Column(db.String(255), nullable=False)
  191. data_source_info = db.Column(db.Text, nullable=True)
  192. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  193. batch = db.Column(db.String(255), nullable=False)
  194. name = db.Column(db.String(255), nullable=False)
  195. created_from = db.Column(db.String(255), nullable=False)
  196. created_by = db.Column(StringUUID, nullable=False)
  197. created_api_request_id = db.Column(StringUUID, nullable=True)
  198. created_at = db.Column(db.DateTime, nullable=False,
  199. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  200. # start processing
  201. processing_started_at = db.Column(db.DateTime, nullable=True)
  202. # parsing
  203. file_id = db.Column(db.Text, nullable=True)
  204. word_count = db.Column(db.Integer, nullable=True)
  205. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  206. # cleaning
  207. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  208. # split
  209. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  210. # indexing
  211. tokens = db.Column(db.Integer, nullable=True)
  212. indexing_latency = db.Column(db.Float, nullable=True)
  213. completed_at = db.Column(db.DateTime, nullable=True)
  214. # pause
  215. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
  216. paused_by = db.Column(StringUUID, nullable=True)
  217. paused_at = db.Column(db.DateTime, nullable=True)
  218. # error
  219. error = db.Column(db.Text, nullable=True)
  220. stopped_at = db.Column(db.DateTime, nullable=True)
  221. # basic fields
  222. indexing_status = db.Column(db.String(
  223. 255), nullable=False, server_default=db.text("'waiting'::character varying"))
  224. enabled = db.Column(db.Boolean, nullable=False,
  225. server_default=db.text('true'))
  226. disabled_at = db.Column(db.DateTime, nullable=True)
  227. disabled_by = db.Column(StringUUID, nullable=True)
  228. archived = db.Column(db.Boolean, nullable=False,
  229. server_default=db.text('false'))
  230. archived_reason = db.Column(db.String(255), nullable=True)
  231. archived_by = db.Column(StringUUID, nullable=True)
  232. archived_at = db.Column(db.DateTime, nullable=True)
  233. updated_at = db.Column(db.DateTime, nullable=False,
  234. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  235. doc_type = db.Column(db.String(40), nullable=True)
  236. doc_metadata = db.Column(db.JSON, nullable=True)
  237. doc_form = db.Column(db.String(
  238. 255), nullable=False, server_default=db.text("'text_model'::character varying"))
  239. doc_language = db.Column(db.String(255), nullable=True)
  240. DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
  241. @property
  242. def display_status(self):
  243. status = None
  244. if self.indexing_status == 'waiting':
  245. status = 'queuing'
  246. elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused:
  247. status = 'paused'
  248. elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']:
  249. status = 'indexing'
  250. elif self.indexing_status == 'error':
  251. status = 'error'
  252. elif self.indexing_status == 'completed' and not self.archived and self.enabled:
  253. status = 'available'
  254. elif self.indexing_status == 'completed' and not self.archived and not self.enabled:
  255. status = 'disabled'
  256. elif self.indexing_status == 'completed' and self.archived:
  257. status = 'archived'
  258. return status
  259. @property
  260. def data_source_info_dict(self):
  261. if self.data_source_info:
  262. try:
  263. data_source_info_dict = json.loads(self.data_source_info)
  264. except JSONDecodeError:
  265. data_source_info_dict = {}
  266. return data_source_info_dict
  267. return None
  268. @property
  269. def data_source_detail_dict(self):
  270. if self.data_source_info:
  271. if self.data_source_type == 'upload_file':
  272. data_source_info_dict = json.loads(self.data_source_info)
  273. file_detail = db.session.query(UploadFile). \
  274. filter(UploadFile.id == data_source_info_dict['upload_file_id']). \
  275. one_or_none()
  276. if file_detail:
  277. return {
  278. 'upload_file': {
  279. 'id': file_detail.id,
  280. 'name': file_detail.name,
  281. 'size': file_detail.size,
  282. 'extension': file_detail.extension,
  283. 'mime_type': file_detail.mime_type,
  284. 'created_by': file_detail.created_by,
  285. 'created_at': file_detail.created_at.timestamp()
  286. }
  287. }
  288. elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
  289. return json.loads(self.data_source_info)
  290. return {}
  291. @property
  292. def average_segment_length(self):
  293. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  294. return self.word_count // self.segment_count
  295. return 0
  296. @property
  297. def dataset_process_rule(self):
  298. if self.dataset_process_rule_id:
  299. return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
  300. return None
  301. @property
  302. def dataset(self):
  303. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  304. @property
  305. def segment_count(self):
  306. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  307. @property
  308. def hit_count(self):
  309. return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
  310. .filter(DocumentSegment.document_id == self.id).scalar()
  311. def to_dict(self):
  312. return {
  313. 'id': self.id,
  314. 'tenant_id': self.tenant_id,
  315. 'dataset_id': self.dataset_id,
  316. 'position': self.position,
  317. 'data_source_type': self.data_source_type,
  318. 'data_source_info': self.data_source_info,
  319. 'dataset_process_rule_id': self.dataset_process_rule_id,
  320. 'batch': self.batch,
  321. 'name': self.name,
  322. 'created_from': self.created_from,
  323. 'created_by': self.created_by,
  324. 'created_api_request_id': self.created_api_request_id,
  325. 'created_at': self.created_at,
  326. 'processing_started_at': self.processing_started_at,
  327. 'file_id': self.file_id,
  328. 'word_count': self.word_count,
  329. 'parsing_completed_at': self.parsing_completed_at,
  330. 'cleaning_completed_at': self.cleaning_completed_at,
  331. 'splitting_completed_at': self.splitting_completed_at,
  332. 'tokens': self.tokens,
  333. 'indexing_latency': self.indexing_latency,
  334. 'completed_at': self.completed_at,
  335. 'is_paused': self.is_paused,
  336. 'paused_by': self.paused_by,
  337. 'paused_at': self.paused_at,
  338. 'error': self.error,
  339. 'stopped_at': self.stopped_at,
  340. 'indexing_status': self.indexing_status,
  341. 'enabled': self.enabled,
  342. 'disabled_at': self.disabled_at,
  343. 'disabled_by': self.disabled_by,
  344. 'archived': self.archived,
  345. 'archived_reason': self.archived_reason,
  346. 'archived_by': self.archived_by,
  347. 'archived_at': self.archived_at,
  348. 'updated_at': self.updated_at,
  349. 'doc_type': self.doc_type,
  350. 'doc_metadata': self.doc_metadata,
  351. 'doc_form': self.doc_form,
  352. 'doc_language': self.doc_language,
  353. 'display_status': self.display_status,
  354. 'data_source_info_dict': self.data_source_info_dict,
  355. 'average_segment_length': self.average_segment_length,
  356. 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
  357. 'dataset': self.dataset.to_dict() if self.dataset else None,
  358. 'segment_count': self.segment_count,
  359. 'hit_count': self.hit_count
  360. }
  361. @classmethod
  362. def from_dict(cls, data: dict):
  363. return cls(
  364. id=data.get('id'),
  365. tenant_id=data.get('tenant_id'),
  366. dataset_id=data.get('dataset_id'),
  367. position=data.get('position'),
  368. data_source_type=data.get('data_source_type'),
  369. data_source_info=data.get('data_source_info'),
  370. dataset_process_rule_id=data.get('dataset_process_rule_id'),
  371. batch=data.get('batch'),
  372. name=data.get('name'),
  373. created_from=data.get('created_from'),
  374. created_by=data.get('created_by'),
  375. created_api_request_id=data.get('created_api_request_id'),
  376. created_at=data.get('created_at'),
  377. processing_started_at=data.get('processing_started_at'),
  378. file_id=data.get('file_id'),
  379. word_count=data.get('word_count'),
  380. parsing_completed_at=data.get('parsing_completed_at'),
  381. cleaning_completed_at=data.get('cleaning_completed_at'),
  382. splitting_completed_at=data.get('splitting_completed_at'),
  383. tokens=data.get('tokens'),
  384. indexing_latency=data.get('indexing_latency'),
  385. completed_at=data.get('completed_at'),
  386. is_paused=data.get('is_paused'),
  387. paused_by=data.get('paused_by'),
  388. paused_at=data.get('paused_at'),
  389. error=data.get('error'),
  390. stopped_at=data.get('stopped_at'),
  391. indexing_status=data.get('indexing_status'),
  392. enabled=data.get('enabled'),
  393. disabled_at=data.get('disabled_at'),
  394. disabled_by=data.get('disabled_by'),
  395. archived=data.get('archived'),
  396. archived_reason=data.get('archived_reason'),
  397. archived_by=data.get('archived_by'),
  398. archived_at=data.get('archived_at'),
  399. updated_at=data.get('updated_at'),
  400. doc_type=data.get('doc_type'),
  401. doc_metadata=data.get('doc_metadata'),
  402. doc_form=data.get('doc_form'),
  403. doc_language=data.get('doc_language')
  404. )
  405. class DocumentSegment(db.Model):
  406. __tablename__ = 'document_segments'
  407. __table_args__ = (
  408. db.PrimaryKeyConstraint('id', name='document_segment_pkey'),
  409. db.Index('document_segment_dataset_id_idx', 'dataset_id'),
  410. db.Index('document_segment_document_id_idx', 'document_id'),
  411. db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'),
  412. db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'),
  413. db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'),
  414. db.Index('document_segment_tenant_idx', 'tenant_id'),
  415. )
  416. # initial fields
  417. id = db.Column(StringUUID, nullable=False,
  418. server_default=db.text('uuid_generate_v4()'))
  419. tenant_id = db.Column(StringUUID, nullable=False)
  420. dataset_id = db.Column(StringUUID, nullable=False)
  421. document_id = db.Column(StringUUID, nullable=False)
  422. position = db.Column(db.Integer, nullable=False)
  423. content = db.Column(db.Text, nullable=False)
  424. answer = db.Column(db.Text, nullable=True)
  425. word_count = db.Column(db.Integer, nullable=False)
  426. tokens = db.Column(db.Integer, nullable=False)
  427. # indexing fields
  428. keywords = db.Column(db.JSON, nullable=True)
  429. index_node_id = db.Column(db.String(255), nullable=True)
  430. index_node_hash = db.Column(db.String(255), nullable=True)
  431. # basic fields
  432. hit_count = db.Column(db.Integer, nullable=False, default=0)
  433. enabled = db.Column(db.Boolean, nullable=False,
  434. server_default=db.text('true'))
  435. disabled_at = db.Column(db.DateTime, nullable=True)
  436. disabled_by = db.Column(StringUUID, nullable=True)
  437. status = db.Column(db.String(255), nullable=False,
  438. server_default=db.text("'waiting'::character varying"))
  439. created_by = db.Column(StringUUID, nullable=False)
  440. created_at = db.Column(db.DateTime, nullable=False,
  441. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  442. updated_by = db.Column(StringUUID, nullable=True)
  443. updated_at = db.Column(db.DateTime, nullable=False,
  444. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  445. indexing_at = db.Column(db.DateTime, nullable=True)
  446. completed_at = db.Column(db.DateTime, nullable=True)
  447. error = db.Column(db.Text, nullable=True)
  448. stopped_at = db.Column(db.DateTime, nullable=True)
  449. @property
  450. def dataset(self):
  451. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  452. @property
  453. def document(self):
  454. return db.session.query(Document).filter(Document.id == self.document_id).first()
  455. @property
  456. def previous_segment(self):
  457. return db.session.query(DocumentSegment).filter(
  458. DocumentSegment.document_id == self.document_id,
  459. DocumentSegment.position == self.position - 1
  460. ).first()
  461. @property
  462. def next_segment(self):
  463. return db.session.query(DocumentSegment).filter(
  464. DocumentSegment.document_id == self.document_id,
  465. DocumentSegment.position == self.position + 1
  466. ).first()
  467. def get_sign_content(self):
  468. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  469. text = self.content
  470. matches = re.finditer(pattern, text)
  471. signed_urls = []
  472. for match in matches:
  473. upload_file_id = match.group(1)
  474. nonce = os.urandom(16).hex()
  475. timestamp = str(int(time.time()))
  476. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  477. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
  478. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  479. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  480. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  481. signed_url = f"{match.group(0)}?{params}"
  482. signed_urls.append((match.start(), match.end(), signed_url))
  483. # Reconstruct the text with signed URLs
  484. offset = 0
  485. for start, end, signed_url in signed_urls:
  486. text = text[:start + offset] + signed_url + text[end + offset:]
  487. offset += len(signed_url) - (end - start)
  488. return text
  489. class AppDatasetJoin(db.Model):
  490. __tablename__ = 'app_dataset_joins'
  491. __table_args__ = (
  492. db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'),
  493. db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
  494. )
  495. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  496. app_id = db.Column(StringUUID, nullable=False)
  497. dataset_id = db.Column(StringUUID, nullable=False)
  498. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  499. @property
  500. def app(self):
  501. return db.session.get(App, self.app_id)
  502. class DatasetQuery(db.Model):
  503. __tablename__ = 'dataset_queries'
  504. __table_args__ = (
  505. db.PrimaryKeyConstraint('id', name='dataset_query_pkey'),
  506. db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
  507. )
  508. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  509. dataset_id = db.Column(StringUUID, nullable=False)
  510. content = db.Column(db.Text, nullable=False)
  511. source = db.Column(db.String(255), nullable=False)
  512. source_app_id = db.Column(StringUUID, nullable=True)
  513. created_by_role = db.Column(db.String, nullable=False)
  514. created_by = db.Column(StringUUID, nullable=False)
  515. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  516. class DatasetKeywordTable(db.Model):
  517. __tablename__ = 'dataset_keyword_tables'
  518. __table_args__ = (
  519. db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'),
  520. db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
  521. )
  522. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  523. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  524. keyword_table = db.Column(db.Text, nullable=False)
  525. data_source_type = db.Column(db.String(255), nullable=False,
  526. server_default=db.text("'database'::character varying"))
  527. @property
  528. def keyword_table_dict(self):
  529. class SetDecoder(json.JSONDecoder):
  530. def __init__(self, *args, **kwargs):
  531. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  532. def object_hook(self, dct):
  533. if isinstance(dct, dict):
  534. for keyword, node_idxs in dct.items():
  535. if isinstance(node_idxs, list):
  536. dct[keyword] = set(node_idxs)
  537. return dct
  538. # get dataset
  539. dataset = Dataset.query.filter_by(
  540. id=self.dataset_id
  541. ).first()
  542. if not dataset:
  543. return None
  544. if self.data_source_type == 'database':
  545. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  546. else:
  547. file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt'
  548. try:
  549. keyword_table_text = storage.load_once(file_key)
  550. if keyword_table_text:
  551. return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder)
  552. return None
  553. except Exception as e:
  554. logging.exception(str(e))
  555. return None
  556. class Embedding(db.Model):
  557. __tablename__ = 'embeddings'
  558. __table_args__ = (
  559. db.PrimaryKeyConstraint('id', name='embedding_pkey'),
  560. db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'),
  561. db.Index('created_at_idx', 'created_at')
  562. )
  563. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  564. model_name = db.Column(db.String(255), nullable=False,
  565. server_default=db.text("'text-embedding-ada-002'::character varying"))
  566. hash = db.Column(db.String(64), nullable=False)
  567. embedding = db.Column(db.LargeBinary, nullable=False)
  568. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  569. provider_name = db.Column(db.String(255), nullable=False,
  570. server_default=db.text("''::character varying"))
  571. def set_embedding(self, embedding_data: list[float]):
  572. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  573. def get_embedding(self) -> list[float]:
  574. return pickle.loads(self.embedding)
  575. class DatasetCollectionBinding(db.Model):
  576. __tablename__ = 'dataset_collection_bindings'
  577. __table_args__ = (
  578. db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
  579. db.Index('provider_model_name_idx', 'provider_name', 'model_name')
  580. )
  581. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  582. provider_name = db.Column(db.String(40), nullable=False)
  583. model_name = db.Column(db.String(255), nullable=False)
  584. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  585. collection_name = db.Column(db.String(64), nullable=False)
  586. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  587. class DatasetPermission(db.Model):
  588. __tablename__ = 'dataset_permissions'
  589. __table_args__ = (
  590. db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'),
  591. db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'),
  592. db.Index('idx_dataset_permissions_account_id', 'account_id'),
  593. db.Index('idx_dataset_permissions_tenant_id', 'tenant_id')
  594. )
  595. id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True)
  596. dataset_id = db.Column(StringUUID, nullable=False)
  597. account_id = db.Column(StringUUID, nullable=False)
  598. tenant_id = db.Column(StringUUID, nullable=False)
  599. has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
  600. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))