|
@@ -3,12 +3,13 @@ import json
|
|
|
import math
|
|
|
import random
|
|
|
import string
|
|
|
+import threading
|
|
|
import time
|
|
|
import uuid
|
|
|
|
|
|
import click
|
|
|
from tqdm import tqdm
|
|
|
-from flask import current_app
|
|
|
+from flask import current_app, Flask
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
from werkzeug.exceptions import NotFound
|
|
|
|
|
@@ -456,92 +457,92 @@ def update_qdrant_indexes():
|
|
|
@click.command('normalization-collections', help='restore all collections in one')
|
|
|
def normalization_collections():
|
|
|
click.echo(click.style('Start normalization collections.', fg='green'))
|
|
|
- normalization_count = 0
|
|
|
-
|
|
|
+ normalization_count = []
|
|
|
page = 1
|
|
|
while True:
|
|
|
try:
|
|
|
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
|
|
- .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
|
|
+ .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
|
|
except NotFound:
|
|
|
break
|
|
|
-
|
|
|
+ datasets_result = datasets.items
|
|
|
page += 1
|
|
|
- for dataset in datasets:
|
|
|
- if not dataset.collection_binding_id:
|
|
|
- try:
|
|
|
- click.echo('restore dataset index: {}'.format(dataset.id))
|
|
|
- try:
|
|
|
- embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=dataset.tenant_id,
|
|
|
- model_provider_name=dataset.embedding_model_provider,
|
|
|
- model_name=dataset.embedding_model
|
|
|
- )
|
|
|
- except Exception:
|
|
|
- provider = Provider(
|
|
|
- id='provider_id',
|
|
|
- tenant_id=dataset.tenant_id,
|
|
|
- provider_name='openai',
|
|
|
- provider_type=ProviderType.CUSTOM.value,
|
|
|
- encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
|
|
- is_valid=True,
|
|
|
- )
|
|
|
- model_provider = OpenAIProvider(provider=provider)
|
|
|
- embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
- model_provider=model_provider)
|
|
|
- embeddings = CacheEmbedding(embedding_model)
|
|
|
- dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
- filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
|
|
- DatasetCollectionBinding.model_name == embedding_model.name). \
|
|
|
- order_by(DatasetCollectionBinding.created_at). \
|
|
|
- first()
|
|
|
-
|
|
|
- if not dataset_collection_binding:
|
|
|
- dataset_collection_binding = DatasetCollectionBinding(
|
|
|
- provider_name=embedding_model.model_provider.provider_name,
|
|
|
- model_name=embedding_model.name,
|
|
|
- collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
|
|
- )
|
|
|
- db.session.add(dataset_collection_binding)
|
|
|
- db.session.commit()
|
|
|
+ for i in range(0, len(datasets_result), 5):
|
|
|
+ threads = []
|
|
|
+ sub_datasets = datasets_result[i:i + 5]
|
|
|
+ for dataset in sub_datasets:
|
|
|
+ document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
|
|
+ 'flask_app': current_app._get_current_object(),
|
|
|
+ 'dataset': dataset,
|
|
|
+ 'normalization_count': normalization_count
|
|
|
+ })
|
|
|
+ threads.append(document_format_thread)
|
|
|
+ document_format_thread.start()
|
|
|
+ for thread in threads:
|
|
|
+ thread.join()
|
|
|
+
|
|
|
+ click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
|
|
+
|
|
|
+
|
|
|
+def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
|
|
+ with flask_app.app_context():
|
|
|
+ try:
|
|
|
+ click.echo('restore dataset index: {}'.format(dataset.id))
|
|
|
+ try:
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
+ except Exception:
|
|
|
+ provider = Provider(
|
|
|
+ id='provider_id',
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ provider_name='openai',
|
|
|
+ provider_type=ProviderType.CUSTOM.value,
|
|
|
+ encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
|
|
+ is_valid=True,
|
|
|
+ )
|
|
|
+ model_provider = OpenAIProvider(provider=provider)
|
|
|
+ embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
+ model_provider=model_provider)
|
|
|
+ embeddings = CacheEmbedding(embedding_model)
|
|
|
+ dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
+ filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
|
|
+ DatasetCollectionBinding.model_name == embedding_model.name). \
|
|
|
+ order_by(DatasetCollectionBinding.created_at). \
|
|
|
+ first()
|
|
|
+
|
|
|
+ if not dataset_collection_binding:
|
|
|
+ dataset_collection_binding = DatasetCollectionBinding(
|
|
|
+ provider_name=embedding_model.model_provider.provider_name,
|
|
|
+ model_name=embedding_model.name,
|
|
|
+ collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
|
|
+ )
|
|
|
+ db.session.add(dataset_collection_binding)
|
|
|
+ db.session.commit()
|
|
|
|
|
|
- from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
|
-
|
|
|
- index = QdrantVectorIndex(
|
|
|
- dataset=dataset,
|
|
|
- config=QdrantConfig(
|
|
|
- endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
- api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
- root_path=current_app.root_path
|
|
|
- ),
|
|
|
- embeddings=embeddings
|
|
|
- )
|
|
|
- if index:
|
|
|
- index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
|
|
- else:
|
|
|
- click.echo('passed.')
|
|
|
-
|
|
|
- original_index = QdrantVectorIndex(
|
|
|
- dataset=dataset,
|
|
|
- config=QdrantConfig(
|
|
|
- endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
- api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
- root_path=current_app.root_path
|
|
|
- ),
|
|
|
- embeddings=embeddings
|
|
|
- )
|
|
|
- if original_index:
|
|
|
- original_index.delete_original_collection(dataset, dataset_collection_binding)
|
|
|
- normalization_count += 1
|
|
|
- else:
|
|
|
- click.echo('passed.')
|
|
|
- except Exception as e:
|
|
|
- click.echo(
|
|
|
- click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
- fg='red'))
|
|
|
- continue
|
|
|
-
|
|
|
- click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
|
|
|
+ from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
|
+
|
|
|
+ index = QdrantVectorIndex(
|
|
|
+ dataset=dataset,
|
|
|
+ config=QdrantConfig(
|
|
|
+ endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
+ api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
+ root_path=current_app.root_path
|
|
|
+ ),
|
|
|
+ embeddings=embeddings
|
|
|
+ )
|
|
|
+ if index:
|
|
|
+ # index.delete_by_group_id(dataset.id)
|
|
|
+ index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
|
|
+ else:
|
|
|
+ click.echo('passed.')
|
|
|
+ normalization_count.append(1)
|
|
|
+ except Exception as e:
|
|
|
+ click.echo(
|
|
|
+ click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
+ fg='red'))
|
|
|
|
|
|
|
|
|
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|