|
@@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding
|
|
|
from services.entities.knowledge_entities.knowledge_entities import (
|
|
|
ChildChunkUpdateArgs,
|
|
|
KnowledgeConfig,
|
|
|
+ RerankingModel,
|
|
|
RetrievalModel,
|
|
|
SegmentUpdateArgs,
|
|
|
)
|
|
@@ -548,12 +549,14 @@ class DocumentService:
|
|
|
}
|
|
|
|
|
|
@staticmethod
|
|
|
- def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
|
|
|
- document = (
|
|
|
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
|
|
- )
|
|
|
-
|
|
|
- return document
|
|
|
+ def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
|
|
|
+ if document_id:
|
|
|
+ document = (
|
|
|
+ db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
|
|
+ )
|
|
|
+ return document
|
|
|
+ else:
|
|
|
+ return None
|
|
|
|
|
|
@staticmethod
|
|
|
def get_document_by_id(document_id: str) -> Optional[Document]:
|
|
@@ -744,25 +747,26 @@ class DocumentService:
|
|
|
if features.billing.enabled:
|
|
|
if not knowledge_config.original_document_id:
|
|
|
count = 0
|
|
|
- if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
- upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
- count = len(upload_file_list)
|
|
|
- elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
- notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
- for notion_info in notion_info_list:
|
|
|
- count = count + len(notion_info.pages)
|
|
|
- elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
- website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
- count = len(website_info.urls)
|
|
|
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
|
|
- if count > batch_upload_limit:
|
|
|
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
|
-
|
|
|
- DocumentService.check_documents_upload_quota(count, features)
|
|
|
+ if knowledge_config.data_source:
|
|
|
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
|
|
+ count = len(upload_file_list)
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
+ for notion_info in notion_info_list: # type: ignore
|
|
|
+ count = count + len(notion_info.pages)
|
|
|
+ elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
+ website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
+ count = len(website_info.urls) # type: ignore
|
|
|
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
|
|
+ if count > batch_upload_limit:
|
|
|
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
|
+
|
|
|
+ DocumentService.check_documents_upload_quota(count, features)
|
|
|
|
|
|
# if dataset is empty, update dataset data_source_type
|
|
|
if not dataset.data_source_type:
|
|
|
- dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
|
|
+ dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
|
|
|
|
|
if not dataset.indexing_technique:
|
|
|
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
|
@@ -789,7 +793,7 @@ class DocumentService:
|
|
|
"score_threshold_enabled": False,
|
|
|
}
|
|
|
|
|
|
- dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
|
|
|
+ dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
|
|
|
|
|
documents = []
|
|
|
if knowledge_config.original_document_id:
|
|
@@ -801,34 +805,35 @@ class DocumentService:
|
|
|
# save process rule
|
|
|
if not dataset_process_rule:
|
|
|
process_rule = knowledge_config.process_rule
|
|
|
- if process_rule.mode in ("custom", "hierarchical"):
|
|
|
- dataset_process_rule = DatasetProcessRule(
|
|
|
- dataset_id=dataset.id,
|
|
|
- mode=process_rule.mode,
|
|
|
- rules=process_rule.rules.model_dump_json(),
|
|
|
- created_by=account.id,
|
|
|
- )
|
|
|
- elif process_rule.mode == "automatic":
|
|
|
- dataset_process_rule = DatasetProcessRule(
|
|
|
- dataset_id=dataset.id,
|
|
|
- mode=process_rule.mode,
|
|
|
- rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
|
|
- created_by=account.id,
|
|
|
- )
|
|
|
- else:
|
|
|
- logging.warn(
|
|
|
- f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
|
|
|
- )
|
|
|
- return
|
|
|
- db.session.add(dataset_process_rule)
|
|
|
- db.session.commit()
|
|
|
+ if process_rule:
|
|
|
+ if process_rule.mode in ("custom", "hierarchical"):
|
|
|
+ dataset_process_rule = DatasetProcessRule(
|
|
|
+ dataset_id=dataset.id,
|
|
|
+ mode=process_rule.mode,
|
|
|
+ rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
|
|
+ created_by=account.id,
|
|
|
+ )
|
|
|
+ elif process_rule.mode == "automatic":
|
|
|
+ dataset_process_rule = DatasetProcessRule(
|
|
|
+ dataset_id=dataset.id,
|
|
|
+ mode=process_rule.mode,
|
|
|
+ rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
|
|
+ created_by=account.id,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ logging.warn(
|
|
|
+ f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
|
|
+ )
|
|
|
+ return
|
|
|
+ db.session.add(dataset_process_rule)
|
|
|
+ db.session.commit()
|
|
|
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
|
|
with redis_client.lock(lock_name, timeout=600):
|
|
|
position = DocumentService.get_documents_position(dataset.id)
|
|
|
document_ids = []
|
|
|
duplicate_document_ids = []
|
|
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
- upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
|
|
for file_id in upload_file_list:
|
|
|
file = (
|
|
|
db.session.query(UploadFile)
|
|
@@ -854,7 +859,7 @@ class DocumentService:
|
|
|
name=file_name,
|
|
|
).first()
|
|
|
if document:
|
|
|
- document.dataset_process_rule_id = dataset_process_rule.id
|
|
|
+ document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
|
|
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
document.created_from = created_from
|
|
|
document.doc_form = knowledge_config.doc_form
|
|
@@ -868,7 +873,7 @@ class DocumentService:
|
|
|
continue
|
|
|
document = DocumentService.build_document(
|
|
|
dataset,
|
|
|
- dataset_process_rule.id,
|
|
|
+ dataset_process_rule.id, # type: ignore
|
|
|
knowledge_config.data_source.info_list.data_source_type,
|
|
|
knowledge_config.doc_form,
|
|
|
knowledge_config.doc_language,
|
|
@@ -886,6 +891,8 @@ class DocumentService:
|
|
|
position += 1
|
|
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
+ if not notion_info_list:
|
|
|
+ raise ValueError("No notion info list found.")
|
|
|
exist_page_ids = []
|
|
|
exist_document = {}
|
|
|
documents = Document.query.filter_by(
|
|
@@ -921,7 +928,7 @@ class DocumentService:
|
|
|
}
|
|
|
document = DocumentService.build_document(
|
|
|
dataset,
|
|
|
- dataset_process_rule.id,
|
|
|
+ dataset_process_rule.id, # type: ignore
|
|
|
knowledge_config.data_source.info_list.data_source_type,
|
|
|
knowledge_config.doc_form,
|
|
|
knowledge_config.doc_language,
|
|
@@ -944,6 +951,8 @@ class DocumentService:
|
|
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
|
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
+ if not website_info:
|
|
|
+ raise ValueError("No website info list found.")
|
|
|
urls = website_info.urls
|
|
|
for url in urls:
|
|
|
data_source_info = {
|
|
@@ -959,7 +968,7 @@ class DocumentService:
|
|
|
document_name = url
|
|
|
document = DocumentService.build_document(
|
|
|
dataset,
|
|
|
- dataset_process_rule.id,
|
|
|
+ dataset_process_rule.id, # type: ignore
|
|
|
knowledge_config.data_source.info_list.data_source_type,
|
|
|
knowledge_config.doc_form,
|
|
|
knowledge_config.doc_language,
|
|
@@ -1054,7 +1063,7 @@ class DocumentService:
|
|
|
dataset_process_rule = DatasetProcessRule(
|
|
|
dataset_id=dataset.id,
|
|
|
mode=process_rule.mode,
|
|
|
- rules=process_rule.rules.model_dump_json(),
|
|
|
+ rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
|
|
created_by=account.id,
|
|
|
)
|
|
|
elif process_rule.mode == "automatic":
|
|
@@ -1073,6 +1082,8 @@ class DocumentService:
|
|
|
file_name = ""
|
|
|
data_source_info = {}
|
|
|
if document_data.data_source.info_list.data_source_type == "upload_file":
|
|
|
+ if not document_data.data_source.info_list.file_info_list:
|
|
|
+ raise ValueError("No file info list found.")
|
|
|
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
|
|
for file_id in upload_file_list:
|
|
|
file = (
|
|
@@ -1090,6 +1101,8 @@ class DocumentService:
|
|
|
"upload_file_id": file_id,
|
|
|
}
|
|
|
elif document_data.data_source.info_list.data_source_type == "notion_import":
|
|
|
+ if not document_data.data_source.info_list.notion_info_list:
|
|
|
+ raise ValueError("No notion info list found.")
|
|
|
notion_info_list = document_data.data_source.info_list.notion_info_list
|
|
|
for notion_info in notion_info_list:
|
|
|
workspace_id = notion_info.workspace_id
|
|
@@ -1107,20 +1120,21 @@ class DocumentService:
|
|
|
data_source_info = {
|
|
|
"notion_workspace_id": workspace_id,
|
|
|
"notion_page_id": page.page_id,
|
|
|
- "notion_page_icon": page.page_icon,
|
|
|
+ "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
|
|
"type": page.type,
|
|
|
}
|
|
|
elif document_data.data_source.info_list.data_source_type == "website_crawl":
|
|
|
website_info = document_data.data_source.info_list.website_info_list
|
|
|
- urls = website_info.urls
|
|
|
- for url in urls:
|
|
|
- data_source_info = {
|
|
|
- "url": url,
|
|
|
- "provider": website_info.provider,
|
|
|
- "job_id": website_info.job_id,
|
|
|
- "only_main_content": website_info.only_main_content,
|
|
|
- "mode": "crawl",
|
|
|
- }
|
|
|
+ if website_info:
|
|
|
+ urls = website_info.urls
|
|
|
+ for url in urls:
|
|
|
+ data_source_info = {
|
|
|
+ "url": url,
|
|
|
+ "provider": website_info.provider,
|
|
|
+ "job_id": website_info.job_id,
|
|
|
+ "only_main_content": website_info.only_main_content, # type: ignore
|
|
|
+ "mode": "crawl",
|
|
|
+ }
|
|
|
document.data_source_type = document_data.data_source.info_list.data_source_type
|
|
|
document.data_source_info = json.dumps(data_source_info)
|
|
|
document.name = file_name
|
|
@@ -1155,15 +1169,21 @@ class DocumentService:
|
|
|
if features.billing.enabled:
|
|
|
count = 0
|
|
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
|
|
- upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
+ upload_file_list = (
|
|
|
+ knowledge_config.data_source.info_list.file_info_list.file_ids
|
|
|
+ if knowledge_config.data_source.info_list.file_info_list
|
|
|
+ else []
|
|
|
+ )
|
|
|
count = len(upload_file_list)
|
|
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
|
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
|
|
- for notion_info in notion_info_list:
|
|
|
- count = count + len(notion_info.pages)
|
|
|
+ if notion_info_list:
|
|
|
+ for notion_info in notion_info_list:
|
|
|
+ count = count + len(notion_info.pages)
|
|
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
|
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
|
|
- count = len(website_info.urls)
|
|
|
+ if website_info:
|
|
|
+ count = len(website_info.urls)
|
|
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
|
|
if count > batch_upload_limit:
|
|
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
@@ -1174,20 +1194,20 @@ class DocumentService:
|
|
|
retrieval_model = None
|
|
|
if knowledge_config.indexing_technique == "high_quality":
|
|
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
- knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
|
|
+ knowledge_config.embedding_model_provider, # type: ignore
|
|
|
+ knowledge_config.embedding_model, # type: ignore
|
|
|
)
|
|
|
dataset_collection_binding_id = dataset_collection_binding.id
|
|
|
if knowledge_config.retrieval_model:
|
|
|
retrieval_model = knowledge_config.retrieval_model
|
|
|
else:
|
|
|
- default_retrieval_model = {
|
|
|
- "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
|
- "reranking_enable": False,
|
|
|
- "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
|
|
- "top_k": 2,
|
|
|
- "score_threshold_enabled": False,
|
|
|
- }
|
|
|
- retrieval_model = RetrievalModel(**default_retrieval_model)
|
|
|
+ retrieval_model = RetrievalModel(
|
|
|
+ search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
|
+ reranking_enable=False,
|
|
|
+ reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
|
|
+ top_k=2,
|
|
|
+ score_threshold_enabled=False,
|
|
|
+ )
|
|
|
# save dataset
|
|
|
dataset = Dataset(
|
|
|
tenant_id=tenant_id,
|
|
@@ -1557,12 +1577,12 @@ class SegmentService:
|
|
|
raise ValueError("Can't update disabled segment")
|
|
|
try:
|
|
|
word_count_change = segment.word_count
|
|
|
- content = args.content
|
|
|
+ content = args.content or segment.content
|
|
|
if segment.content == content:
|
|
|
segment.word_count = len(content)
|
|
|
if document.doc_form == "qa_model":
|
|
|
segment.answer = args.answer
|
|
|
- segment.word_count += len(args.answer)
|
|
|
+ segment.word_count += len(args.answer) if args.answer else 0
|
|
|
word_count_change = segment.word_count - word_count_change
|
|
|
if args.keywords:
|
|
|
segment.keywords = args.keywords
|
|
@@ -1577,7 +1597,12 @@ class SegmentService:
|
|
|
db.session.add(document)
|
|
|
# update segment index task
|
|
|
if args.enabled:
|
|
|
- VectorService.create_segments_vector([args.keywords], [segment], dataset)
|
|
|
+ VectorService.create_segments_vector(
|
|
|
+ [args.keywords] if args.keywords else None,
|
|
|
+ [segment],
|
|
|
+ dataset,
|
|
|
+ document.doc_form,
|
|
|
+ )
|
|
|
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
|
|
# regenerate child chunks
|
|
|
# get embedding model instance
|
|
@@ -1605,6 +1630,8 @@ class SegmentService:
|
|
|
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
|
|
.first()
|
|
|
)
|
|
|
+ if not processing_rule:
|
|
|
+ raise ValueError("No processing rule found.")
|
|
|
VectorService.generate_child_chunks(
|
|
|
segment, document, dataset, embedding_model_instance, processing_rule, True
|
|
|
)
|
|
@@ -1639,7 +1666,7 @@ class SegmentService:
|
|
|
segment.disabled_by = None
|
|
|
if document.doc_form == "qa_model":
|
|
|
segment.answer = args.answer
|
|
|
- segment.word_count += len(args.answer)
|
|
|
+ segment.word_count += len(args.answer) if args.answer else 0
|
|
|
word_count_change = segment.word_count - word_count_change
|
|
|
# update document word count
|
|
|
if word_count_change != 0:
|
|
@@ -1673,6 +1700,8 @@ class SegmentService:
|
|
|
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
|
|
.first()
|
|
|
)
|
|
|
+ if not processing_rule:
|
|
|
+ raise ValueError("No processing rule found.")
|
|
|
VectorService.generate_child_chunks(
|
|
|
segment, document, dataset, embedding_model_instance, processing_rule, True
|
|
|
)
|