소스 검색

py lint (#12102)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Jyong 4 달 전
부모
커밋
84ac004772

+ 1 - 1
api/commands.py

@@ -587,7 +587,7 @@ def upgrade_db():
             click.echo(click.style("Starting database migration.", fg="green"))
 
             # run db migration
-            import flask_migrate
+            import flask_migrate  # type: ignore
 
             flask_migrate.upgrade()
 

+ 4 - 3
api/controllers/console/datasets/datasets_document.py

@@ -413,7 +413,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
                 indexing_runner = IndexingRunner()
 
                 try:
-                    response = indexing_runner.indexing_estimate(
+                    estimate_response = indexing_runner.indexing_estimate(
                         current_user.current_tenant_id,
                         [extract_setting],
                         data_process_rule_dict,
@@ -421,6 +421,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
                         "English",
                         dataset_id,
                     )
+                    return estimate_response.model_dump(), 200
                 except LLMBadRequestError:
                     raise ProviderNotInitializeError(
                         "No Embedding Model available. Please configure a valid provider "
@@ -431,7 +432,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
                 except Exception as e:
                     raise IndexingEstimateError(str(e))
 
-        return response.model_dump(), 200
+        return response, 200
 
 
 class DocumentBatchIndexingEstimateApi(DocumentResource):
@@ -521,6 +522,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                     "English",
                     dataset_id,
                 )
+                return response.model_dump(), 200
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
@@ -530,7 +532,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                 raise ProviderNotInitializeError(ex.description)
             except Exception as e:
                 raise IndexingEstimateError(str(e))
-        return response.model_dump(), 200
 
 
 class DocumentBatchIndexingStatusApi(DocumentResource):

+ 14 - 8
api/controllers/service_api/dataset/document.py

@@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields
 from libs.login import current_user
 from models.dataset import Dataset, Document, DocumentSegment
 from services.dataset_service import DocumentService
+from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
 from services.file_service import FileService
 
 
@@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource):
             "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
         }
         args["data_source"] = data_source
+        knowledge_config = KnowledgeConfig(**args)
         # validate args
-        DocumentService.document_create_args_validate(args)
+        DocumentService.document_create_args_validate(knowledge_config)
 
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(
                 dataset=dataset,
-                document_data=args,
+                knowledge_config=knowledge_config,
                 account=current_user,
                 dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
                 created_from="api",
@@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource):
             args["data_source"] = data_source
         # validate args
         args["original_document_id"] = str(document_id)
-        DocumentService.document_create_args_validate(args)
+        knowledge_config = KnowledgeConfig(**args)
+        DocumentService.document_create_args_validate(knowledge_config)
 
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(
                 dataset=dataset,
-                document_data=args,
+                knowledge_config=knowledge_config,
                 account=current_user,
                 dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
                 created_from="api",
@@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource):
         data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
         args["data_source"] = data_source
         # validate args
-        DocumentService.document_create_args_validate(args)
+        knowledge_config = KnowledgeConfig(**args)
+        DocumentService.document_create_args_validate(knowledge_config)
 
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(
                 dataset=dataset,
-                document_data=args,
+                knowledge_config=knowledge_config,
                 account=dataset.created_by_account,
                 dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
                 created_from="api",
@@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
             args["data_source"] = data_source
         # validate args
         args["original_document_id"] = str(document_id)
-        DocumentService.document_create_args_validate(args)
+
+        knowledge_config = KnowledgeConfig(**args)
+        DocumentService.document_create_args_validate(knowledge_config)
 
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(
                 dataset=dataset,
-                document_data=args,
+                knowledge_config=knowledge_config,
                 account=dataset.created_by_account,
                 dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
                 created_from="api",

+ 7 - 7
api/core/indexing_runner.py

@@ -276,7 +276,7 @@ class IndexingRunner:
                     tenant_id=tenant_id,
                     model_type=ModelType.TEXT_EMBEDDING,
                 )
-        preview_texts = []
+        preview_texts = []  # type: ignore
 
         total_segments = 0
         index_type = doc_form
@@ -300,13 +300,13 @@ class IndexingRunner:
                 if len(preview_texts) < 10:
                     if doc_form and doc_form == "qa_model":
                         preview_detail = QAPreviewDetail(
-                            question=document.page_content, answer=document.metadata.get("answer")
+                            question=document.page_content, answer=document.metadata.get("answer") or ""
                         )
                         preview_texts.append(preview_detail)
                     else:
-                        preview_detail = PreviewDetail(content=document.page_content)
+                        preview_detail = PreviewDetail(content=document.page_content)  # type: ignore
                         if document.children:
-                            preview_detail.child_chunks = [child.page_content for child in document.children]
+                            preview_detail.child_chunks = [child.page_content for child in document.children]  # type: ignore
                         preview_texts.append(preview_detail)
 
                 # delete image files and related db records
@@ -325,7 +325,7 @@ class IndexingRunner:
 
         if doc_form and doc_form == "qa_model":
             return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
-        return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
+        return IndexingEstimate(total_segments=total_segments, preview=preview_texts)  # type: ignore
 
     def _extract(
         self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@@ -454,7 +454,7 @@ class IndexingRunner:
                 embedding_model_instance=embedding_model_instance,
             )
 
-        return character_splitter
+        return character_splitter  # type: ignore
 
     def _split_to_documents_for_estimate(
         self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
@@ -535,7 +535,7 @@ class IndexingRunner:
             # create keyword index
             create_keyword_thread = threading.Thread(
                 target=self._process_keyword_index,
-                args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
+                args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),  # type: ignore
             )
             create_keyword_thread.start()
 

+ 65 - 64
api/core/rag/datasource/retrieval_service.py

@@ -258,78 +258,79 @@ class RetrievalService:
         include_segment_ids = []
         segment_child_map = {}
         for document in documents:
-            document_id = document.metadata["document_id"]
+            document_id = document.metadata.get("document_id")
             dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
-            if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
-                child_index_node_id = document.metadata["doc_id"]
-                result = (
-                    db.session.query(ChildChunk, DocumentSegment)
-                    .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
-                    .filter(
-                        ChildChunk.index_node_id == child_index_node_id,
-                        DocumentSegment.dataset_id == dataset_document.dataset_id,
-                        DocumentSegment.enabled == True,
-                        DocumentSegment.status == "completed",
+            if dataset_document:
+                if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                    child_index_node_id = document.metadata.get("doc_id")
+                    result = (
+                        db.session.query(ChildChunk, DocumentSegment)
+                        .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
+                        .filter(
+                            ChildChunk.index_node_id == child_index_node_id,
+                            DocumentSegment.dataset_id == dataset_document.dataset_id,
+                            DocumentSegment.enabled == True,
+                            DocumentSegment.status == "completed",
+                        )
+                        .first()
                     )
-                    .first()
-                )
-                if result:
-                    child_chunk, segment = result
-                    if not segment:
-                        continue
-                    if segment.id not in include_segment_ids:
-                        include_segment_ids.append(segment.id)
-                        child_chunk_detail = {
-                            "id": child_chunk.id,
-                            "content": child_chunk.content,
-                            "position": child_chunk.position,
-                            "score": document.metadata.get("score", 0.0),
-                        }
-                        map_detail = {
-                            "max_score": document.metadata.get("score", 0.0),
-                            "child_chunks": [child_chunk_detail],
-                        }
-                        segment_child_map[segment.id] = map_detail
-                        record = {
-                            "segment": segment,
-                        }
-                        records.append(record)
+                    if result:
+                        child_chunk, segment = result
+                        if not segment:
+                            continue
+                        if segment.id not in include_segment_ids:
+                            include_segment_ids.append(segment.id)
+                            child_chunk_detail = {
+                                "id": child_chunk.id,
+                                "content": child_chunk.content,
+                                "position": child_chunk.position,
+                                "score": document.metadata.get("score", 0.0),
+                            }
+                            map_detail = {
+                                "max_score": document.metadata.get("score", 0.0),
+                                "child_chunks": [child_chunk_detail],
+                            }
+                            segment_child_map[segment.id] = map_detail
+                            record = {
+                                "segment": segment,
+                            }
+                            records.append(record)
+                        else:
+                            child_chunk_detail = {
+                                "id": child_chunk.id,
+                                "content": child_chunk.content,
+                                "position": child_chunk.position,
+                                "score": document.metadata.get("score", 0.0),
+                            }
+                            segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
+                            segment_child_map[segment.id]["max_score"] = max(
+                                segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
+                            )
                     else:
-                        child_chunk_detail = {
-                            "id": child_chunk.id,
-                            "content": child_chunk.content,
-                            "position": child_chunk.position,
-                            "score": document.metadata.get("score", 0.0),
-                        }
-                        segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
-                        segment_child_map[segment.id]["max_score"] = max(
-                            segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
-                        )
+                        continue
                 else:
-                    continue
-            else:
-                index_node_id = document.metadata["doc_id"]
+                    index_node_id = document.metadata["doc_id"]
 
-                segment = (
-                    db.session.query(DocumentSegment)
-                    .filter(
-                        DocumentSegment.dataset_id == dataset_document.dataset_id,
-                        DocumentSegment.enabled == True,
-                        DocumentSegment.status == "completed",
-                        DocumentSegment.index_node_id == index_node_id,
+                    segment = (
+                        db.session.query(DocumentSegment)
+                        .filter(
+                            DocumentSegment.dataset_id == dataset_document.dataset_id,
+                            DocumentSegment.enabled == True,
+                            DocumentSegment.status == "completed",
+                            DocumentSegment.index_node_id == index_node_id,
+                        )
+                        .first()
                     )
-                    .first()
-                )
 
-                if not segment:
-                    continue
-                include_segment_ids.append(segment.id)
-                record = {
-                    "segment": segment,
-                    "score": document.metadata.get("score", None),
-                }
+                    if not segment:
+                        continue
+                    include_segment_ids.append(segment.id)
+                    record = {
+                        "segment": segment,
+                        "score": document.metadata.get("score", None),
+                    }
 
-                records.append(record)
+                    records.append(record)
             for record in records:
                 if record["segment"].id in segment_child_map:
                     record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)

+ 19 - 18
api/core/rag/docstore/dataset_docstore.py

@@ -122,26 +122,27 @@ class DatasetDocumentStore:
                 db.session.add(segment_document)
                 db.session.flush()
                 if save_child:
-                    for postion, child in enumerate(doc.children, start=1):
-                        child_segment = ChildChunk(
-                            tenant_id=self._dataset.tenant_id,
-                            dataset_id=self._dataset.id,
-                            document_id=self._document_id,
-                            segment_id=segment_document.id,
-                            position=postion,
-                            index_node_id=child.metadata["doc_id"],
-                            index_node_hash=child.metadata["doc_hash"],
-                            content=child.page_content,
-                            word_count=len(child.page_content),
-                            type="automatic",
-                            created_by=self._user_id,
-                        )
-                        db.session.add(child_segment)
+                    if doc.children:
+                        for postion, child in enumerate(doc.children, start=1):
+                            child_segment = ChildChunk(
+                                tenant_id=self._dataset.tenant_id,
+                                dataset_id=self._dataset.id,
+                                document_id=self._document_id,
+                                segment_id=segment_document.id,
+                                position=postion,
+                                index_node_id=child.metadata.get("doc_id"),
+                                index_node_hash=child.metadata.get("doc_hash"),
+                                content=child.page_content,
+                                word_count=len(child.page_content),
+                                type="automatic",
+                                created_by=self._user_id,
+                            )
+                            db.session.add(child_segment)
             else:
                 segment_document.content = doc.page_content
                 if doc.metadata.get("answer"):
                     segment_document.answer = doc.metadata.pop("answer", "")
-                segment_document.index_node_hash = doc.metadata["doc_hash"]
+                segment_document.index_node_hash = doc.metadata.get("doc_hash")
                 segment_document.word_count = len(doc.page_content)
                 segment_document.tokens = tokens
                 if save_child and doc.children:
@@ -160,8 +161,8 @@ class DatasetDocumentStore:
                             document_id=self._document_id,
                             segment_id=segment_document.id,
                             position=position,
-                            index_node_id=child.metadata["doc_id"],
-                            index_node_hash=child.metadata["doc_hash"],
+                            index_node_id=child.metadata.get("doc_id"),
+                            index_node_hash=child.metadata.get("doc_hash"),
                             content=child.page_content,
                             word_count=len(child.page_content),
                             type="automatic",

+ 1 - 1
api/core/rag/extractor/excel_extractor.py

@@ -4,7 +4,7 @@ import os
 from typing import Optional, cast
 
 import pandas as pd
-from openpyxl import load_workbook
+from openpyxl import load_workbook  # type: ignore
 
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document

+ 1 - 1
api/core/rag/index_processor/index_processor_base.py

@@ -81,4 +81,4 @@ class BaseIndexProcessor(ABC):
                 embedding_model_instance=embedding_model_instance,
             )
 
-        return character_splitter
+        return character_splitter  # type: ignore

+ 6 - 0
api/core/rag/index_processor/processor/paragraph_index_processor.py

@@ -30,12 +30,18 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
 
     def transform(self, documents: list[Document], **kwargs) -> list[Document]:
         process_rule = kwargs.get("process_rule")
+        if not process_rule:
+            raise ValueError("No process rule found.")
         if process_rule.get("mode") == "automatic":
             automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
             rules = Rule(**automatic_rule)
         else:
+            if not process_rule.get("rules"):
+                raise ValueError("No rules found in process rule.")
             rules = Rule(**process_rule.get("rules"))
         # Split the text documents into nodes.
+        if not rules.segmentation:
+            raise ValueError("No segmentation found in rules.")
         splitter = self._get_splitter(
             processing_rule_mode=process_rule.get("mode"),
             max_tokens=rules.segmentation.max_tokens,

+ 7 - 1
api/core/rag/index_processor/processor/parent_child_index_processor.py

@@ -30,8 +30,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
 
     def transform(self, documents: list[Document], **kwargs) -> list[Document]:
         process_rule = kwargs.get("process_rule")
+        if not process_rule:
+            raise ValueError("No process rule found.")
+        if not process_rule.get("rules"):
+            raise ValueError("No rules found in process rule.")
         rules = Rule(**process_rule.get("rules"))
-        all_documents = []
+        all_documents = []  # type: ignore
         if rules.parent_mode == ParentMode.PARAGRAPH:
             # Split the text documents into nodes.
             splitter = self._get_splitter(
@@ -161,6 +165,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
         process_rule_mode: str,
         embedding_model_instance: Optional[ModelInstance],
     ) -> list[ChildDocument]:
+        if not rules.subchunk_segmentation:
+            raise ValueError("No subchunk segmentation found in rules.")
         child_splitter = self._get_splitter(
             processing_rule_mode=process_rule_mode,
             max_tokens=rules.subchunk_segmentation.max_tokens,

+ 11 - 7
api/core/rag/index_processor/processor/qa_index_processor.py

@@ -37,12 +37,16 @@ class QAIndexProcessor(BaseIndexProcessor):
     def transform(self, documents: list[Document], **kwargs) -> list[Document]:
         preview = kwargs.get("preview")
         process_rule = kwargs.get("process_rule")
+        if not process_rule:
+            raise ValueError("No process rule found.")
+        if not process_rule.get("rules"):
+            raise ValueError("No rules found in process rule.")
         rules = Rule(**process_rule.get("rules"))
         splitter = self._get_splitter(
             processing_rule_mode=process_rule.get("mode"),
-            max_tokens=rules.segmentation.max_tokens,
-            chunk_overlap=rules.segmentation.chunk_overlap,
-            separator=rules.segmentation.separator,
+            max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
+            chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
+            separator=rules.segmentation.separator if rules.segmentation else "",
             embedding_model_instance=kwargs.get("embedding_model_instance"),
         )
 
@@ -71,8 +75,8 @@ class QAIndexProcessor(BaseIndexProcessor):
             all_documents.extend(split_documents)
         if preview:
             self._format_qa_document(
-                current_app._get_current_object(),
-                kwargs.get("tenant_id"),
+                current_app._get_current_object(),  # type: ignore
+                kwargs.get("tenant_id"),  # type: ignore
                 all_documents[0],
                 all_qa_documents,
                 kwargs.get("doc_language", "English"),
@@ -85,8 +89,8 @@ class QAIndexProcessor(BaseIndexProcessor):
                     document_format_thread = threading.Thread(
                         target=self._format_qa_document,
                         kwargs={
-                            "flask_app": current_app._get_current_object(),
-                            "tenant_id": kwargs.get("tenant_id"),
+                            "flask_app": current_app._get_current_object(),  # type: ignore
+                            "tenant_id": kwargs.get("tenant_id"),  # type: ignore
                             "document_node": doc,
                             "all_qa_documents": all_qa_documents,
                             "document_language": kwargs.get("doc_language", "English"),

+ 3 - 3
api/core/rag/models/document.py

@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
 from collections.abc import Sequence
 from typing import Any, Optional
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
 
 
 class ChildDocument(BaseModel):
@@ -15,7 +15,7 @@ class ChildDocument(BaseModel):
     """Arbitrary metadata about the page content (e.g., source, relationships to other
         documents, etc.).
     """
-    metadata: Optional[dict] = Field(default_factory=dict)
+    metadata: dict = {}
 
 
 class Document(BaseModel):
@@ -28,7 +28,7 @@ class Document(BaseModel):
     """Arbitrary metadata about the page content (e.g., source, relationships to other
         documents, etc.).
     """
-    metadata: Optional[dict] = Field(default_factory=dict)
+    metadata: dict = {}
 
     provider: Optional[str] = "dify"
 

+ 1 - 1
api/extensions/ext_blueprints.py

@@ -5,7 +5,7 @@ from dify_app import DifyApp
 def init_app(app: DifyApp):
     # register blueprint routers
 
-    from flask_cors import CORS
+    from flask_cors import CORS  # type: ignore
 
     from controllers.console import bp as console_app_bp
     from controllers.files import bp as files_bp

+ 6 - 9
api/schedule/mail_clean_document_notify_task.py

@@ -1,9 +1,9 @@
 import logging
 import time
+from collections import defaultdict
 
 import click
 from celery import shared_task  # type: ignore
-from flask import render_template
 
 from extensions.ext_mail import mail
 from models.account import Account, Tenant, TenantAccountJoin
@@ -27,7 +27,7 @@ def send_document_clean_notify_task():
     try:
         dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
         # group by tenant_id
-        dataset_auto_disable_logs_map = {}
+        dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
         for dataset_auto_disable_log in dataset_auto_disable_logs:
             dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
 
@@ -37,11 +37,13 @@ def send_document_clean_notify_task():
             if not tenant:
                 continue
             current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
+            if not current_owner_join:
+                continue
             account = Account.query.filter(Account.id == current_owner_join.account_id).first()
             if not account:
                 continue
 
-            dataset_auto_dataset_map = {}
+            dataset_auto_dataset_map = {}  # type: ignore
             for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
                 dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
                     dataset_auto_disable_log.document_id
@@ -53,14 +55,9 @@ def send_document_clean_notify_task():
                     document_count = len(document_ids)
                     knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
 
-        html_content = render_template(
-            "clean_document_job_mail_template-US.html",
-        )
-        mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
-
         end_at = time.perf_counter()
         logging.info(
             click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
         )
     except Exception:
-        logging.exception("Send invite member mail to {} failed".format(to))
+        logging.exception("Send invite member mail to failed")

+ 2 - 2
api/services/app_dsl_service.py

@@ -4,7 +4,7 @@ from enum import StrEnum
 from typing import Optional, cast
 from uuid import uuid4
 
-import yaml
+import yaml  # type: ignore
 from packaging import version
 from pydantic import BaseModel
 from sqlalchemy import select
@@ -465,7 +465,7 @@ class AppDslService:
         else:
             cls._append_model_config_export_data(export_data, app_model)
 
-        return yaml.dump(export_data, allow_unicode=True)
+        return yaml.dump(export_data, allow_unicode=True)  # type: ignore
 
     @classmethod
     def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:

+ 106 - 77
api/services/dataset_service.py

@@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding
 from services.entities.knowledge_entities.knowledge_entities import (
     ChildChunkUpdateArgs,
     KnowledgeConfig,
+    RerankingModel,
     RetrievalModel,
     SegmentUpdateArgs,
 )
@@ -548,12 +549,14 @@ class DocumentService:
     }
 
     @staticmethod
-    def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
-        document = (
-            db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
-        )
-
-        return document
+    def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
+        if document_id:
+            document = (
+                db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+            )
+            return document
+        else:
+            return None
 
     @staticmethod
     def get_document_by_id(document_id: str) -> Optional[Document]:
@@ -744,25 +747,26 @@ class DocumentService:
         if features.billing.enabled:
             if not knowledge_config.original_document_id:
                 count = 0
-                if knowledge_config.data_source.info_list.data_source_type == "upload_file":
-                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
-                    count = len(upload_file_list)
-                elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
-                    notion_info_list = knowledge_config.data_source.info_list.notion_info_list
-                    for notion_info in notion_info_list:
-                        count = count + len(notion_info.pages)
-                elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
-                    website_info = knowledge_config.data_source.info_list.website_info_list
-                    count = len(website_info.urls)
-                batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
-                if count > batch_upload_limit:
-                    raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
-
-                DocumentService.check_documents_upload_quota(count, features)
+                if knowledge_config.data_source:
+                    if knowledge_config.data_source.info_list.data_source_type == "upload_file":
+                        upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
+                        count = len(upload_file_list)
+                    elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
+                        notion_info_list = knowledge_config.data_source.info_list.notion_info_list
+                        for notion_info in notion_info_list:  # type: ignore
+                            count = count + len(notion_info.pages)
+                    elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
+                        website_info = knowledge_config.data_source.info_list.website_info_list
+                        count = len(website_info.urls)  # type: ignore
+                    batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+                    if count > batch_upload_limit:
+                        raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+
+                    DocumentService.check_documents_upload_quota(count, features)
 
         # if dataset is empty, update dataset data_source_type
         if not dataset.data_source_type:
-            dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
+            dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type  # type: ignore
 
         if not dataset.indexing_technique:
             if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
@@ -789,7 +793,7 @@ class DocumentService:
                         "score_threshold_enabled": False,
                     }
 
-                    dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
+                    dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model  # type: ignore
 
         documents = []
         if knowledge_config.original_document_id:
@@ -801,34 +805,35 @@ class DocumentService:
             # save process rule
             if not dataset_process_rule:
                 process_rule = knowledge_config.process_rule
-                if process_rule.mode in ("custom", "hierarchical"):
-                    dataset_process_rule = DatasetProcessRule(
-                        dataset_id=dataset.id,
-                        mode=process_rule.mode,
-                        rules=process_rule.rules.model_dump_json(),
-                        created_by=account.id,
-                    )
-                elif process_rule.mode == "automatic":
-                    dataset_process_rule = DatasetProcessRule(
-                        dataset_id=dataset.id,
-                        mode=process_rule.mode,
-                        rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
-                        created_by=account.id,
-                    )
-                else:
-                    logging.warn(
-                        f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
-                    )
-                    return
-                db.session.add(dataset_process_rule)
-                db.session.commit()
+                if process_rule:
+                    if process_rule.mode in ("custom", "hierarchical"):
+                        dataset_process_rule = DatasetProcessRule(
+                            dataset_id=dataset.id,
+                            mode=process_rule.mode,
+                            rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
+                            created_by=account.id,
+                        )
+                    elif process_rule.mode == "automatic":
+                        dataset_process_rule = DatasetProcessRule(
+                            dataset_id=dataset.id,
+                            mode=process_rule.mode,
+                            rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
+                            created_by=account.id,
+                        )
+                    else:
+                        logging.warn(
+                            f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
+                        )
+                        return
+                    db.session.add(dataset_process_rule)
+                    db.session.commit()
             lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
             with redis_client.lock(lock_name, timeout=600):
                 position = DocumentService.get_documents_position(dataset.id)
                 document_ids = []
                 duplicate_document_ids = []
                 if knowledge_config.data_source.info_list.data_source_type == "upload_file":
-                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
+                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
                     for file_id in upload_file_list:
                         file = (
                             db.session.query(UploadFile)
@@ -854,7 +859,7 @@ class DocumentService:
                                 name=file_name,
                             ).first()
                             if document:
-                                document.dataset_process_rule_id = dataset_process_rule.id
+                                document.dataset_process_rule_id = dataset_process_rule.id  # type: ignore
                                 document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                                 document.created_from = created_from
                                 document.doc_form = knowledge_config.doc_form
@@ -868,7 +873,7 @@ class DocumentService:
                                 continue
                         document = DocumentService.build_document(
                             dataset,
-                            dataset_process_rule.id,
+                            dataset_process_rule.id,  # type: ignore
                             knowledge_config.data_source.info_list.data_source_type,
                             knowledge_config.doc_form,
                             knowledge_config.doc_language,
@@ -886,6 +891,8 @@ class DocumentService:
                         position += 1
                 elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
                     notion_info_list = knowledge_config.data_source.info_list.notion_info_list
+                    if not notion_info_list:
+                        raise ValueError("No notion info list found.")
                     exist_page_ids = []
                     exist_document = {}
                     documents = Document.query.filter_by(
@@ -921,7 +928,7 @@ class DocumentService:
                                 }
                                 document = DocumentService.build_document(
                                     dataset,
-                                    dataset_process_rule.id,
+                                    dataset_process_rule.id,  # type: ignore
                                     knowledge_config.data_source.info_list.data_source_type,
                                     knowledge_config.doc_form,
                                     knowledge_config.doc_language,
@@ -944,6 +951,8 @@ class DocumentService:
                         clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
                 elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
                     website_info = knowledge_config.data_source.info_list.website_info_list
+                    if not website_info:
+                        raise ValueError("No website info list found.")
                     urls = website_info.urls
                     for url in urls:
                         data_source_info = {
@@ -959,7 +968,7 @@ class DocumentService:
                             document_name = url
                         document = DocumentService.build_document(
                             dataset,
-                            dataset_process_rule.id,
+                            dataset_process_rule.id,  # type: ignore
                             knowledge_config.data_source.info_list.data_source_type,
                             knowledge_config.doc_form,
                             knowledge_config.doc_language,
@@ -1054,7 +1063,7 @@ class DocumentService:
                 dataset_process_rule = DatasetProcessRule(
                     dataset_id=dataset.id,
                     mode=process_rule.mode,
-                    rules=process_rule.rules.model_dump_json(),
+                    rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
                     created_by=account.id,
                 )
             elif process_rule.mode == "automatic":
@@ -1073,6 +1082,8 @@ class DocumentService:
             file_name = ""
             data_source_info = {}
             if document_data.data_source.info_list.data_source_type == "upload_file":
+                if not document_data.data_source.info_list.file_info_list:
+                    raise ValueError("No file info list found.")
                 upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
                 for file_id in upload_file_list:
                     file = (
@@ -1090,6 +1101,8 @@ class DocumentService:
                         "upload_file_id": file_id,
                     }
             elif document_data.data_source.info_list.data_source_type == "notion_import":
+                if not document_data.data_source.info_list.notion_info_list:
+                    raise ValueError("No notion info list found.")
                 notion_info_list = document_data.data_source.info_list.notion_info_list
                 for notion_info in notion_info_list:
                     workspace_id = notion_info.workspace_id
@@ -1107,20 +1120,21 @@ class DocumentService:
                         data_source_info = {
                             "notion_workspace_id": workspace_id,
                             "notion_page_id": page.page_id,
-                            "notion_page_icon": page.page_icon,
+                            "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,  # type: ignore
                             "type": page.type,
                         }
             elif document_data.data_source.info_list.data_source_type == "website_crawl":
                 website_info = document_data.data_source.info_list.website_info_list
-                urls = website_info.urls
-                for url in urls:
-                    data_source_info = {
-                        "url": url,
-                        "provider": website_info.provider,
-                        "job_id": website_info.job_id,
-                        "only_main_content": website_info.only_main_content,
-                        "mode": "crawl",
-                    }
+                if website_info:
+                    urls = website_info.urls
+                    for url in urls:
+                        data_source_info = {
+                            "url": url,
+                            "provider": website_info.provider,
+                            "job_id": website_info.job_id,
+                            "only_main_content": website_info.only_main_content,  # type: ignore
+                            "mode": "crawl",
+                        }
             document.data_source_type = document_data.data_source.info_list.data_source_type
             document.data_source_info = json.dumps(data_source_info)
             document.name = file_name
@@ -1155,15 +1169,21 @@ class DocumentService:
         if features.billing.enabled:
             count = 0
             if knowledge_config.data_source.info_list.data_source_type == "upload_file":
-                upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
+                upload_file_list = (
+                    knowledge_config.data_source.info_list.file_info_list.file_ids
+                    if knowledge_config.data_source.info_list.file_info_list
+                    else []
+                )
                 count = len(upload_file_list)
             elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
                 notion_info_list = knowledge_config.data_source.info_list.notion_info_list
-                for notion_info in notion_info_list:
-                    count = count + len(notion_info.pages)
+                if notion_info_list:
+                    for notion_info in notion_info_list:
+                        count = count + len(notion_info.pages)
             elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
                 website_info = knowledge_config.data_source.info_list.website_info_list
-                count = len(website_info.urls)
+                if website_info:
+                    count = len(website_info.urls)
             batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -1174,20 +1194,20 @@ class DocumentService:
         retrieval_model = None
         if knowledge_config.indexing_technique == "high_quality":
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                knowledge_config.embedding_model_provider, knowledge_config.embedding_model
+                knowledge_config.embedding_model_provider,  # type: ignore
+                knowledge_config.embedding_model,  # type: ignore
             )
             dataset_collection_binding_id = dataset_collection_binding.id
             if knowledge_config.retrieval_model:
                 retrieval_model = knowledge_config.retrieval_model
             else:
-                default_retrieval_model = {
-                    "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
-                    "reranking_enable": False,
-                    "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
-                    "top_k": 2,
-                    "score_threshold_enabled": False,
-                }
-                retrieval_model = RetrievalModel(**default_retrieval_model)
+                retrieval_model = RetrievalModel(
+                    search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
+                    reranking_enable=False,
+                    reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
+                    top_k=2,
+                    score_threshold_enabled=False,
+                )
         # save dataset
         dataset = Dataset(
             tenant_id=tenant_id,
@@ -1557,12 +1577,12 @@ class SegmentService:
                 raise ValueError("Can't update disabled segment")
         try:
             word_count_change = segment.word_count
-            content = args.content
+            content = args.content or segment.content
             if segment.content == content:
                 segment.word_count = len(content)
                 if document.doc_form == "qa_model":
                     segment.answer = args.answer
-                    segment.word_count += len(args.answer)
+                    segment.word_count += len(args.answer) if args.answer else 0
                 word_count_change = segment.word_count - word_count_change
                 if args.keywords:
                     segment.keywords = args.keywords
@@ -1577,7 +1597,12 @@ class SegmentService:
                     db.session.add(document)
                 # update segment index task
                 if args.enabled:
-                    VectorService.create_segments_vector([args.keywords], [segment], dataset)
+                    VectorService.create_segments_vector(
+                        [args.keywords] if args.keywords else None,
+                        [segment],
+                        dataset,
+                        document.doc_form,
+                    )
                 if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
                     # regenerate child chunks
                     # get embedding model instance
@@ -1605,6 +1630,8 @@ class SegmentService:
                         .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
                         .first()
                     )
+                    if not processing_rule:
+                        raise ValueError("No processing rule found.")
                     VectorService.generate_child_chunks(
                         segment, document, dataset, embedding_model_instance, processing_rule, True
                     )
@@ -1639,7 +1666,7 @@ class SegmentService:
                 segment.disabled_by = None
                 if document.doc_form == "qa_model":
                     segment.answer = args.answer
-                    segment.word_count += len(args.answer)
+                    segment.word_count += len(args.answer) if args.answer else 0
                 word_count_change = segment.word_count - word_count_change
                 # update document word count
                 if word_count_change != 0:
@@ -1673,6 +1700,8 @@ class SegmentService:
                         .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
                         .first()
                     )
+                    if not processing_rule:
+                        raise ValueError("No processing rule found.")
                     VectorService.generate_child_chunks(
                         segment, document, dataset, embedding_model_instance, processing_rule, True
                     )

+ 1 - 1
api/services/entities/knowledge_entities/knowledge_entities.py

@@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
     original_document_id: Optional[str] = None
     duplicate: bool = True
     indexing_technique: Literal["high_quality", "economy"]
-    data_source: Optional[DataSource] = None
+    data_source: DataSource
     process_rule: Optional[ProcessRule] = None
     retrieval_model: Optional[RetrievalModel] = None
     doc_form: str = "text_model"

+ 1 - 1
api/services/hit_testing_service.py

@@ -69,7 +69,7 @@ class HitTestingService:
         db.session.add(dataset_query)
         db.session.commit()
 
-        return cls.compact_retrieve_response(query, all_documents)
+        return cls.compact_retrieve_response(query, all_documents)  # type: ignore
 
     @classmethod
     def external_retrieve(

+ 4 - 2
api/services/vector_service.py

@@ -29,6 +29,8 @@ class VectorService:
                     .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
                     .first()
                 )
+                if not processing_rule:
+                    raise ValueError("No processing rule found.")
                 # get embedding model instance
                 if dataset.indexing_technique == "high_quality":
                     # check embedding model setting
@@ -98,7 +100,7 @@ class VectorService:
     def generate_child_chunks(
         cls,
         segment: DocumentSegment,
-        dataset_document: Document,
+        dataset_document: DatasetDocument,
         dataset: Dataset,
         embedding_model_instance: ModelInstance,
         processing_rule: DatasetProcessRule,
@@ -130,7 +132,7 @@ class VectorService:
             doc_language=dataset_document.doc_language,
         )
         # save child chunks
-        if len(documents) > 0 and len(documents[0].children) > 0:
+        if documents and documents[0].children:
             index_processor.load(dataset, documents)
 
             for position, child_chunk in enumerate(documents[0].children, start=1):

+ 2 - 1
api/tasks/batch_clean_document_task.py

@@ -44,7 +44,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
                 for upload_file_id in image_upload_file_ids:
                     image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
                     try:
-                        storage.delete(image_file.key)
+                        if image_file and image_file.key:
+                            storage.delete(image_file.key)
                     except Exception:
                         logging.exception(
                             "Delete image_files failed when storage deleted, \