|
@@ -10,6 +10,7 @@ from flask import current_app
|
|
|
from sqlalchemy import func
|
|
|
|
|
|
from core.index.index import IndexBuilder
|
|
|
+from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
|
|
from core.model_providers.model_factory import ModelFactory
|
|
|
from extensions.ext_redis import redis_client
|
|
|
from flask_login import current_user
|
|
@@ -91,16 +92,18 @@ class DatasetService:
|
|
|
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
|
|
raise DatasetNameDuplicateError(
|
|
|
f'Dataset with name {name} already exists.')
|
|
|
- embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=current_user.current_tenant_id
|
|
|
- )
|
|
|
+ embedding_model = None
|
|
|
+ if indexing_technique == 'high_quality':
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=current_user.current_tenant_id
|
|
|
+ )
|
|
|
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
|
|
# dataset = Dataset(name=name, provider=provider, config=config)
|
|
|
dataset.created_by = account.id
|
|
|
dataset.updated_by = account.id
|
|
|
dataset.tenant_id = tenant_id
|
|
|
- dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
|
|
- dataset.embedding_model = embedding_model.name
|
|
|
+ dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
|
|
|
+ dataset.embedding_model = embedding_model.name if embedding_model else None
|
|
|
db.session.add(dataset)
|
|
|
db.session.commit()
|
|
|
return dataset
|
|
@@ -115,6 +118,23 @@ class DatasetService:
|
|
|
else:
|
|
|
return dataset
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def check_dataset_model_setting(dataset):
|
|
|
+ if dataset.indexing_technique == 'high_quality':
|
|
|
+ try:
|
|
|
+ ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
+ except LLMBadRequestError:
|
|
|
+ raise ValueError(
|
|
|
+ f"No Embedding Model available. Please configure a valid provider "
|
|
|
+ f"in the Settings -> Model Provider.")
|
|
|
+ except ProviderTokenNotInitError as ex:
|
|
|
+ raise ValueError(f"The dataset in unavailable, due to: "
|
|
|
+ f"{ex.description}")
|
|
|
+
|
|
|
@staticmethod
|
|
|
def update_dataset(dataset_id, data, user):
|
|
|
dataset = DatasetService.get_dataset(dataset_id)
|
|
@@ -124,6 +144,19 @@ class DatasetService:
|
|
|
if data['indexing_technique'] == 'economy':
|
|
|
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
|
|
|
elif data['indexing_technique'] == 'high_quality':
|
|
|
+ # check embedding model setting
|
|
|
+ try:
|
|
|
+ ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
+ except LLMBadRequestError:
|
|
|
+ raise ValueError(
|
|
|
+ f"No Embedding Model available. Please configure a valid provider "
|
|
|
+ f"in the Settings -> Model Provider.")
|
|
|
+ except ProviderTokenNotInitError as ex:
|
|
|
+ raise ValueError(ex.description)
|
|
|
deal_dataset_vector_index_task.delay(dataset_id, 'add')
|
|
|
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
|
|
|
|
@@ -397,23 +430,23 @@ class DocumentService:
|
|
|
|
|
|
# check document limit
|
|
|
if current_app.config['EDITION'] == 'CLOUD':
|
|
|
- count = 0
|
|
|
- if document_data["data_source"]["type"] == "upload_file":
|
|
|
- upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
|
|
- count = len(upload_file_list)
|
|
|
- elif document_data["data_source"]["type"] == "notion_import":
|
|
|
- notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
- for notion_info in notion_info_list:
|
|
|
- count = count + len(notion_info['pages'])
|
|
|
- documents_count = DocumentService.get_tenant_documents_count()
|
|
|
- total_count = documents_count + count
|
|
|
- tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
|
|
- if total_count > tenant_document_count:
|
|
|
- raise ValueError(f"over document limit {tenant_document_count}.")
|
|
|
+ if 'original_document_id' not in document_data or not document_data['original_document_id']:
|
|
|
+ count = 0
|
|
|
+ if document_data["data_source"]["type"] == "upload_file":
|
|
|
+ upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
|
|
+ count = len(upload_file_list)
|
|
|
+ elif document_data["data_source"]["type"] == "notion_import":
|
|
|
+ notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
+ for notion_info in notion_info_list:
|
|
|
+ count = count + len(notion_info['pages'])
|
|
|
+ documents_count = DocumentService.get_tenant_documents_count()
|
|
|
+ total_count = documents_count + count
|
|
|
+ tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
|
|
+ if total_count > tenant_document_count:
|
|
|
+ raise ValueError(f"over document limit {tenant_document_count}.")
|
|
|
# if dataset is empty, update dataset data_source_type
|
|
|
if not dataset.data_source_type:
|
|
|
dataset.data_source_type = document_data["data_source"]["type"]
|
|
|
- db.session.commit()
|
|
|
|
|
|
if not dataset.indexing_technique:
|
|
|
if 'indexing_technique' not in document_data \
|
|
@@ -421,6 +454,13 @@ class DocumentService:
|
|
|
raise ValueError("Indexing technique is required")
|
|
|
|
|
|
dataset.indexing_technique = document_data["indexing_technique"]
|
|
|
+ if document_data["indexing_technique"] == 'high_quality':
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id
|
|
|
+ )
|
|
|
+ dataset.embedding_model = embedding_model.name
|
|
|
+ dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
|
|
+
|
|
|
|
|
|
documents = []
|
|
|
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
|
@@ -466,11 +506,11 @@ class DocumentService:
|
|
|
"upload_file_id": file_id,
|
|
|
}
|
|
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
- data_source_info, created_from, position,
|
|
|
- account, file_name, batch)
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, file_name, batch)
|
|
|
db.session.add(document)
|
|
|
db.session.flush()
|
|
|
document_ids.append(document.id)
|
|
@@ -512,11 +552,11 @@ class DocumentService:
|
|
|
"type": page['type']
|
|
|
}
|
|
|
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
- data_source_info, created_from, position,
|
|
|
- account, page['page_name'], batch)
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, page['page_name'], batch)
|
|
|
db.session.add(document)
|
|
|
db.session.flush()
|
|
|
document_ids.append(document.id)
|
|
@@ -536,9 +576,9 @@ class DocumentService:
|
|
|
|
|
|
@staticmethod
|
|
|
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
|
|
- document_language: str, data_source_info: dict, created_from: str, position: int,
|
|
|
- account: Account,
|
|
|
- name: str, batch: str):
|
|
|
+ document_language: str, data_source_info: dict, created_from: str, position: int,
|
|
|
+ account: Account,
|
|
|
+ name: str, batch: str):
|
|
|
document = Document(
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
dataset_id=dataset.id,
|
|
@@ -567,6 +607,7 @@ class DocumentService:
|
|
|
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
|
|
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
created_from: str = 'web'):
|
|
|
+ DatasetService.check_dataset_model_setting(dataset)
|
|
|
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
|
|
if document.display_status != 'available':
|
|
|
raise ValueError("Document is not available")
|
|
@@ -674,9 +715,11 @@ class DocumentService:
|
|
|
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
|
|
if total_count > tenant_document_count:
|
|
|
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
|
|
- embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=tenant_id
|
|
|
- )
|
|
|
+ embedding_model = None
|
|
|
+ if document_data['indexing_technique'] == 'high_quality':
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=tenant_id
|
|
|
+ )
|
|
|
# save dataset
|
|
|
dataset = Dataset(
|
|
|
tenant_id=tenant_id,
|
|
@@ -684,8 +727,8 @@ class DocumentService:
|
|
|
data_source_type=document_data["data_source"]["type"],
|
|
|
indexing_technique=document_data["indexing_technique"],
|
|
|
created_by=account.id,
|
|
|
- embedding_model=embedding_model.name,
|
|
|
- embedding_model_provider=embedding_model.model_provider.provider_name
|
|
|
+ embedding_model=embedding_model.name if embedding_model else None,
|
|
|
+ embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
|
|
|
)
|
|
|
|
|
|
db.session.add(dataset)
|
|
@@ -903,15 +946,15 @@ class SegmentService:
|
|
|
content = args['content']
|
|
|
doc_id = str(uuid.uuid4())
|
|
|
segment_hash = helper.generate_text_hash(content)
|
|
|
-
|
|
|
- embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=dataset.tenant_id,
|
|
|
- model_provider_name=dataset.embedding_model_provider,
|
|
|
- model_name=dataset.embedding_model
|
|
|
- )
|
|
|
-
|
|
|
- # calc embedding use tokens
|
|
|
- tokens = embedding_model.get_num_tokens(content)
|
|
|
+ tokens = 0
|
|
|
+ if dataset.indexing_technique == 'high_quality':
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
+ # calc embedding use tokens
|
|
|
+ tokens = embedding_model.get_num_tokens(content)
|
|
|
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
|
|
DocumentSegment.document_id == document.id
|
|
|
).scalar()
|
|
@@ -973,15 +1016,16 @@ class SegmentService:
|
|
|
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
|
|
else:
|
|
|
segment_hash = helper.generate_text_hash(content)
|
|
|
+ tokens = 0
|
|
|
+ if dataset.indexing_technique == 'high_quality':
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
|
|
|
- embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=dataset.tenant_id,
|
|
|
- model_provider_name=dataset.embedding_model_provider,
|
|
|
- model_name=dataset.embedding_model
|
|
|
- )
|
|
|
-
|
|
|
- # calc embedding use tokens
|
|
|
- tokens = embedding_model.get_num_tokens(content)
|
|
|
+ # calc embedding use tokens
|
|
|
+ tokens = embedding_model.get_num_tokens(content)
|
|
|
segment.content = content
|
|
|
segment.index_node_hash = segment_hash
|
|
|
segment.word_count = len(content)
|
|
@@ -1013,7 +1057,7 @@ class SegmentService:
|
|
|
cache_result = redis_client.get(indexing_cache_key)
|
|
|
if cache_result is not None:
|
|
|
raise ValueError("Segment is deleting.")
|
|
|
-
|
|
|
+
|
|
|
# enabled segment need to delete index
|
|
|
if segment.enabled:
|
|
|
# send delete segment index task
|