|
@@ -4,6 +4,7 @@ import math
|
|
|
import random
|
|
|
import string
|
|
|
import time
|
|
|
+import uuid
|
|
|
|
|
|
import click
|
|
|
from tqdm import tqdm
|
|
@@ -23,7 +24,7 @@ from libs.helper import email as email_validate
|
|
|
from extensions.ext_database import db
|
|
|
from libs.rsa import generate_key_pair
|
|
|
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
|
|
-from models.dataset import Dataset, DatasetQuery, Document
|
|
|
+from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
|
|
from models.model import Account, AppModelConfig, App
|
|
|
import secrets
|
|
|
import base64
|
|
@@ -239,7 +240,13 @@ def clean_unused_dataset_indexes():
|
|
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
|
|
# delete from vector index
|
|
|
if vector_index:
|
|
|
- vector_index.delete()
|
|
|
+ if dataset.collection_binding_id:
|
|
|
+ vector_index.delete_by_group_id(dataset.id)
|
|
|
+ else:
|
|
|
+ if dataset.collection_binding_id:
|
|
|
+ vector_index.delete_by_group_id(dataset.id)
|
|
|
+ else:
|
|
|
+ vector_index.delete()
|
|
|
kw_index.delete()
|
|
|
# update document
|
|
|
update_params = {
|
|
@@ -346,7 +353,8 @@ def create_qdrant_indexes():
|
|
|
is_valid=True,
|
|
|
)
|
|
|
model_provider = OpenAIProvider(provider=provider)
|
|
|
- embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
|
|
+ embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
+ model_provider=model_provider)
|
|
|
embeddings = CacheEmbedding(embedding_model)
|
|
|
|
|
|
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
@@ -364,7 +372,8 @@ def create_qdrant_indexes():
|
|
|
index.create_qdrant_dataset(dataset)
|
|
|
index_struct = {
|
|
|
"type": 'qdrant',
|
|
|
- "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
|
|
+ "vector_store": {
|
|
|
+ "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
|
|
}
|
|
|
dataset.index_struct = json.dumps(index_struct)
|
|
|
db.session.commit()
|
|
@@ -373,7 +382,8 @@ def create_qdrant_indexes():
|
|
|
click.echo('passed.')
|
|
|
except Exception as e:
|
|
|
click.echo(
|
|
|
- click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
|
|
+ click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
+ fg='red'))
|
|
|
continue
|
|
|
|
|
|
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
|
@@ -414,7 +424,8 @@ def update_qdrant_indexes():
|
|
|
is_valid=True,
|
|
|
)
|
|
|
model_provider = OpenAIProvider(provider=provider)
|
|
|
- embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
|
|
|
+ embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
+ model_provider=model_provider)
|
|
|
embeddings = CacheEmbedding(embedding_model)
|
|
|
|
|
|
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
@@ -435,11 +446,104 @@ def update_qdrant_indexes():
|
|
|
click.echo('passed.')
|
|
|
except Exception as e:
|
|
|
click.echo(
|
|
|
- click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
|
|
+ click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
+ fg='red'))
|
|
|
continue
|
|
|
|
|
|
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
|
|
|
|
|
|
+
|
|
|
+@click.command('normalization-collections', help='restore all collections in one')
|
|
|
+def normalization_collections():
|
|
|
+ click.echo(click.style('Start normalization collections.', fg='green'))
|
|
|
+ normalization_count = 0
|
|
|
+
|
|
|
+ page = 1
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
|
|
+ .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
|
|
+ except NotFound:
|
|
|
+ break
|
|
|
+
|
|
|
+ page += 1
|
|
|
+ for dataset in datasets:
|
|
|
+ if not dataset.collection_binding_id:
|
|
|
+ try:
|
|
|
+ click.echo('restore dataset index: {}'.format(dataset.id))
|
|
|
+ try:
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
+ except Exception:
|
|
|
+ provider = Provider(
|
|
|
+ id='provider_id',
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ provider_name='openai',
|
|
|
+ provider_type=ProviderType.CUSTOM.value,
|
|
|
+ encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
|
|
+ is_valid=True,
|
|
|
+ )
|
|
|
+ model_provider = OpenAIProvider(provider=provider)
|
|
|
+ embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
+ model_provider=model_provider)
|
|
|
+ embeddings = CacheEmbedding(embedding_model)
|
|
|
+ dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
+ filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
|
|
+ DatasetCollectionBinding.model_name == embedding_model.name). \
|
|
|
+ order_by(DatasetCollectionBinding.created_at). \
|
|
|
+ first()
|
|
|
+
|
|
|
+ if not dataset_collection_binding:
|
|
|
+ dataset_collection_binding = DatasetCollectionBinding(
|
|
|
+ provider_name=embedding_model.model_provider.provider_name,
|
|
|
+ model_name=embedding_model.name,
|
|
|
+ collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
|
|
+ )
|
|
|
+ db.session.add(dataset_collection_binding)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
|
+
|
|
|
+ index = QdrantVectorIndex(
|
|
|
+ dataset=dataset,
|
|
|
+ config=QdrantConfig(
|
|
|
+ endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
+ api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
+ root_path=current_app.root_path
|
|
|
+ ),
|
|
|
+ embeddings=embeddings
|
|
|
+ )
|
|
|
+ if index:
|
|
|
+ index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
|
|
+ else:
|
|
|
+ click.echo('passed.')
|
|
|
+
|
|
|
+ original_index = QdrantVectorIndex(
|
|
|
+ dataset=dataset,
|
|
|
+ config=QdrantConfig(
|
|
|
+ endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
+ api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
+ root_path=current_app.root_path
|
|
|
+ ),
|
|
|
+ embeddings=embeddings
|
|
|
+ )
|
|
|
+ if original_index:
|
|
|
+ original_index.delete_original_collection(dataset, dataset_collection_binding)
|
|
|
+ normalization_count += 1
|
|
|
+ else:
|
|
|
+ click.echo('passed.')
|
|
|
+ except Exception as e:
|
|
|
+ click.echo(
|
|
|
+ click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
+ fg='red'))
|
|
|
+ continue
|
|
|
+
|
|
|
+ click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
|
|
|
+
|
|
|
+
|
|
|
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
|
|
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
|
|
def update_app_model_configs(batch_size):
|
|
@@ -473,7 +577,7 @@ def update_app_model_configs(batch_size):
|
|
|
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
|
|
.filter(App.mode == 'completion') \
|
|
|
.count()
|
|
|
-
|
|
|
+
|
|
|
if total_records == 0:
|
|
|
click.secho("No data to migrate.", fg='green')
|
|
|
return
|
|
@@ -485,14 +589,14 @@ def update_app_model_configs(batch_size):
|
|
|
offset = i * batch_size
|
|
|
limit = min(batch_size, total_records - offset)
|
|
|
|
|
|
- click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
|
|
|
-
|
|
|
+ click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
|
|
+
|
|
|
data_batch = db.session.query(AppModelConfig) \
|
|
|
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
|
|
.filter(App.mode == 'completion') \
|
|
|
.order_by(App.created_at) \
|
|
|
.offset(offset).limit(limit).all()
|
|
|
-
|
|
|
+
|
|
|
if not data_batch:
|
|
|
click.secho("No more data to migrate.", fg='green')
|
|
|
break
|
|
@@ -512,7 +616,7 @@ def update_app_model_configs(batch_size):
|
|
|
app_data = db.session.query(App) \
|
|
|
.filter(App.id == data.app_id) \
|
|
|
.one()
|
|
|
-
|
|
|
+
|
|
|
account_data = db.session.query(Account) \
|
|
|
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
|
|
|
.filter(TenantAccountJoin.role == 'owner') \
|
|
@@ -534,13 +638,15 @@ def update_app_model_configs(batch_size):
|
|
|
db.session.commit()
|
|
|
|
|
|
except Exception as e:
|
|
|
- click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
|
|
|
+ click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
|
|
+ fg='red')
|
|
|
continue
|
|
|
-
|
|
|
- click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
|
|
|
-
|
|
|
+
|
|
|
+ click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
|
|
+
|
|
|
pbar.update(len(data_batch))
|
|
|
|
|
|
+
|
|
|
def register_commands(app):
|
|
|
app.cli.add_command(reset_password)
|
|
|
app.cli.add_command(reset_email)
|
|
@@ -551,4 +657,5 @@ def register_commands(app):
|
|
|
app.cli.add_command(clean_unused_dataset_indexes)
|
|
|
app.cli.add_command(create_qdrant_indexes)
|
|
|
app.cli.add_command(update_qdrant_indexes)
|
|
|
- app.cli.add_command(update_app_model_configs)
|
|
|
+ app.cli.add_command(update_app_model_configs)
|
|
|
+ app.cli.add_command(normalization_collections)
|