|
@@ -3,7 +3,7 @@ import logging
|
|
|
import datetime
|
|
|
import time
|
|
|
import random
|
|
|
-from typing import Optional
|
|
|
+from typing import Optional, List
|
|
|
from extensions.ext_redis import redis_client
|
|
|
from flask_login import current_user
|
|
|
|
|
@@ -14,10 +14,12 @@ from extensions.ext_database import db
|
|
|
from models.account import Account
|
|
|
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
|
|
|
from models.model import UploadFile
|
|
|
+from models.source import DataSourceBinding
|
|
|
from services.errors.account import NoPermissionError
|
|
|
from services.errors.dataset import DatasetNameDuplicateError
|
|
|
from services.errors.document import DocumentIndexingError
|
|
|
from services.errors.file import FileNotExistsError
|
|
|
+from tasks.clean_notion_document_task import clean_notion_document_task
|
|
|
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
|
|
from tasks.document_indexing_task import document_indexing_task
|
|
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
|
@@ -286,6 +288,24 @@ class DocumentService:
|
|
|
return document
|
|
|
|
|
|
@staticmethod
|
|
|
+ def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
|
|
|
+ documents = db.session.query(Document).filter(
|
|
|
+ Document.dataset_id == dataset_id,
|
|
|
+ Document.enabled == True
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ return documents
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
|
|
|
+ documents = db.session.query(Document).filter(
|
|
|
+ Document.batch == batch,
|
|
|
+ Document.dataset_id == dataset_id,
|
|
|
+ Document.tenant_id == current_user.current_tenant_id
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ return documents
|
|
|
+ @staticmethod
|
|
|
def get_document_file_detail(file_id: str):
|
|
|
file_detail = db.session.query(UploadFile). \
|
|
|
filter(UploadFile.id == file_id). \
|
|
@@ -344,9 +364,9 @@ class DocumentService:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_documents_position(dataset_id):
|
|
|
- documents = Document.query.filter_by(dataset_id=dataset_id).all()
|
|
|
- if documents:
|
|
|
- return len(documents) + 1
|
|
|
+ document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
|
|
+ if document:
|
|
|
+ return document.position + 1
|
|
|
else:
|
|
|
return 1
|
|
|
|
|
@@ -363,9 +383,11 @@ class DocumentService:
|
|
|
|
|
|
if dataset.indexing_technique == 'high_quality':
|
|
|
IndexBuilder.get_default_service_context(dataset.tenant_id)
|
|
|
-
|
|
|
+ documents = []
|
|
|
+ batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
|
|
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
|
|
document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
|
|
|
+ documents.append(document)
|
|
|
else:
|
|
|
|
|
|
if not dataset_process_rule:
|
|
@@ -386,46 +408,114 @@ class DocumentService:
|
|
|
)
|
|
|
db.session.add(dataset_process_rule)
|
|
|
db.session.commit()
|
|
|
-
|
|
|
- file_name = ''
|
|
|
- data_source_info = {}
|
|
|
- if document_data["data_source"]["type"] == "upload_file":
|
|
|
- file_id = document_data["data_source"]["info"]
|
|
|
- file = db.session.query(UploadFile).filter(
|
|
|
- UploadFile.tenant_id == dataset.tenant_id,
|
|
|
- UploadFile.id == file_id
|
|
|
- ).first()
|
|
|
-
|
|
|
-
|
|
|
- if not file:
|
|
|
- raise FileNotExistsError()
|
|
|
-
|
|
|
- file_name = file.name
|
|
|
- data_source_info = {
|
|
|
- "upload_file_id": file_id,
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
position = DocumentService.get_documents_position(dataset.id)
|
|
|
- document = Document(
|
|
|
- tenant_id=dataset.tenant_id,
|
|
|
- dataset_id=dataset.id,
|
|
|
- position=position,
|
|
|
- data_source_type=document_data["data_source"]["type"],
|
|
|
- data_source_info=json.dumps(data_source_info),
|
|
|
- dataset_process_rule_id=dataset_process_rule.id,
|
|
|
- batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
|
|
|
- name=file_name,
|
|
|
- created_from=created_from,
|
|
|
- created_by=account.id,
|
|
|
-
|
|
|
- )
|
|
|
-
|
|
|
- db.session.add(document)
|
|
|
+ document_ids = []
|
|
|
+ if document_data["data_source"]["type"] == "upload_file":
|
|
|
+ upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
|
|
+ for file_id in upload_file_list:
|
|
|
+ file = db.session.query(UploadFile).filter(
|
|
|
+ UploadFile.tenant_id == dataset.tenant_id,
|
|
|
+ UploadFile.id == file_id
|
|
|
+ ).first()
|
|
|
+
|
|
|
+
|
|
|
+ if not file:
|
|
|
+ raise FileNotExistsError()
|
|
|
+
|
|
|
+ file_name = file.name
|
|
|
+ data_source_info = {
|
|
|
+ "upload_file_id": file_id,
|
|
|
+ }
|
|
|
+ document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, file_name, batch)
|
|
|
+ db.session.add(document)
|
|
|
+ db.session.flush()
|
|
|
+ document_ids.append(document.id)
|
|
|
+ documents.append(document)
|
|
|
+ position += 1
|
|
|
+ elif document_data["data_source"]["type"] == "notion_import":
|
|
|
+ notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
+ exist_page_ids = []
|
|
|
+ exist_document = dict()
|
|
|
+ documents = Document.query.filter_by(
|
|
|
+ dataset_id=dataset.id,
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
+ data_source_type='notion_import',
|
|
|
+ enabled=True
|
|
|
+ ).all()
|
|
|
+ if documents:
|
|
|
+ for document in documents:
|
|
|
+ data_source_info = json.loads(document.data_source_info)
|
|
|
+ exist_page_ids.append(data_source_info['notion_page_id'])
|
|
|
+ exist_document[data_source_info['notion_page_id']] = document.id
|
|
|
+ for notion_info in notion_info_list:
|
|
|
+ workspace_id = notion_info['workspace_id']
|
|
|
+ data_source_binding = DataSourceBinding.query.filter(
|
|
|
+ db.and_(
|
|
|
+ DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
|
|
+ DataSourceBinding.provider == 'notion',
|
|
|
+ DataSourceBinding.disabled == False,
|
|
|
+ DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
|
+ )
|
|
|
+ ).first()
|
|
|
+ if not data_source_binding:
|
|
|
+ raise ValueError('Data source binding not found.')
|
|
|
+ for page in notion_info['pages']:
|
|
|
+ if page['page_id'] not in exist_page_ids:
|
|
|
+ data_source_info = {
|
|
|
+ "notion_workspace_id": workspace_id,
|
|
|
+ "notion_page_id": page['page_id'],
|
|
|
+ "notion_page_icon": page['page_icon'],
|
|
|
+ "type": page['type']
|
|
|
+ }
|
|
|
+ document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, page['page_name'], batch)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ db.session.add(document)
|
|
|
+ db.session.flush()
|
|
|
+
|
|
|
+ document_ids.append(document.id)
|
|
|
+ documents.append(document)
|
|
|
+ position += 1
|
|
|
+ else:
|
|
|
+ exist_document.pop(page['page_id'])
|
|
|
+
|
|
|
+ if len(exist_document) > 0:
|
|
|
+ clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
- document_indexing_task.delay(document.dataset_id, document.id)
|
|
|
+ document_indexing_task.delay(dataset.id, document_ids)
|
|
|
+
|
|
|
+ return documents, batch
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
|
|
|
+ created_from: str, position: int, account: Account, name: str, batch: str):
|
|
|
+ document = Document(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ dataset_id=dataset.id,
|
|
|
+ position=position,
|
|
|
+ data_source_type=data_source_type,
|
|
|
+ data_source_info=json.dumps(data_source_info),
|
|
|
+ dataset_process_rule_id=process_rule_id,
|
|
|
+ batch=batch,
|
|
|
+ name=name,
|
|
|
+ created_from=created_from,
|
|
|
+ created_by=account.id,
|
|
|
+ )
|
|
|
return document
|
|
|
|
|
|
@staticmethod
|
|
@@ -460,20 +550,42 @@ class DocumentService:
|
|
|
file_name = ''
|
|
|
data_source_info = {}
|
|
|
if document_data["data_source"]["type"] == "upload_file":
|
|
|
- file_id = document_data["data_source"]["info"]
|
|
|
- file = db.session.query(UploadFile).filter(
|
|
|
- UploadFile.tenant_id == dataset.tenant_id,
|
|
|
- UploadFile.id == file_id
|
|
|
- ).first()
|
|
|
-
|
|
|
-
|
|
|
- if not file:
|
|
|
- raise FileNotExistsError()
|
|
|
-
|
|
|
- file_name = file.name
|
|
|
- data_source_info = {
|
|
|
- "upload_file_id": file_id,
|
|
|
- }
|
|
|
+ upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
|
|
+ for file_id in upload_file_list:
|
|
|
+ file = db.session.query(UploadFile).filter(
|
|
|
+ UploadFile.tenant_id == dataset.tenant_id,
|
|
|
+ UploadFile.id == file_id
|
|
|
+ ).first()
|
|
|
+
|
|
|
+
|
|
|
+ if not file:
|
|
|
+ raise FileNotExistsError()
|
|
|
+
|
|
|
+ file_name = file.name
|
|
|
+ data_source_info = {
|
|
|
+ "upload_file_id": file_id,
|
|
|
+ }
|
|
|
+ elif document_data["data_source"]["type"] == "notion_import":
|
|
|
+ notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
+ for notion_info in notion_info_list:
|
|
|
+ workspace_id = notion_info['workspace_id']
|
|
|
+ data_source_binding = DataSourceBinding.query.filter(
|
|
|
+ db.and_(
|
|
|
+ DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
|
|
+ DataSourceBinding.provider == 'notion',
|
|
|
+ DataSourceBinding.disabled == False,
|
|
|
+ DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
|
+ )
|
|
|
+ ).first()
|
|
|
+ if not data_source_binding:
|
|
|
+ raise ValueError('Data source binding not found.')
|
|
|
+ for page in notion_info['pages']:
|
|
|
+ data_source_info = {
|
|
|
+ "notion_workspace_id": workspace_id,
|
|
|
+ "notion_page_id": page['page_id'],
|
|
|
+ "notion_page_icon": page['page_icon'],
|
|
|
+ "type": page['type']
|
|
|
+ }
|
|
|
document.data_source_type = document_data["data_source"]["type"]
|
|
|
document.data_source_info = json.dumps(data_source_info)
|
|
|
document.name = file_name
|
|
@@ -513,15 +625,15 @@ class DocumentService:
|
|
|
db.session.add(dataset)
|
|
|
db.session.flush()
|
|
|
|
|
|
- document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
|
|
+ documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
|
|
|
|
|
cut_length = 18
|
|
|
- cut_name = document.name[:cut_length]
|
|
|
- dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
|
|
|
- dataset.description = 'useful for when you want to answer queries about the ' + document.name
|
|
|
+ cut_name = documents[0].name[:cut_length]
|
|
|
+ dataset.name = cut_name + '...'
|
|
|
+ dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
|
|
|
db.session.commit()
|
|
|
|
|
|
- return dataset, document
|
|
|
+ return dataset, documents, batch
|
|
|
|
|
|
@classmethod
|
|
|
def document_create_args_validate(cls, args: dict):
|
|
@@ -552,9 +664,15 @@ class DocumentService:
|
|
|
if args['data_source']['type'] not in Document.DATA_SOURCES:
|
|
|
raise ValueError("Data source type is invalid")
|
|
|
|
|
|
+ if 'info_list' not in args['data_source'] or not args['data_source']['info_list']:
|
|
|
+ raise ValueError("Data source info is required")
|
|
|
+
|
|
|
if args['data_source']['type'] == 'upload_file':
|
|
|
- if 'info' not in args['data_source'] or not args['data_source']['info']:
|
|
|
- raise ValueError("Data source info is required")
|
|
|
+ if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']:
|
|
|
+ raise ValueError("File source info is required")
|
|
|
+ if args['data_source']['type'] == 'notion_import':
|
|
|
+ if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']:
|
|
|
+ raise ValueError("Notion source info is required")
|
|
|
|
|
|
@classmethod
|
|
|
def process_rule_args_validate(cls, args: dict):
|
|
@@ -624,3 +742,78 @@ class DocumentService:
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
|
|
raise ValueError("Process rule segmentation max_tokens is invalid")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def estimate_args_validate(cls, args: dict):
|
|
|
+ if 'info_list' not in args or not args['info_list']:
|
|
|
+ raise ValueError("Data source info is required")
|
|
|
+
|
|
|
+ if not isinstance(args['info_list'], dict):
|
|
|
+ raise ValueError("Data info is invalid")
|
|
|
+
|
|
|
+ if 'process_rule' not in args or not args['process_rule']:
|
|
|
+ raise ValueError("Process rule is required")
|
|
|
+
|
|
|
+ if not isinstance(args['process_rule'], dict):
|
|
|
+ raise ValueError("Process rule is invalid")
|
|
|
+
|
|
|
+ if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
|
|
|
+ raise ValueError("Process rule mode is required")
|
|
|
+
|
|
|
+ if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
|
|
|
+ raise ValueError("Process rule mode is invalid")
|
|
|
+
|
|
|
+ if args['process_rule']['mode'] == 'automatic':
|
|
|
+ args['process_rule']['rules'] = {}
|
|
|
+ else:
|
|
|
+ if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
|
|
|
+ raise ValueError("Process rule rules is required")
|
|
|
+
|
|
|
+ if not isinstance(args['process_rule']['rules'], dict):
|
|
|
+ raise ValueError("Process rule rules is invalid")
|
|
|
+
|
|
|
+ if 'pre_processing_rules' not in args['process_rule']['rules'] \
|
|
|
+ or args['process_rule']['rules']['pre_processing_rules'] is None:
|
|
|
+ raise ValueError("Process rule pre_processing_rules is required")
|
|
|
+
|
|
|
+ if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
|
|
|
+ raise ValueError("Process rule pre_processing_rules is invalid")
|
|
|
+
|
|
|
+ unique_pre_processing_rule_dicts = {}
|
|
|
+ for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
|
|
|
+ if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
|
|
|
+ raise ValueError("Process rule pre_processing_rules id is required")
|
|
|
+
|
|
|
+ if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
|
|
+ raise ValueError("Process rule pre_processing_rules id is invalid")
|
|
|
+
|
|
|
+ if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
|
|
|
+ raise ValueError("Process rule pre_processing_rules enabled is required")
|
|
|
+
|
|
|
+ if not isinstance(pre_processing_rule['enabled'], bool):
|
|
|
+ raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
|
|
+
|
|
|
+ unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
|
|
|
+
|
|
|
+ args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
|
|
|
+
|
|
|
+ if 'segmentation' not in args['process_rule']['rules'] \
|
|
|
+ or args['process_rule']['rules']['segmentation'] is None:
|
|
|
+ raise ValueError("Process rule segmentation is required")
|
|
|
+
|
|
|
+ if not isinstance(args['process_rule']['rules']['segmentation'], dict):
|
|
|
+ raise ValueError("Process rule segmentation is invalid")
|
|
|
+
|
|
|
+ if 'separator' not in args['process_rule']['rules']['segmentation'] \
|
|
|
+ or not args['process_rule']['rules']['segmentation']['separator']:
|
|
|
+ raise ValueError("Process rule segmentation separator is required")
|
|
|
+
|
|
|
+ if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
|
|
|
+ raise ValueError("Process rule segmentation separator is invalid")
|
|
|
+
|
|
|
+ if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
|
|
|
+ or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
|
|
+ raise ValueError("Process rule segmentation max_tokens is required")
|
|
|
+
|
|
|
+ if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
|
|
+ raise ValueError("Process rule segmentation max_tokens is invalid")
|