Prechádzať zdrojové kódy

update qdrant migrate command (#2260)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 rok pred
rodič
commit
409e0c8e1c
1 zmenil súbory, kde vykonal 16 pridanie a 49 odobranie
  1. 16 49
      api/commands.py

+ 16 - 49
api/commands.py

@@ -339,26 +339,7 @@ def create_qdrant_indexes():
 
                             )
                         except Exception:
-                            try:
-                                embedding_model = model_manager.get_default_model_instance(
-                                    tenant_id=dataset.tenant_id,
-                                    model_type=ModelType.TEXT_EMBEDDING,
-                                )
-                                dataset.embedding_model = embedding_model.model
-                                dataset.embedding_model_provider = embedding_model.provider
-                            except Exception:
-
-                                provider = Provider(
-                                    id='provider_id',
-                                    tenant_id=dataset.tenant_id,
-                                    provider_name='openai',
-                                    provider_type=ProviderType.SYSTEM.value,
-                                    encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
-                                    is_valid=True,
-                                )
-                                model_provider = OpenAIProvider(provider=provider)
-                                embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
-                                                                  model_provider=model_provider)
+                            continue
                         embeddings = CacheEmbedding(embedding_model)
 
                         from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
@@ -405,7 +386,7 @@ def update_qdrant_indexes():
                 .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
         except NotFound:
             break
-
+        model_manager = ModelManager()
         page += 1
         for dataset in datasets:
             if dataset.index_struct_dict:
@@ -413,23 +394,15 @@ def update_qdrant_indexes():
                     try:
                         click.echo('Update dataset qdrant index: {}'.format(dataset.id))
                         try:
-                            embedding_model = ModelFactory.get_embedding_model(
+                            embedding_model = model_manager.get_model_instance(
                                 tenant_id=dataset.tenant_id,
-                                model_provider_name=dataset.embedding_model_provider,
-                                model_name=dataset.embedding_model
+                                provider=dataset.embedding_model_provider,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                                model=dataset.embedding_model
+
                             )
                         except Exception:
-                            provider = Provider(
-                                id='provider_id',
-                                tenant_id=dataset.tenant_id,
-                                provider_name='openai',
-                                provider_type=ProviderType.CUSTOM.value,
-                                encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
-                                is_valid=True,
-                            )
-                            model_provider = OpenAIProvider(provider=provider)
-                            embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
-                                                              model_provider=model_provider)
+                            continue
                         embeddings = CacheEmbedding(embedding_model)
 
                         from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
@@ -524,23 +497,17 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count:
         try:
             click.echo('restore dataset index: {}'.format(dataset.id))
             try:
-                embedding_model = ModelFactory.get_embedding_model(
+                model_manager = ModelManager()
+
+                embedding_model = model_manager.get_model_instance(
                     tenant_id=dataset.tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
+
                 )
             except Exception:
-                provider = Provider(
-                    id='provider_id',
-                    tenant_id=dataset.tenant_id,
-                    provider_name='openai',
-                    provider_type=ProviderType.CUSTOM.value,
-                    encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
-                    is_valid=True,
-                )
-                model_provider = OpenAIProvider(provider=provider)
-                embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
-                                                  model_provider=model_provider)
+                pass
             embeddings = CacheEmbedding(embedding_model)
             dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
                 filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,