Jyong 3 hónapja
szülő
commit
bee32d960a

+ 2 - 1
api/controllers/console/datasets/datasets_document.py

@@ -257,7 +257,8 @@ class DatasetDocumentListApi(Resource):
         parser.add_argument("original_document_id", type=str, required=False, location="json")
         parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
         parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
-
+        parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
         parser.add_argument(
             "doc_language", type=str, default="English", required=False, nullable=False, location="json"
         )

+ 17 - 7
api/services/dataset_service.py

@@ -792,13 +792,19 @@ class DocumentService:
             dataset.indexing_technique = knowledge_config.indexing_technique
             if knowledge_config.indexing_technique == "high_quality":
                 model_manager = ModelManager()
-                embedding_model = model_manager.get_default_model_instance(
-                    tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
-                )
-                dataset.embedding_model = embedding_model.model
-                dataset.embedding_model_provider = embedding_model.provider
+                if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
+                    dataset_embedding_model = knowledge_config.embedding_model
+                    dataset_embedding_model_provider = knowledge_config.embedding_model_provider
+                else:
+                    embedding_model = model_manager.get_default_model_instance(
+                        tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
+                    )
+                    dataset_embedding_model = embedding_model.model
+                    dataset_embedding_model_provider = embedding_model.provider
+                dataset.embedding_model = dataset_embedding_model
+                dataset.embedding_model_provider = dataset_embedding_model_provider
                 dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                    embedding_model.provider, embedding_model.model
+                    dataset_embedding_model_provider, dataset_embedding_model
                 )
                 dataset.collection_binding_id = dataset_collection_binding.id
                 if not dataset.retrieval_model:
@@ -810,7 +816,11 @@ class DocumentService:
                         "score_threshold_enabled": False,
                     }
 
-                    dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model  # type: ignore
+                    dataset.retrieval_model = (
+                        knowledge_config.retrieval_model.model_dump()
+                        if knowledge_config.retrieval_model
+                        else default_retrieval_model
+                    )  # type: ignore
 
         documents = []
         if knowledge_config.original_document_id:

+ 4 - 1
api/tasks/deal_dataset_vector_index_task.py

@@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
 
         if not dataset:
             raise Exception("Dataset not found")
-        index_type = dataset.doc_form
+        index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
         index_processor = IndexProcessorFactory(index_type).init_index_processor()
         if action == "remove":
             index_processor.clean(dataset, None, with_keywords=False)
@@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
                             {"indexing_status": "error", "error": str(e)}, synchronize_session=False
                         )
                         db.session.commit()
+            else:
+                # clean collection
+                index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
 
         end_at = time.perf_counter()
         logging.info(