Explorar o código

Feat/add annotation migrate (#2675)

Co-authored-by: jyong <jyong@dify.ai>
Jyong hai 1 ano
pai
achega
3631e53ff0
Modificáronse 2 ficheiros con 111 adicións e 3 borrados
  1. 109 2
      api/commands.py
  2. 2 1
      api/core/rag/datasource/vdb/milvus/milvus_vector.py

+ 109 - 2
api/commands.py

@@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
 from models.account import Tenant
 from models.account import Tenant
 from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
 from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from models.dataset import Document as DatasetDocument
-from models.model import Account
+from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
 from models.provider import Provider, ProviderModel
 from models.provider import Provider, ProviderModel
 
 
 
 
@@ -125,7 +125,114 @@ def reset_encrypt_key_pair():
 
 
 
 
 @click.command('vdb-migrate', help='migrate vector db.')
 @click.command('vdb-migrate', help='migrate vector db.')
-def vdb_migrate():
+@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
+def vdb_migrate(scope: str):
+    if scope in ['knowledge', 'all']:
+        migrate_knowledge_vector_database()
+    if scope in ['annotation', 'all']:
+        migrate_annotation_vector_database()
+
+
+def migrate_annotation_vector_database():
+    """
+    Migrate annotation datas to target vector database .
+    """
+    click.echo(click.style('Start migrate annotation data.', fg='green'))
+    create_count = 0
+    skipped_count = 0
+    total_count = 0
+    page = 1
+    while True:
+        try:
+            # get apps info
+            apps = db.session.query(App).filter(
+                App.status == 'normal'
+            ).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for app in apps:
+            total_count = total_count + 1
+            click.echo(f'Processing the {total_count} app {app.id}. '
+                       + f'{create_count} created, {skipped_count} skipped.')
+            try:
+                click.echo('Create app annotation index: {}'.format(app.id))
+                app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
+                    AppAnnotationSetting.app_id == app.id
+                ).first()
+
+                if not app_annotation_setting:
+                    skipped_count = skipped_count + 1
+                    click.echo('App annotation setting is disabled: {}'.format(app.id))
+                    continue
+                # get dataset_collection_binding info
+                dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
+                    DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
+                ).first()
+                if not dataset_collection_binding:
+                    click.echo('App annotation collection binding is not exist: {}'.format(app.id))
+                    continue
+                annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
+                dataset = Dataset(
+                    id=app.id,
+                    tenant_id=app.tenant_id,
+                    indexing_technique='high_quality',
+                    embedding_model_provider=dataset_collection_binding.provider_name,
+                    embedding_model=dataset_collection_binding.model_name,
+                    collection_binding_id=dataset_collection_binding.id
+                )
+                documents = []
+                if annotations:
+                    for annotation in annotations:
+                        document = Document(
+                            page_content=annotation.question,
+                            metadata={
+                                "annotation_id": annotation.id,
+                                "app_id": app.id,
+                                "doc_id": annotation.id
+                            }
+                        )
+                        documents.append(document)
+
+                vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
+                click.echo(f"Start to migrate annotation, app_id: {app.id}.")
+
+                try:
+                    vector.delete()
+                    click.echo(
+                        click.style(f'Successfully delete vector index for app: {app.id}.',
+                                    fg='green'))
+                except Exception as e:
+                    click.echo(
+                        click.style(f'Failed to delete vector index for app {app.id}.',
+                                    fg='red'))
+                    raise e
+                if documents:
+                    try:
+                        click.echo(click.style(
+                            f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
+                            fg='green'))
+                        vector.create(documents)
+                        click.echo(
+                            click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
+                    except Exception as e:
+                        click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
+                        raise e
+                click.echo(f'Successfully migrated app annotation {app.id}.')
+                create_count += 1
+            except Exception as e:
+                click.echo(
+                    click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                fg='red'))
+                continue
+
+    click.echo(
+        click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
+                    fg='green'))
+
+
+def migrate_knowledge_vector_database():
     """
     """
     Migrate vector database datas to target vector database .
     Migrate vector database datas to target vector database .
     """
     """

+ 2 - 1
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -140,7 +140,8 @@ class MilvusVector(BaseVector):
         connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
         connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
 
 
         from pymilvus import utility
         from pymilvus import utility
-        utility.drop_collection(self._collection_name, None, using=alias)
+        if utility.has_collection(self._collection_name, using=alias):
+            utility.drop_collection(self._collection_name, None, using=alias)
 
 
     def text_exists(self, id: str) -> bool:
     def text_exists(self, id: str) -> bool: