|
@@ -14,6 +14,7 @@ from configs import dify_config
|
|
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
|
|
from core.model_manager import ModelManager
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
+from core.rag.index_processor.constant.index_type import IndexType
|
|
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|
|
from events.dataset_event import dataset_was_deleted
|
|
|
from events.document_event import document_was_deleted
|
|
@@ -23,7 +24,9 @@ from libs import helper
|
|
|
from models.account import Account, TenantAccountRole
|
|
|
from models.dataset import (
|
|
|
AppDatasetJoin,
|
|
|
+ ChildChunk,
|
|
|
Dataset,
|
|
|
+ DatasetAutoDisableLog,
|
|
|
DatasetCollectionBinding,
|
|
|
DatasetPermission,
|
|
|
DatasetPermissionEnum,
|
|
@@ -35,8 +38,14 @@ from models.dataset import (
|
|
|
)
|
|
|
from models.model import UploadFile
|
|
|
from models.source import DataSourceOauthBinding
|
|
|
-from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity
|
|
|
-from services.errors.account import NoPermissionError
|
|
|
+from services.entities.knowledge_entities.knowledge_entities import (
|
|
|
+ ChildChunkUpdateArgs,
|
|
|
+ KnowledgeConfig,
|
|
|
+ RetrievalModel,
|
|
|
+ SegmentUpdateArgs,
|
|
|
+)
|
|
|
+from services.errors.account import InvalidActionError, NoPermissionError
|
|
|
+from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
|
|
from services.errors.dataset import DatasetNameDuplicateError
|
|
|
from services.errors.document import DocumentIndexingError
|
|
|
from services.errors.file import FileNotExistsError
|
|
@@ -44,13 +53,16 @@ from services.external_knowledge_service import ExternalDatasetService
|
|
|
from services.feature_service import FeatureModel, FeatureService
|
|
|
from services.tag_service import TagService
|
|
|
from services.vector_service import VectorService
|
|
|
+from tasks.batch_clean_document_task import batch_clean_document_task
|
|
|
from tasks.clean_notion_document_task import clean_notion_document_task
|
|
|
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
|
|
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
|
|
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
|
|
+from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
|
|
from tasks.document_indexing_task import document_indexing_task
|
|
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
|
|
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
|
|
+from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
|
|
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
|
|
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
|
|
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
|
@@ -408,6 +420,24 @@ class DatasetService:
|
|
|
.all()
|
|
|
)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
|
|
+
|
|
|
+ start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
|
|
+ dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
|
|
+ DatasetAutoDisableLog.dataset_id == dataset_id,
|
|
|
+ DatasetAutoDisableLog.created_at >= start_date,
|
|
|
+ ).all()
|
|
|
+ if dataset_auto_disable_logs:
|
|
|
+ return {
|
|
|
+ "document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
|
|
+ "count": len(dataset_auto_disable_logs),
|
|
|
+ }
|
|
|
+ return {
|
|
|
+ "document_ids": [],
|
|
|
+ "count": 0,
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
class DocumentService:
|
|
|
DEFAULT_RULES = {
|
|
@@ -588,6 +618,20 @@ class DocumentService:
|
|
|
db.session.delete(document)
|
|
|
db.session.commit()
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def delete_documents(dataset: Dataset, document_ids: list[str]):
|
|
|
+ documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
|
|
|
+ file_ids = [
|
|
|
+ document.data_source_info_dict["upload_file_id"]
|
|
|
+ for document in documents
|
|
|
+ if document.data_source_type == "upload_file"
|
|
|
+ ]
|
|
|
+ batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
|
|
+
|
|
|
+ for document in documents:
|
|
|
+ db.session.delete(document)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
@staticmethod
|
|
|
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
|
|
|
dataset = DatasetService.get_dataset(dataset_id)
|
|
@@ -689,7 +733,7 @@ class DocumentService:
|
|
|
@staticmethod
|
|
|
def save_document_with_dataset_id(
|
|
|
dataset: Dataset,
|
|
|
- document_data: dict,
|
|
|
+ knowledge_config: KnowledgeConfig,
|
|
|
account: Account | Any,
|
|
|
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
created_from: str = "web",
|
|
@@ -698,18 +742,18 @@ class DocumentService:
|
|
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
|
|
|
|
|
if features.billing.enabled:
|
|
|
- if "original_document_id" not in document_data or not document_data["original_document_id"]:
|
|
|
+ if not knowledge_config.original_document_id:
|
|
|
count = 0
|
|
|
- if document_data["data_source"]["type"] == "upload_file":
|
|
|
- upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
count = len(upload_file_list)
|
|
|
- elif document_data["data_source"]["type"] == "notion_import":
|
|
|
- notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
for notion_info in notion_info_list:
|
|
|
- count = count + len(notion_info["pages"])
|
|
|
- elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
- website_info = document_data["data_source"]["info_list"]["website_info_list"]
|
|
|
- count = len(website_info["urls"])
|
|
|
+ count = count + len(notion_info.pages)
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
+ website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
+ count = len(website_info.urls)
|
|
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
|
|
if count > batch_upload_limit:
|
|
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
@@ -718,17 +762,14 @@ class DocumentService:
|
|
|
|
|
|
|
|
|
if not dataset.data_source_type:
|
|
|
- dataset.data_source_type = document_data["data_source"]["type"]
|
|
|
+ dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
|
|
|
|
|
if not dataset.indexing_technique:
|
|
|
- if (
|
|
|
- "indexing_technique" not in document_data
|
|
|
- or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST
|
|
|
- ):
|
|
|
- raise ValueError("Indexing technique is required")
|
|
|
+ if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
|
|
+ raise ValueError("Indexing technique is invalid")
|
|
|
|
|
|
- dataset.indexing_technique = document_data["indexing_technique"]
|
|
|
- if document_data["indexing_technique"] == "high_quality":
|
|
|
+ dataset.indexing_technique = knowledge_config.indexing_technique
|
|
|
+ if knowledge_config.indexing_technique == "high_quality":
|
|
|
model_manager = ModelManager()
|
|
|
embedding_model = model_manager.get_default_model_instance(
|
|
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
|
@@ -748,29 +789,29 @@ class DocumentService:
|
|
|
"score_threshold_enabled": False,
|
|
|
}
|
|
|
|
|
|
- dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model
|
|
|
+ dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
|
|
|
|
|
|
documents = []
|
|
|
- if document_data.get("original_document_id"):
|
|
|
- document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
|
|
|
+ if knowledge_config.original_document_id:
|
|
|
+ document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
|
|
documents.append(document)
|
|
|
batch = document.batch
|
|
|
else:
|
|
|
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
|
|
|
|
|
if not dataset_process_rule:
|
|
|
- process_rule = document_data["process_rule"]
|
|
|
- if process_rule["mode"] == "custom":
|
|
|
+ process_rule = knowledge_config.process_rule
|
|
|
+ if process_rule.mode in ("custom", "hierarchical"):
|
|
|
dataset_process_rule = DatasetProcessRule(
|
|
|
dataset_id=dataset.id,
|
|
|
- mode=process_rule["mode"],
|
|
|
- rules=json.dumps(process_rule["rules"]),
|
|
|
+ mode=process_rule.mode,
|
|
|
+ rules=process_rule.rules.model_dump_json(),
|
|
|
created_by=account.id,
|
|
|
)
|
|
|
- elif process_rule["mode"] == "automatic":
|
|
|
+ elif process_rule.mode == "automatic":
|
|
|
dataset_process_rule = DatasetProcessRule(
|
|
|
dataset_id=dataset.id,
|
|
|
- mode=process_rule["mode"],
|
|
|
+ mode=process_rule.mode,
|
|
|
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
|
|
created_by=account.id,
|
|
|
)
|
|
@@ -786,8 +827,8 @@ class DocumentService:
|
|
|
position = DocumentService.get_documents_position(dataset.id)
|
|
|
document_ids = []
|
|
|
duplicate_document_ids = []
|
|
|
- if document_data["data_source"]["type"] == "upload_file":
|
|
|
- upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
for file_id in upload_file_list:
|
|
|
file = (
|
|
|
db.session.query(UploadFile)
|
|
@@ -804,7 +845,7 @@ class DocumentService:
|
|
|
"upload_file_id": file_id,
|
|
|
}
|
|
|
|
|
|
- if document_data.get("duplicate", False):
|
|
|
+ if knowledge_config.duplicate:
|
|
|
document = Document.query.filter_by(
|
|
|
dataset_id=dataset.id,
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
@@ -814,10 +855,10 @@ class DocumentService:
|
|
|
).first()
|
|
|
if document:
|
|
|
document.dataset_process_rule_id = dataset_process_rule.id
|
|
|
- document.updated_at = datetime.datetime.utcnow()
|
|
|
+ document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
document.created_from = created_from
|
|
|
- document.doc_form = document_data["doc_form"]
|
|
|
- document.doc_language = document_data["doc_language"]
|
|
|
+ document.doc_form = knowledge_config.doc_form
|
|
|
+ document.doc_language = knowledge_config.doc_language
|
|
|
document.data_source_info = json.dumps(data_source_info)
|
|
|
document.batch = batch
|
|
|
document.indexing_status = "waiting"
|
|
@@ -828,9 +869,9 @@ class DocumentService:
|
|
|
document = DocumentService.build_document(
|
|
|
dataset,
|
|
|
dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
+ knowledge_config.data_source.info_list.data_source_type,
|
|
|
+ knowledge_config.doc_form,
|
|
|
+ knowledge_config.doc_language,
|
|
|
data_source_info,
|
|
|
created_from,
|
|
|
position,
|
|
@@ -843,8 +884,8 @@ class DocumentService:
|
|
|
document_ids.append(document.id)
|
|
|
documents.append(document)
|
|
|
position += 1
|
|
|
- elif document_data["data_source"]["type"] == "notion_import":
|
|
|
- notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
exist_page_ids = []
|
|
|
exist_document = {}
|
|
|
documents = Document.query.filter_by(
|
|
@@ -859,7 +900,7 @@ class DocumentService:
|
|
|
exist_page_ids.append(data_source_info["notion_page_id"])
|
|
|
exist_document[data_source_info["notion_page_id"]] = document.id
|
|
|
for notion_info in notion_info_list:
|
|
|
- workspace_id = notion_info["workspace_id"]
|
|
|
+ workspace_id = notion_info.workspace_id
|
|
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
|
|
db.and_(
|
|
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
@@ -870,25 +911,25 @@ class DocumentService:
|
|
|
).first()
|
|
|
if not data_source_binding:
|
|
|
raise ValueError("Data source binding not found.")
|
|
|
- for page in notion_info["pages"]:
|
|
|
- if page["page_id"] not in exist_page_ids:
|
|
|
+ for page in notion_info.pages:
|
|
|
+ if page.page_id not in exist_page_ids:
|
|
|
data_source_info = {
|
|
|
"notion_workspace_id": workspace_id,
|
|
|
- "notion_page_id": page["page_id"],
|
|
|
- "notion_page_icon": page["page_icon"],
|
|
|
- "type": page["type"],
|
|
|
+ "notion_page_id": page.page_id,
|
|
|
+ "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
|
|
+ "type": page.type,
|
|
|
}
|
|
|
document = DocumentService.build_document(
|
|
|
dataset,
|
|
|
dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
+ knowledge_config.data_source.info_list.data_source_type,
|
|
|
+ knowledge_config.doc_form,
|
|
|
+ knowledge_config.doc_language,
|
|
|
data_source_info,
|
|
|
created_from,
|
|
|
position,
|
|
|
account,
|
|
|
- page["page_name"],
|
|
|
+ page.page_name,
|
|
|
batch,
|
|
|
)
|
|
|
db.session.add(document)
|
|
@@ -897,19 +938,19 @@ class DocumentService:
|
|
|
documents.append(document)
|
|
|
position += 1
|
|
|
else:
|
|
|
- exist_document.pop(page["page_id"])
|
|
|
+ exist_document.pop(page.page_id)
|
|
|
|
|
|
if len(exist_document) > 0:
|
|
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
|
|
- elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
- website_info = document_data["data_source"]["info_list"]["website_info_list"]
|
|
|
- urls = website_info["urls"]
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
+ website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
+ urls = website_info.urls
|
|
|
for url in urls:
|
|
|
data_source_info = {
|
|
|
"url": url,
|
|
|
- "provider": website_info["provider"],
|
|
|
- "job_id": website_info["job_id"],
|
|
|
- "only_main_content": website_info.get("only_main_content", False),
|
|
|
+ "provider": website_info.provider,
|
|
|
+ "job_id": website_info.job_id,
|
|
|
+ "only_main_content": website_info.only_main_content,
|
|
|
"mode": "crawl",
|
|
|
}
|
|
|
if len(url) > 255:
|
|
@@ -919,9 +960,9 @@ class DocumentService:
|
|
|
document = DocumentService.build_document(
|
|
|
dataset,
|
|
|
dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
+ knowledge_config.data_source.info_list.data_source_type,
|
|
|
+ knowledge_config.doc_form,
|
|
|
+ knowledge_config.doc_language,
|
|
|
data_source_info,
|
|
|
created_from,
|
|
|
position,
|
|
@@ -995,31 +1036,31 @@ class DocumentService:
|
|
|
@staticmethod
|
|
|
def update_document_with_dataset_id(
|
|
|
dataset: Dataset,
|
|
|
- document_data: dict,
|
|
|
+ document_data: KnowledgeConfig,
|
|
|
account: Account,
|
|
|
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
created_from: str = "web",
|
|
|
):
|
|
|
DatasetService.check_dataset_model_setting(dataset)
|
|
|
- document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
|
|
+ document = DocumentService.get_document(dataset.id, document_data.original_document_id)
|
|
|
if document is None:
|
|
|
raise NotFound("Document not found")
|
|
|
if document.display_status != "available":
|
|
|
raise ValueError("Document is not available")
|
|
|
|
|
|
- if document_data.get("process_rule"):
|
|
|
- process_rule = document_data["process_rule"]
|
|
|
- if process_rule["mode"] == "custom":
|
|
|
+ if document_data.process_rule:
|
|
|
+ process_rule = document_data.process_rule
|
|
|
+ if process_rule.mode in {"custom", "hierarchical"}:
|
|
|
dataset_process_rule = DatasetProcessRule(
|
|
|
dataset_id=dataset.id,
|
|
|
- mode=process_rule["mode"],
|
|
|
- rules=json.dumps(process_rule["rules"]),
|
|
|
+ mode=process_rule.mode,
|
|
|
+ rules=process_rule.rules.model_dump_json(),
|
|
|
created_by=account.id,
|
|
|
)
|
|
|
- elif process_rule["mode"] == "automatic":
|
|
|
+ elif process_rule.mode == "automatic":
|
|
|
dataset_process_rule = DatasetProcessRule(
|
|
|
dataset_id=dataset.id,
|
|
|
- mode=process_rule["mode"],
|
|
|
+ mode=process_rule.mode,
|
|
|
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
|
|
created_by=account.id,
|
|
|
)
|
|
@@ -1028,11 +1069,11 @@ class DocumentService:
|
|
|
db.session.commit()
|
|
|
document.dataset_process_rule_id = dataset_process_rule.id
|
|
|
|
|
|
- if document_data.get("data_source"):
|
|
|
+ if document_data.data_source:
|
|
|
file_name = ""
|
|
|
data_source_info = {}
|
|
|
- if document_data["data_source"]["type"] == "upload_file":
|
|
|
- upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
|
|
|
+ if document_data.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
|
|
for file_id in upload_file_list:
|
|
|
file = (
|
|
|
db.session.query(UploadFile)
|
|
@@ -1048,10 +1089,10 @@ class DocumentService:
|
|
|
data_source_info = {
|
|
|
"upload_file_id": file_id,
|
|
|
}
|
|
|
- elif document_data["data_source"]["type"] == "notion_import":
|
|
|
- notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
|
|
|
+ elif document_data.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ notion_info_list = document_data.data_source.info_list.notion_info_list
|
|
|
for notion_info in notion_info_list:
|
|
|
- workspace_id = notion_info["workspace_id"]
|
|
|
+ workspace_id = notion_info.workspace_id
|
|
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
|
|
db.and_(
|
|
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
@@ -1062,31 +1103,31 @@ class DocumentService:
|
|
|
).first()
|
|
|
if not data_source_binding:
|
|
|
raise ValueError("Data source binding not found.")
|
|
|
- for page in notion_info["pages"]:
|
|
|
+ for page in notion_info.pages:
|
|
|
data_source_info = {
|
|
|
"notion_workspace_id": workspace_id,
|
|
|
- "notion_page_id": page["page_id"],
|
|
|
- "notion_page_icon": page["page_icon"],
|
|
|
- "type": page["type"],
|
|
|
+ "notion_page_id": page.page_id,
|
|
|
+ "notion_page_icon": page.page_icon,
|
|
|
+ "type": page.type,
|
|
|
}
|
|
|
- elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
- website_info = document_data["data_source"]["info_list"]["website_info_list"]
|
|
|
- urls = website_info["urls"]
|
|
|
+ elif document_data.data_source.info_list.data_source_type == "website_crawl":
|
|
|
+ website_info = document_data.data_source.info_list.website_info_list
|
|
|
+ urls = website_info.urls
|
|
|
for url in urls:
|
|
|
data_source_info = {
|
|
|
"url": url,
|
|
|
- "provider": website_info["provider"],
|
|
|
- "job_id": website_info["job_id"],
|
|
|
- "only_main_content": website_info.get("only_main_content", False),
|
|
|
+ "provider": website_info.provider,
|
|
|
+ "job_id": website_info.job_id,
|
|
|
+ "only_main_content": website_info.only_main_content,
|
|
|
"mode": "crawl",
|
|
|
}
|
|
|
- document.data_source_type = document_data["data_source"]["type"]
|
|
|
+ document.data_source_type = document_data.data_source.info_list.data_source_type
|
|
|
document.data_source_info = json.dumps(data_source_info)
|
|
|
document.name = file_name
|
|
|
|
|
|
|
|
|
- if document_data.get("name"):
|
|
|
- document.name = document_data["name"]
|
|
|
+ if document_data.name:
|
|
|
+ document.name = document_data.name
|
|
|
|
|
|
document.indexing_status = "waiting"
|
|
|
document.completed_at = None
|
|
@@ -1096,7 +1137,7 @@ class DocumentService:
|
|
|
document.splitting_completed_at = None
|
|
|
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
document.created_from = created_from
|
|
|
- document.doc_form = document_data["doc_form"]
|
|
|
+ document.doc_form = document_data.doc_form
|
|
|
db.session.add(document)
|
|
|
db.session.commit()
|
|
|
|
|
@@ -1108,21 +1149,21 @@ class DocumentService:
|
|
|
return document
|
|
|
|
|
|
@staticmethod
|
|
|
- def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
|
|
|
+ def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
|
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
|
|
|
|
|
if features.billing.enabled:
|
|
|
count = 0
|
|
|
- if document_data["data_source"]["type"] == "upload_file":
|
|
|
- upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
count = len(upload_file_list)
|
|
|
- elif document_data["data_source"]["type"] == "notion_import":
|
|
|
- notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
for notion_info in notion_info_list:
|
|
|
- count = count + len(notion_info["pages"])
|
|
|
- elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
- website_info = document_data["data_source"]["info_list"]["website_info_list"]
|
|
|
- count = len(website_info["urls"])
|
|
|
+ count = count + len(notion_info.pages)
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
+ website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
+ count = len(website_info.urls)
|
|
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
|
|
if count > batch_upload_limit:
|
|
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
@@ -1131,13 +1172,13 @@ class DocumentService:
|
|
|
|
|
|
dataset_collection_binding_id = None
|
|
|
retrieval_model = None
|
|
|
- if document_data["indexing_technique"] == "high_quality":
|
|
|
+ if knowledge_config.indexing_technique == "high_quality":
|
|
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
- document_data["embedding_model_provider"], document_data["embedding_model"]
|
|
|
+ knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
|
|
)
|
|
|
dataset_collection_binding_id = dataset_collection_binding.id
|
|
|
- if document_data.get("retrieval_model"):
|
|
|
- retrieval_model = document_data["retrieval_model"]
|
|
|
+ if knowledge_config.retrieval_model:
|
|
|
+ retrieval_model = knowledge_config.retrieval_model
|
|
|
else:
|
|
|
default_retrieval_model = {
|
|
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
@@ -1146,24 +1187,24 @@ class DocumentService:
|
|
|
"top_k": 2,
|
|
|
"score_threshold_enabled": False,
|
|
|
}
|
|
|
- retrieval_model = default_retrieval_model
|
|
|
+ retrieval_model = RetrievalModel(**default_retrieval_model)
|
|
|
|
|
|
dataset = Dataset(
|
|
|
tenant_id=tenant_id,
|
|
|
name="",
|
|
|
- data_source_type=document_data["data_source"]["type"],
|
|
|
- indexing_technique=document_data.get("indexing_technique", "high_quality"),
|
|
|
+ data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
|
|
+ indexing_technique=knowledge_config.indexing_technique,
|
|
|
created_by=account.id,
|
|
|
- embedding_model=document_data.get("embedding_model"),
|
|
|
- embedding_model_provider=document_data.get("embedding_model_provider"),
|
|
|
+ embedding_model=knowledge_config.embedding_model,
|
|
|
+ embedding_model_provider=knowledge_config.embedding_model_provider,
|
|
|
collection_binding_id=dataset_collection_binding_id,
|
|
|
- retrieval_model=retrieval_model,
|
|
|
+ retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
|
|
)
|
|
|
|
|
|
- db.session.add(dataset)
|
|
|
+ db.session.add(dataset)
|
|
|
db.session.flush()
|
|
|
|
|
|
- documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
|
|
|
+ documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
|
|
|
|
|
|
cut_length = 18
|
|
|
cut_name = documents[0].name[:cut_length]
|
|
@@ -1174,133 +1215,86 @@ class DocumentService:
|
|
|
return dataset, documents, batch
|
|
|
|
|
|
@classmethod
|
|
|
- def document_create_args_validate(cls, args: dict):
|
|
|
- if "original_document_id" not in args or not args["original_document_id"]:
|
|
|
- DocumentService.data_source_args_validate(args)
|
|
|
- DocumentService.process_rule_args_validate(args)
|
|
|
+ def document_create_args_validate(cls, knowledge_config: KnowledgeConfig):
|
|
|
+ if not knowledge_config.data_source and not knowledge_config.process_rule:
|
|
|
+ raise ValueError("Data source or Process rule is required")
|
|
|
else:
|
|
|
- if ("data_source" not in args or not args["data_source"]) and (
|
|
|
- "process_rule" not in args or not args["process_rule"]
|
|
|
- ):
|
|
|
- raise ValueError("Data source or Process rule is required")
|
|
|
- else:
|
|
|
- if args.get("data_source"):
|
|
|
- DocumentService.data_source_args_validate(args)
|
|
|
- if args.get("process_rule"):
|
|
|
- DocumentService.process_rule_args_validate(args)
|
|
|
+ if knowledge_config.data_source:
|
|
|
+ DocumentService.data_source_args_validate(knowledge_config)
|
|
|
+ if knowledge_config.process_rule:
|
|
|
+ DocumentService.process_rule_args_validate(knowledge_config)
|
|
|
|
|
|
@classmethod
|
|
|
- def data_source_args_validate(cls, args: dict):
|
|
|
- if "data_source" not in args or not args["data_source"]:
|
|
|
+ def data_source_args_validate(cls, knowledge_config: KnowledgeConfig):
|
|
|
+ if not knowledge_config.data_source:
|
|
|
raise ValueError("Data source is required")
|
|
|
|
|
|
- if not isinstance(args["data_source"], dict):
|
|
|
- raise ValueError("Data source is invalid")
|
|
|
-
|
|
|
- if "type" not in args["data_source"] or not args["data_source"]["type"]:
|
|
|
- raise ValueError("Data source type is required")
|
|
|
-
|
|
|
- if args["data_source"]["type"] not in Document.DATA_SOURCES:
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES:
|
|
|
raise ValueError("Data source type is invalid")
|
|
|
|
|
|
- if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]:
|
|
|
+ if not knowledge_config.data_source.info_list:
|
|
|
raise ValueError("Data source info is required")
|
|
|
|
|
|
- if args["data_source"]["type"] == "upload_file":
|
|
|
- if (
|
|
|
- "file_info_list" not in args["data_source"]["info_list"]
|
|
|
- or not args["data_source"]["info_list"]["file_info_list"]
|
|
|
- ):
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ if not knowledge_config.data_source.info_list.file_info_list:
|
|
|
raise ValueError("File source info is required")
|
|
|
- if args["data_source"]["type"] == "notion_import":
|
|
|
- if (
|
|
|
- "notion_info_list" not in args["data_source"]["info_list"]
|
|
|
- or not args["data_source"]["info_list"]["notion_info_list"]
|
|
|
- ):
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ if not knowledge_config.data_source.info_list.notion_info_list:
|
|
|
raise ValueError("Notion source info is required")
|
|
|
- if args["data_source"]["type"] == "website_crawl":
|
|
|
- if (
|
|
|
- "website_info_list" not in args["data_source"]["info_list"]
|
|
|
- or not args["data_source"]["info_list"]["website_info_list"]
|
|
|
- ):
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
+ if not knowledge_config.data_source.info_list.website_info_list:
|
|
|
raise ValueError("Website source info is required")
|
|
|
|
|
|
@classmethod
|
|
|
- def process_rule_args_validate(cls, args: dict):
|
|
|
- if "process_rule" not in args or not args["process_rule"]:
|
|
|
+ def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig):
|
|
|
+ if not knowledge_config.process_rule:
|
|
|
raise ValueError("Process rule is required")
|
|
|
|
|
|
- if not isinstance(args["process_rule"], dict):
|
|
|
- raise ValueError("Process rule is invalid")
|
|
|
-
|
|
|
- if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
|
|
|
+ if not knowledge_config.process_rule.mode:
|
|
|
raise ValueError("Process rule mode is required")
|
|
|
|
|
|
- if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
|
|
|
+ if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
|
|
|
raise ValueError("Process rule mode is invalid")
|
|
|
|
|
|
- if args["process_rule"]["mode"] == "automatic":
|
|
|
- args["process_rule"]["rules"] = {}
|
|
|
+ if knowledge_config.process_rule.mode == "automatic":
|
|
|
+ knowledge_config.process_rule.rules = None
|
|
|
else:
|
|
|
- if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
|
|
|
+ if not knowledge_config.process_rule.rules:
|
|
|
raise ValueError("Process rule rules is required")
|
|
|
|
|
|
- if not isinstance(args["process_rule"]["rules"], dict):
|
|
|
- raise ValueError("Process rule rules is invalid")
|
|
|
-
|
|
|
- if (
|
|
|
- "pre_processing_rules" not in args["process_rule"]["rules"]
|
|
|
- or args["process_rule"]["rules"]["pre_processing_rules"] is None
|
|
|
- ):
|
|
|
+ if knowledge_config.process_rule.rules.pre_processing_rules is None:
|
|
|
raise ValueError("Process rule pre_processing_rules is required")
|
|
|
|
|
|
- if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
|
|
|
- raise ValueError("Process rule pre_processing_rules is invalid")
|
|
|
-
|
|
|
unique_pre_processing_rule_dicts = {}
|
|
|
- for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
|
|
|
- if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
|
|
|
+ for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules:
|
|
|
+ if not pre_processing_rule.id:
|
|
|
raise ValueError("Process rule pre_processing_rules id is required")
|
|
|
|
|
|
- if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
|
|
- raise ValueError("Process rule pre_processing_rules id is invalid")
|
|
|
-
|
|
|
- if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
|
|
|
- raise ValueError("Process rule pre_processing_rules enabled is required")
|
|
|
-
|
|
|
- if not isinstance(pre_processing_rule["enabled"], bool):
|
|
|
+ if not isinstance(pre_processing_rule.enabled, bool):
|
|
|
raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
|
|
|
|
|
- unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
|
|
|
+ unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule
|
|
|
|
|
|
- args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
|
|
|
+ knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
|
|
|
|
|
|
- if (
|
|
|
- "segmentation" not in args["process_rule"]["rules"]
|
|
|
- or args["process_rule"]["rules"]["segmentation"] is None
|
|
|
- ):
|
|
|
+ if not knowledge_config.process_rule.rules.segmentation:
|
|
|
raise ValueError("Process rule segmentation is required")
|
|
|
|
|
|
- if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
|
|
|
- raise ValueError("Process rule segmentation is invalid")
|
|
|
-
|
|
|
- if (
|
|
|
- "separator" not in args["process_rule"]["rules"]["segmentation"]
|
|
|
- or not args["process_rule"]["rules"]["segmentation"]["separator"]
|
|
|
- ):
|
|
|
+ if not knowledge_config.process_rule.rules.segmentation.separator:
|
|
|
raise ValueError("Process rule segmentation separator is required")
|
|
|
|
|
|
- if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
|
|
|
+ if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str):
|
|
|
raise ValueError("Process rule segmentation separator is invalid")
|
|
|
|
|
|
- if (
|
|
|
- "max_tokens" not in args["process_rule"]["rules"]["segmentation"]
|
|
|
- or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
|
|
|
+ if not (
|
|
|
+ knowledge_config.process_rule.mode == "hierarchical"
|
|
|
+ and knowledge_config.process_rule.rules.parent_mode == "full-doc"
|
|
|
):
|
|
|
- raise ValueError("Process rule segmentation max_tokens is required")
|
|
|
+ if not knowledge_config.process_rule.rules.segmentation.max_tokens:
|
|
|
+ raise ValueError("Process rule segmentation max_tokens is required")
|
|
|
|
|
|
- if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
|
|
- raise ValueError("Process rule segmentation max_tokens is invalid")
|
|
|
+ if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int):
|
|
|
+ raise ValueError("Process rule segmentation max_tokens is invalid")
|
|
|
|
|
|
@classmethod
|
|
|
def estimate_args_validate(cls, args: dict):
|
|
@@ -1447,7 +1441,7 @@ class SegmentService:
|
|
|
|
|
|
|
|
|
try:
|
|
|
- VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset)
|
|
|
+ VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
|
|
|
except Exception as e:
|
|
|
logging.exception("create segment index failed")
|
|
|
segment_document.enabled = False
|
|
@@ -1525,7 +1519,7 @@ class SegmentService:
|
|
|
db.session.add(document)
|
|
|
try:
|
|
|
|
|
|
- VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
|
|
|
+ VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
|
|
|
except Exception as e:
|
|
|
logging.exception("create segment index failed")
|
|
|
for segment_document in segment_data_list:
|
|
@@ -1537,14 +1531,13 @@ class SegmentService:
|
|
|
return segment_data_list
|
|
|
|
|
|
@classmethod
|
|
|
- def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
|
|
- segment_update_entity = SegmentUpdateEntity(**args)
|
|
|
+ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
|
|
|
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
|
|
cache_result = redis_client.get(indexing_cache_key)
|
|
|
if cache_result is not None:
|
|
|
raise ValueError("Segment is indexing, please try again later")
|
|
|
- if segment_update_entity.enabled is not None:
|
|
|
- action = segment_update_entity.enabled
|
|
|
+ if args.enabled is not None:
|
|
|
+ action = args.enabled
|
|
|
if segment.enabled != action:
|
|
|
if not action:
|
|
|
segment.enabled = action
|
|
@@ -1557,22 +1550,22 @@ class SegmentService:
|
|
|
disable_segment_from_index_task.delay(segment.id)
|
|
|
return segment
|
|
|
if not segment.enabled:
|
|
|
- if segment_update_entity.enabled is not None:
|
|
|
- if not segment_update_entity.enabled:
|
|
|
+ if args.enabled is not None:
|
|
|
+ if not args.enabled:
|
|
|
raise ValueError("Can't update disabled segment")
|
|
|
else:
|
|
|
raise ValueError("Can't update disabled segment")
|
|
|
try:
|
|
|
word_count_change = segment.word_count
|
|
|
- content = segment_update_entity.content
|
|
|
+ content = args.content
|
|
|
if segment.content == content:
|
|
|
segment.word_count = len(content)
|
|
|
if document.doc_form == "qa_model":
|
|
|
- segment.answer = segment_update_entity.answer
|
|
|
- segment.word_count += len(segment_update_entity.answer or "")
|
|
|
+ segment.answer = args.answer
|
|
|
+ segment.word_count += len(args.answer)
|
|
|
word_count_change = segment.word_count - word_count_change
|
|
|
- if segment_update_entity.keywords:
|
|
|
- segment.keywords = segment_update_entity.keywords
|
|
|
+ if args.keywords:
|
|
|
+ segment.keywords = args.keywords
|
|
|
segment.enabled = True
|
|
|
segment.disabled_at = None
|
|
|
segment.disabled_by = None
|
|
@@ -1583,9 +1576,38 @@ class SegmentService:
|
|
|
document.word_count = max(0, document.word_count + word_count_change)
|
|
|
db.session.add(document)
|
|
|
|
|
|
- if segment_update_entity.enabled:
|
|
|
- keywords = segment_update_entity.keywords or []
|
|
|
- VectorService.create_segments_vector([keywords], [segment], dataset)
|
|
|
+ if args.enabled:
|
|
|
+ VectorService.create_segments_vector([args.keywords], [segment], dataset)
|
|
|
+ if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
|
|
+
|
|
|
+
|
|
|
+ if dataset.indexing_technique == "high_quality":
|
|
|
+
|
|
|
+ model_manager = ModelManager()
|
|
|
+
|
|
|
+ if dataset.embedding_model_provider:
|
|
|
+ embedding_model_instance = model_manager.get_model_instance(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ provider=dataset.embedding_model_provider,
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ model=dataset.embedding_model,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ embedding_model_instance = model_manager.get_default_model_instance(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError("The knowledge base index technique is not high quality!")
|
|
|
+
|
|
|
+ processing_rule = (
|
|
|
+ db.session.query(DatasetProcessRule)
|
|
|
+ .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
|
|
+ .first()
|
|
|
+ )
|
|
|
+ VectorService.generate_child_chunks(
|
|
|
+ segment, document, dataset, embedding_model_instance, processing_rule, True
|
|
|
+ )
|
|
|
else:
|
|
|
segment_hash = helper.generate_text_hash(content)
|
|
|
tokens = 0
|
|
@@ -1616,8 +1638,8 @@ class SegmentService:
|
|
|
segment.disabled_at = None
|
|
|
segment.disabled_by = None
|
|
|
if document.doc_form == "qa_model":
|
|
|
- segment.answer = segment_update_entity.answer
|
|
|
- segment.word_count += len(segment_update_entity.answer or "")
|
|
|
+ segment.answer = args.answer
|
|
|
+ segment.word_count += len(args.answer)
|
|
|
word_count_change = segment.word_count - word_count_change
|
|
|
|
|
|
if word_count_change != 0:
|
|
@@ -1625,8 +1647,38 @@ class SegmentService:
|
|
|
db.session.add(document)
|
|
|
db.session.add(segment)
|
|
|
db.session.commit()
|
|
|
-
|
|
|
- VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset)
|
|
|
+ if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
|
|
+
|
|
|
+ if dataset.indexing_technique == "high_quality":
|
|
|
+
|
|
|
+ model_manager = ModelManager()
|
|
|
+
|
|
|
+ if dataset.embedding_model_provider:
|
|
|
+ embedding_model_instance = model_manager.get_model_instance(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ provider=dataset.embedding_model_provider,
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ model=dataset.embedding_model,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ embedding_model_instance = model_manager.get_default_model_instance(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError("The knowledge base index technique is not high quality!")
|
|
|
+
|
|
|
+ processing_rule = (
|
|
|
+ db.session.query(DatasetProcessRule)
|
|
|
+ .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
|
|
+ .first()
|
|
|
+ )
|
|
|
+ VectorService.generate_child_chunks(
|
|
|
+ segment, document, dataset, embedding_model_instance, processing_rule, True
|
|
|
+ )
|
|
|
+ elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
|
|
+
|
|
|
+ VectorService.update_segment_vector(args.keywords, segment, dataset)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.exception("update segment index failed")
|
|
@@ -1649,13 +1701,265 @@ class SegmentService:
|
|
|
if segment.enabled:
|
|
|
|
|
|
redis_client.setex(indexing_cache_key, 600, 1)
|
|
|
- delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
|
|
|
+ delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
|
|
|
db.session.delete(segment)
|
|
|
|
|
|
document.word_count -= segment.word_count
|
|
|
db.session.add(document)
|
|
|
db.session.commit()
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
|
|
+ index_node_ids = (
|
|
|
+ DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
|
|
|
+ .filter(
|
|
|
+ DocumentSegment.id.in_(segment_ids),
|
|
|
+ DocumentSegment.dataset_id == dataset.id,
|
|
|
+ DocumentSegment.document_id == document.id,
|
|
|
+ DocumentSegment.tenant_id == current_user.current_tenant_id,
|
|
|
+ )
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
|
|
|
+
|
|
|
+ delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
|
|
|
+ db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
|
|
|
+ if action == "enable":
|
|
|
+ segments = (
|
|
|
+ db.session.query(DocumentSegment)
|
|
|
+ .filter(
|
|
|
+ DocumentSegment.id.in_(segment_ids),
|
|
|
+ DocumentSegment.dataset_id == dataset.id,
|
|
|
+ DocumentSegment.document_id == document.id,
|
|
|
+ DocumentSegment.enabled == False,
|
|
|
+ )
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ if not segments:
|
|
|
+ return
|
|
|
+ real_deal_segmment_ids = []
|
|
|
+ for segment in segments:
|
|
|
+ indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
|
|
+ cache_result = redis_client.get(indexing_cache_key)
|
|
|
+ if cache_result is not None:
|
|
|
+ continue
|
|
|
+ segment.enabled = True
|
|
|
+ segment.disabled_at = None
|
|
|
+ segment.disabled_by = None
|
|
|
+ db.session.add(segment)
|
|
|
+ real_deal_segmment_ids.append(segment.id)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
|
|
|
+ elif action == "disable":
|
|
|
+ segments = (
|
|
|
+ db.session.query(DocumentSegment)
|
|
|
+ .filter(
|
|
|
+ DocumentSegment.id.in_(segment_ids),
|
|
|
+ DocumentSegment.dataset_id == dataset.id,
|
|
|
+ DocumentSegment.document_id == document.id,
|
|
|
+ DocumentSegment.enabled == True,
|
|
|
+ )
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ if not segments:
|
|
|
+ return
|
|
|
+ real_deal_segmment_ids = []
|
|
|
+ for segment in segments:
|
|
|
+ indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
|
|
+ cache_result = redis_client.get(indexing_cache_key)
|
|
|
+ if cache_result is not None:
|
|
|
+ continue
|
|
|
+ segment.enabled = False
|
|
|
+ segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
+ segment.disabled_by = current_user.id
|
|
|
+ db.session.add(segment)
|
|
|
+ real_deal_segmment_ids.append(segment.id)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
|
|
|
+ else:
|
|
|
+ raise InvalidActionError()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def create_child_chunk(
|
|
|
+ cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
|
|
|
+ ) -> ChildChunk:
|
|
|
+ lock_name = "add_child_lock_{}".format(segment.id)
|
|
|
+ with redis_client.lock(lock_name, timeout=20):
|
|
|
+ index_node_id = str(uuid.uuid4())
|
|
|
+ index_node_hash = helper.generate_text_hash(content)
|
|
|
+ child_chunk_count = (
|
|
|
+ db.session.query(ChildChunk)
|
|
|
+ .filter(
|
|
|
+ ChildChunk.tenant_id == current_user.current_tenant_id,
|
|
|
+ ChildChunk.dataset_id == dataset.id,
|
|
|
+ ChildChunk.document_id == document.id,
|
|
|
+ ChildChunk.segment_id == segment.id,
|
|
|
+ )
|
|
|
+ .count()
|
|
|
+ )
|
|
|
+ max_position = (
|
|
|
+ db.session.query(func.max(ChildChunk.position))
|
|
|
+ .filter(
|
|
|
+ ChildChunk.tenant_id == current_user.current_tenant_id,
|
|
|
+ ChildChunk.dataset_id == dataset.id,
|
|
|
+ ChildChunk.document_id == document.id,
|
|
|
+ ChildChunk.segment_id == segment.id,
|
|
|
+ )
|
|
|
+ .scalar()
|
|
|
+ )
|
|
|
+ child_chunk = ChildChunk(
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
+ dataset_id=dataset.id,
|
|
|
+ document_id=document.id,
|
|
|
+ segment_id=segment.id,
|
|
|
+ position=max_position + 1,
|
|
|
+ index_node_id=index_node_id,
|
|
|
+ index_node_hash=index_node_hash,
|
|
|
+ content=content,
|
|
|
+ word_count=len(content),
|
|
|
+ type="customized",
|
|
|
+ created_by=current_user.id,
|
|
|
+ )
|
|
|
+ db.session.add(child_chunk)
|
|
|
+
|
|
|
+ try:
|
|
|
+ VectorService.create_child_chunk_vector(child_chunk, dataset)
|
|
|
+ except Exception as e:
|
|
|
+ logging.exception("create child chunk index failed")
|
|
|
+ db.session.rollback()
|
|
|
+ raise ChildChunkIndexingError(str(e))
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ return child_chunk
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def update_child_chunks(
|
|
|
+ cls,
|
|
|
+ child_chunks_update_args: list[ChildChunkUpdateArgs],
|
|
|
+ segment: DocumentSegment,
|
|
|
+ document: Document,
|
|
|
+ dataset: Dataset,
|
|
|
+ ) -> list[ChildChunk]:
|
|
|
+ child_chunks = (
|
|
|
+ db.session.query(ChildChunk)
|
|
|
+ .filter(
|
|
|
+ ChildChunk.dataset_id == dataset.id,
|
|
|
+ ChildChunk.document_id == document.id,
|
|
|
+ ChildChunk.segment_id == segment.id,
|
|
|
+ )
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
|
|
|
+
|
|
|
+ new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
|
|
|
+
|
|
|
+ for child_chunk_update_args in child_chunks_update_args:
|
|
|
+ if child_chunk_update_args.id:
|
|
|
+ child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None)
|
|
|
+ if child_chunk:
|
|
|
+ if child_chunk.content != child_chunk_update_args.content:
|
|
|
+ child_chunk.content = child_chunk_update_args.content
|
|
|
+ child_chunk.word_count = len(child_chunk.content)
|
|
|
+ child_chunk.updated_by = current_user.id
|
|
|
+ child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
+ child_chunk.type = "customized"
|
|
|
+ update_child_chunks.append(child_chunk)
|
|
|
+ else:
|
|
|
+ new_child_chunks_args.append(child_chunk_update_args)
|
|
|
+ if child_chunks_map:
|
|
|
+ delete_child_chunks = list(child_chunks_map.values())
|
|
|
+ try:
|
|
|
+ if update_child_chunks:
|
|
|
+ db.session.bulk_save_objects(update_child_chunks)
|
|
|
+
|
|
|
+ if delete_child_chunks:
|
|
|
+ for child_chunk in delete_child_chunks:
|
|
|
+ db.session.delete(child_chunk)
|
|
|
+ if new_child_chunks_args:
|
|
|
+ child_chunk_count = len(child_chunks)
|
|
|
+ for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
|
|
|
+ index_node_id = str(uuid.uuid4())
|
|
|
+ index_node_hash = helper.generate_text_hash(args.content)
|
|
|
+ child_chunk = ChildChunk(
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
+ dataset_id=dataset.id,
|
|
|
+ document_id=document.id,
|
|
|
+ segment_id=segment.id,
|
|
|
+ position=position,
|
|
|
+ index_node_id=index_node_id,
|
|
|
+ index_node_hash=index_node_hash,
|
|
|
+ content=args.content,
|
|
|
+ word_count=len(args.content),
|
|
|
+ type="customized",
|
|
|
+ created_by=current_user.id,
|
|
|
+ )
|
|
|
+
|
|
|
+ db.session.add(child_chunk)
|
|
|
+ db.session.flush()
|
|
|
+ new_child_chunks.append(child_chunk)
|
|
|
+ VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset)
|
|
|
+ db.session.commit()
|
|
|
+ except Exception as e:
|
|
|
+ logging.exception("update child chunk index failed")
|
|
|
+ db.session.rollback()
|
|
|
+ raise ChildChunkIndexingError(str(e))
|
|
|
+ return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def update_child_chunk(
|
|
|
+ cls,
|
|
|
+ content: str,
|
|
|
+ child_chunk: ChildChunk,
|
|
|
+ segment: DocumentSegment,
|
|
|
+ document: Document,
|
|
|
+ dataset: Dataset,
|
|
|
+ ) -> ChildChunk:
|
|
|
+ try:
|
|
|
+ child_chunk.content = content
|
|
|
+ child_chunk.word_count = len(content)
|
|
|
+ child_chunk.updated_by = current_user.id
|
|
|
+ child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
+ child_chunk.type = "customized"
|
|
|
+ db.session.add(child_chunk)
|
|
|
+ VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
|
|
+ db.session.commit()
|
|
|
+ except Exception as e:
|
|
|
+ logging.exception("update child chunk index failed")
|
|
|
+ db.session.rollback()
|
|
|
+ raise ChildChunkIndexingError(str(e))
|
|
|
+ return child_chunk
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset):
|
|
|
+ db.session.delete(child_chunk)
|
|
|
+ try:
|
|
|
+ VectorService.delete_child_chunk_vector(child_chunk, dataset)
|
|
|
+ except Exception as e:
|
|
|
+ logging.exception("delete child chunk index failed")
|
|
|
+ db.session.rollback()
|
|
|
+ raise ChildChunkDeleteIndexError(str(e))
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_child_chunks(
|
|
|
+ cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
|
|
+ ):
|
|
|
+ query = ChildChunk.query.filter_by(
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
+ dataset_id=dataset_id,
|
|
|
+ document_id=document_id,
|
|
|
+ segment_id=segment_id,
|
|
|
+ ).order_by(ChildChunk.position.asc())
|
|
|
+ if keyword:
|
|
|
+ query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
|
|
+ return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
|
|
+
|
|
|
|
|
|
class DatasetCollectionBindingService:
|
|
|
@classmethod
|