|
@@ -38,28 +38,39 @@ from services.errors.dataset import DatasetNameDuplicateError
|
|
from services.errors.document import DocumentIndexingError
|
|
from services.errors.document import DocumentIndexingError
|
|
from services.errors.file import FileNotExistsError
|
|
from services.errors.file import FileNotExistsError
|
|
from services.feature_service import FeatureModel, FeatureService
|
|
from services.feature_service import FeatureModel, FeatureService
|
|
|
|
+from services.tag_service import TagService
|
|
from services.vector_service import VectorService
|
|
from services.vector_service import VectorService
|
|
from tasks.clean_notion_document_task import clean_notion_document_task
|
|
from tasks.clean_notion_document_task import clean_notion_document_task
|
|
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
|
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
|
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
|
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
|
from tasks.document_indexing_task import document_indexing_task
|
|
from tasks.document_indexing_task import document_indexing_task
|
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
|
|
|
+from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
|
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
|
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
|
|
|
+from tasks.retry_document_indexing_task import retry_document_indexing_task
|
|
|
|
|
|
|
|
|
|
class DatasetService:
|
|
class DatasetService:
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
- def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
|
|
|
|
|
|
+ def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
|
|
if user:
|
|
if user:
|
|
permission_filter = db.or_(Dataset.created_by == user.id,
|
|
permission_filter = db.or_(Dataset.created_by == user.id,
|
|
Dataset.permission == 'all_team_members')
|
|
Dataset.permission == 'all_team_members')
|
|
else:
|
|
else:
|
|
permission_filter = Dataset.permission == 'all_team_members'
|
|
permission_filter = Dataset.permission == 'all_team_members'
|
|
- datasets = Dataset.query.filter(
|
|
|
|
|
|
+ query = Dataset.query.filter(
|
|
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
|
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
|
- .order_by(Dataset.created_at.desc()) \
|
|
|
|
- .paginate(
|
|
|
|
|
|
+ .order_by(Dataset.created_at.desc())
|
|
|
|
+ if search:
|
|
|
|
+ query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
|
|
|
|
+ if tag_ids:
|
|
|
|
+ target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
|
|
|
|
+ if target_ids:
|
|
|
|
+ query = query.filter(db.and_(Dataset.id.in_(target_ids)))
|
|
|
|
+ else:
|
|
|
|
+ return [], 0
|
|
|
|
+ datasets = query.paginate(
|
|
page=page,
|
|
page=page,
|
|
per_page=per_page,
|
|
per_page=per_page,
|
|
max_per_page=100,
|
|
max_per_page=100,
|
|
@@ -165,9 +176,36 @@ class DatasetService:
|
|
# get embedding model setting
|
|
# get embedding model setting
|
|
try:
|
|
try:
|
|
model_manager = ModelManager()
|
|
model_manager = ModelManager()
|
|
- embedding_model = model_manager.get_default_model_instance(
|
|
|
|
|
|
+ embedding_model = model_manager.get_model_instance(
|
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
|
+ provider=data['embedding_model_provider'],
|
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
|
+ model=data['embedding_model']
|
|
|
|
+ )
|
|
|
|
+ filtered_data['embedding_model'] = embedding_model.model
|
|
|
|
+ filtered_data['embedding_model_provider'] = embedding_model.provider
|
|
|
|
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
|
+ embedding_model.provider,
|
|
|
|
+ embedding_model.model
|
|
|
|
+ )
|
|
|
|
+ filtered_data['collection_binding_id'] = dataset_collection_binding.id
|
|
|
|
+ except LLMBadRequestError:
|
|
|
|
+ raise ValueError(
|
|
|
|
+ "No Embedding Model available. Please configure a valid provider "
|
|
|
|
+ "in the Settings -> Model Provider.")
|
|
|
|
+ except ProviderTokenNotInitError as ex:
|
|
|
|
+ raise ValueError(ex.description)
|
|
|
|
+ else:
|
|
|
|
+ if data['embedding_model_provider'] != dataset.embedding_model_provider or \
|
|
|
|
+ data['embedding_model'] != dataset.embedding_model:
|
|
|
|
+ action = 'update'
|
|
|
|
+ try:
|
|
|
|
+ model_manager = ModelManager()
|
|
|
|
+ embedding_model = model_manager.get_model_instance(
|
|
tenant_id=current_user.current_tenant_id,
|
|
tenant_id=current_user.current_tenant_id,
|
|
- model_type=ModelType.TEXT_EMBEDDING
|
|
|
|
|
|
+ provider=data['embedding_model_provider'],
|
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
|
+ model=data['embedding_model']
|
|
)
|
|
)
|
|
filtered_data['embedding_model'] = embedding_model.model
|
|
filtered_data['embedding_model'] = embedding_model.model
|
|
filtered_data['embedding_model_provider'] = embedding_model.provider
|
|
filtered_data['embedding_model_provider'] = embedding_model.provider
|
|
@@ -376,6 +414,15 @@ class DocumentService:
|
|
|
|
|
|
return documents
|
|
return documents
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
+ def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
|
|
|
+ documents = db.session.query(Document).filter(
|
|
|
|
+ Document.dataset_id == dataset_id,
|
|
|
|
+ Document.indexing_status == 'error' or Document.indexing_status == 'paused'
|
|
|
|
+ ).all()
|
|
|
|
+
|
|
|
|
+ return documents
|
|
|
|
+
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
|
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
|
documents = db.session.query(Document).filter(
|
|
documents = db.session.query(Document).filter(
|
|
@@ -440,6 +487,20 @@ class DocumentService:
|
|
# trigger async task
|
|
# trigger async task
|
|
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
|
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
+ def retry_document(dataset_id: str, documents: list[Document]):
|
|
|
|
+ for document in documents:
|
|
|
|
+ # retry document indexing
|
|
|
|
+ document.indexing_status = 'waiting'
|
|
|
|
+ db.session.add(document)
|
|
|
|
+ db.session.commit()
|
|
|
|
+ # add retry flag
|
|
|
|
+ retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
|
|
|
+ redis_client.setex(retry_indexing_cache_key, 600, 1)
|
|
|
|
+ # trigger async task
|
|
|
|
+ document_ids = [document.id for document in documents]
|
|
|
|
+ retry_document_indexing_task.delay(dataset_id, document_ids)
|
|
|
|
+
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_documents_position(dataset_id):
|
|
def get_documents_position(dataset_id):
|
|
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
|
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
|
@@ -537,6 +598,7 @@ class DocumentService:
|
|
db.session.commit()
|
|
db.session.commit()
|
|
position = DocumentService.get_documents_position(dataset.id)
|
|
position = DocumentService.get_documents_position(dataset.id)
|
|
document_ids = []
|
|
document_ids = []
|
|
|
|
+ duplicate_document_ids = []
|
|
if document_data["data_source"]["type"] == "upload_file":
|
|
if document_data["data_source"]["type"] == "upload_file":
|
|
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
|
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
|
for file_id in upload_file_list:
|
|
for file_id in upload_file_list:
|
|
@@ -553,6 +615,28 @@ class DocumentService:
|
|
data_source_info = {
|
|
data_source_info = {
|
|
"upload_file_id": file_id,
|
|
"upload_file_id": file_id,
|
|
}
|
|
}
|
|
|
|
+ # check duplicate
|
|
|
|
+ if document_data.get('duplicate', False):
|
|
|
|
+ document = Document.query.filter_by(
|
|
|
|
+ dataset_id=dataset.id,
|
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
|
+ data_source_type='upload_file',
|
|
|
|
+ enabled=True,
|
|
|
|
+ name=file_name
|
|
|
|
+ ).first()
|
|
|
|
+ if document:
|
|
|
|
+ document.dataset_process_rule_id = dataset_process_rule.id
|
|
|
|
+ document.updated_at = datetime.datetime.utcnow()
|
|
|
|
+ document.created_from = created_from
|
|
|
|
+ document.doc_form = document_data['doc_form']
|
|
|
|
+ document.doc_language = document_data['doc_language']
|
|
|
|
+ document.data_source_info = json.dumps(data_source_info)
|
|
|
|
+ document.batch = batch
|
|
|
|
+ document.indexing_status = 'waiting'
|
|
|
|
+ db.session.add(document)
|
|
|
|
+ documents.append(document)
|
|
|
|
+ duplicate_document_ids.append(document.id)
|
|
|
|
+ continue
|
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
document_data["data_source"]["type"],
|
|
document_data["data_source"]["type"],
|
|
document_data["doc_form"],
|
|
document_data["doc_form"],
|
|
@@ -618,7 +702,10 @@ class DocumentService:
|
|
db.session.commit()
|
|
db.session.commit()
|
|
|
|
|
|
# trigger async task
|
|
# trigger async task
|
|
- document_indexing_task.delay(dataset.id, document_ids)
|
|
|
|
|
|
+ if document_ids:
|
|
|
|
+ document_indexing_task.delay(dataset.id, document_ids)
|
|
|
|
+ if duplicate_document_ids:
|
|
|
|
+ duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
|
|
|
|
|
return documents, batch
|
|
return documents, batch
|
|
|
|
|
|
@@ -626,7 +713,8 @@ class DocumentService:
|
|
def check_documents_upload_quota(count: int, features: FeatureModel):
|
|
def check_documents_upload_quota(count: int, features: FeatureModel):
|
|
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
|
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
|
if count > can_upload_size:
|
|
if count > can_upload_size:
|
|
- raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
|
|
|
|
|
+ raise ValueError(
|
|
|
|
+ f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
|
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
|
@@ -752,7 +840,6 @@ class DocumentService:
|
|
db.session.commit()
|
|
db.session.commit()
|
|
# trigger async task
|
|
# trigger async task
|
|
document_indexing_update_task.delay(document.dataset_id, document.id)
|
|
document_indexing_update_task.delay(document.dataset_id, document.id)
|
|
-
|
|
|
|
return document
|
|
return document
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|