瀏覽代碼

fix segment enable service api (#10445)

Jyong 5 月之前
父節點
當前提交
888d7e6422
共有 2 個文件被更改,包括 24 次插入25 次删除
  1. 14 25
      api/services/dataset_service.py
  2. 10 0
      api/services/entities/knowledge_entities/knowledge_entities.py

+ 14 - 25
api/services/dataset_service.py

@@ -14,8 +14,6 @@ from configs import dify_config
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.rag.datasource.keyword.keyword_factory import Keyword
-from core.rag.models.document import Document as RAGDocument
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
@@ -37,6 +35,7 @@ from models.dataset import (
 )
 from models.model import UploadFile
 from models.source import DataSourceOauthBinding
+from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity
 from services.errors.account import NoPermissionError
 from services.errors.dataset import DatasetNameDuplicateError
 from services.errors.document import DocumentIndexingError
@@ -1503,12 +1502,13 @@ class SegmentService:
 
     @classmethod
     def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
+        segment_update_entity = SegmentUpdateEntity(**args)
         indexing_cache_key = "segment_{}_indexing".format(segment.id)
         cache_result = redis_client.get(indexing_cache_key)
         if cache_result is not None:
             raise ValueError("Segment is indexing, please try again later")
-        if "enabled" in args and args["enabled"] is not None:
-            action = args["enabled"]
+        if segment_update_entity.enabled is not None:
+            action = segment_update_entity.enabled
             if segment.enabled != action:
                 if not action:
                     segment.enabled = action
@@ -1521,37 +1521,26 @@ class SegmentService:
                     disable_segment_from_index_task.delay(segment.id)
                     return segment
         if not segment.enabled:
-            if "enabled" in args and args["enabled"] is not None:
-                if not args["enabled"]:
+            if segment_update_entity.enabled is not None:
+                if not segment_update_entity.enabled:
                     raise ValueError("Can't update disabled segment")
             else:
                 raise ValueError("Can't update disabled segment")
         try:
-            content = args["content"]
+            content = segment_update_entity.content
             if segment.content == content:
                 if document.doc_form == "qa_model":
-                    segment.answer = args["answer"]
-                if args.get("keywords"):
-                    segment.keywords = args["keywords"]
+                    segment.answer = segment_update_entity.answer
+                if segment_update_entity.keywords:
+                    segment.keywords = segment_update_entity.keywords
                 segment.enabled = True
                 segment.disabled_at = None
                 segment.disabled_by = None
                 db.session.add(segment)
                 db.session.commit()
                 # update segment index task
-                if "keywords" in args:
-                    keyword = Keyword(dataset)
-                    keyword.delete_by_ids([segment.index_node_id])
-                    document = RAGDocument(
-                        page_content=segment.content,
-                        metadata={
-                            "doc_id": segment.index_node_id,
-                            "doc_hash": segment.index_node_hash,
-                            "document_id": segment.document_id,
-                            "dataset_id": segment.dataset_id,
-                        },
-                    )
-                    keyword.add_texts([document], keywords_list=[args["keywords"]])
+                if segment_update_entity.enabled:
+                    VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
             else:
                 segment_hash = helper.generate_text_hash(content)
                 tokens = 0
@@ -1579,11 +1568,11 @@ class SegmentService:
                 segment.disabled_at = None
                 segment.disabled_by = None
                 if document.doc_form == "qa_model":
-                    segment.answer = args["answer"]
+                    segment.answer = segment_update_entity.answer
                 db.session.add(segment)
                 db.session.commit()
                 # update segment vector index
-                VectorService.update_segment_vector(args["keywords"], segment, dataset)
+                VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset)
 
         except Exception as e:
             logging.exception("update segment index failed")

+ 10 - 0
api/services/entities/knowledge_entities/knowledge_entities.py

@@ -0,0 +1,10 @@
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class SegmentUpdateEntity(BaseModel):
+    content: str
+    answer: Optional[str] = None
+    keywords: Optional[list[str]] = None
+    enabled: Optional[bool] = None