Browse Source

Feat/add annotation migrate (#2675)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 year ago
parent
commit
3631e53ff0
2 changed files with 111 additions and 3 deletions
  1. 109 2
      api/commands.py
  2. 2 1
      api/core/rag/datasource/vdb/milvus/milvus_vector.py

+ 109 - 2
api/commands.py

@@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
 from models.account import Tenant
 from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
 from models.dataset import Document as DatasetDocument
-from models.model import Account
+from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
 from models.provider import Provider, ProviderModel
 
 
@@ -125,7 +125,114 @@ def reset_encrypt_key_pair():
 
 
 @click.command('vdb-migrate', help='migrate vector db.')
-def vdb_migrate():
+@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
+def vdb_migrate(scope: str):
+    if scope in ['knowledge', 'all']:
+        migrate_knowledge_vector_database()
+    if scope in ['annotation', 'all']:
+        migrate_annotation_vector_database()
+
+
+def migrate_annotation_vector_database():
+    """
+    Migrate annotation datas to target vector database .
+    """
+    click.echo(click.style('Start migrate annotation data.', fg='green'))
+    create_count = 0
+    skipped_count = 0
+    total_count = 0
+    page = 1
+    while True:
+        try:
+            # get apps info
+            apps = db.session.query(App).filter(
+                App.status == 'normal'
+            ).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for app in apps:
+            total_count = total_count + 1
+            click.echo(f'Processing the {total_count} app {app.id}. '
+                       + f'{create_count} created, {skipped_count} skipped.')
+            try:
+                click.echo('Create app annotation index: {}'.format(app.id))
+                app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
+                    AppAnnotationSetting.app_id == app.id
+                ).first()
+
+                if not app_annotation_setting:
+                    skipped_count = skipped_count + 1
+                    click.echo('App annotation setting is disabled: {}'.format(app.id))
+                    continue
+                # get dataset_collection_binding info
+                dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
+                    DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
+                ).first()
+                if not dataset_collection_binding:
+                    click.echo('App annotation collection binding is not exist: {}'.format(app.id))
+                    continue
+                annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
+                dataset = Dataset(
+                    id=app.id,
+                    tenant_id=app.tenant_id,
+                    indexing_technique='high_quality',
+                    embedding_model_provider=dataset_collection_binding.provider_name,
+                    embedding_model=dataset_collection_binding.model_name,
+                    collection_binding_id=dataset_collection_binding.id
+                )
+                documents = []
+                if annotations:
+                    for annotation in annotations:
+                        document = Document(
+                            page_content=annotation.question,
+                            metadata={
+                                "annotation_id": annotation.id,
+                                "app_id": app.id,
+                                "doc_id": annotation.id
+                            }
+                        )
+                        documents.append(document)
+
+                vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
+                click.echo(f"Start to migrate annotation, app_id: {app.id}.")
+
+                try:
+                    vector.delete()
+                    click.echo(
+                        click.style(f'Successfully delete vector index for app: {app.id}.',
+                                    fg='green'))
+                except Exception as e:
+                    click.echo(
+                        click.style(f'Failed to delete vector index for app {app.id}.',
+                                    fg='red'))
+                    raise e
+                if documents:
+                    try:
+                        click.echo(click.style(
+                            f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
+                            fg='green'))
+                        vector.create(documents)
+                        click.echo(
+                            click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
+                    except Exception as e:
+                        click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
+                        raise e
+                click.echo(f'Successfully migrated app annotation {app.id}.')
+                create_count += 1
+            except Exception as e:
+                click.echo(
+                    click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                fg='red'))
+                continue
+
+    click.echo(
+        click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
+                    fg='green'))
+
+
+def migrate_knowledge_vector_database():
     """
     Migrate vector database datas to target vector database .
     """

+ 2 - 1
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -140,7 +140,8 @@ class MilvusVector(BaseVector):
         connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
 
         from pymilvus import utility
-        utility.drop_collection(self._collection_name, None, using=alias)
+        if utility.has_collection(self._collection_name, using=alias):
+            utility.drop_collection(self._collection_name, None, using=alias)
 
     def text_exists(self, id: str) -> bool: