Bladeren bron

code merge error (#8183)

Co-authored-by: crazywoola <427733928@qq.com>
Jyong 7 maanden geleden
bovenliggende
commit
85ff82a694
2 gewijzigde bestanden met toevoegingen van 5 en 5 verwijderingen
  1. 4 0
      api/controllers/console/datasets/datasets_document.py
  2. 1 5
      api/services/dataset_service.py

+ 4 - 0
api/controllers/console/datasets/datasets_document.py

@@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
             "doc_language", type=str, default="English", required=False, nullable=False, location="json"
         )
         parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
+        parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
+        parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
         args = parser.parse_args()
 
         # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
             raise Forbidden()
 
         if args["indexing_technique"] == "high_quality":
+            if args["embedding_model"] is None or args["embedding_model_provider"] is None:
+                raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
             try:
                 model_manager = ModelManager()
                 model_manager.get_default_model_instance(

+ 1 - 5
api/services/dataset_service.py

@@ -1057,12 +1057,8 @@ class DocumentService:
         dataset_collection_binding_id = None
         retrieval_model = None
         if document_data["indexing_technique"] == "high_quality":
-            model_manager = ModelManager()
-            embedding_model = model_manager.get_default_model_instance(
-                tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
-            )
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                embedding_model.provider, embedding_model.model
+                document_data["embedding_model_provider"], document_data["embedding_model"]
             )
             dataset_collection_binding_id = dataset_collection_binding.id
             if document_data.get("retrieval_model"):