|
@@ -9,6 +9,7 @@ from typing import Optional, List
|
|
|
from flask import current_app
|
|
|
from sqlalchemy import func
|
|
|
|
|
|
+from core.index.index import IndexBuilder
|
|
|
from core.model_providers.model_factory import ModelFactory
|
|
|
from extensions.ext_redis import redis_client
|
|
|
from flask_login import current_user
|
|
@@ -25,14 +26,16 @@ from services.errors.account import NoPermissionError
|
|
|
from services.errors.dataset import DatasetNameDuplicateError
|
|
|
from services.errors.document import DocumentIndexingError
|
|
|
from services.errors.file import FileNotExistsError
|
|
|
+from services.vector_service import VectorService
|
|
|
from tasks.clean_notion_document_task import clean_notion_document_task
|
|
|
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
|
|
from tasks.document_indexing_task import document_indexing_task
|
|
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
|
|
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
|
|
from tasks.update_segment_index_task import update_segment_index_task
|
|
|
-from tasks.update_segment_keyword_index_task\
|
|
|
- import update_segment_keyword_index_task
|
|
|
+from tasks.recover_document_indexing_task import recover_document_indexing_task
|
|
|
+from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
|
|
|
+from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
|
|
|
|
|
|
|
|
class DatasetService:
|
|
@@ -88,12 +91,16 @@ class DatasetService:
|
|
|
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
|
|
raise DatasetNameDuplicateError(
|
|
|
f'Dataset with name {name} already exists.')
|
|
|
-
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=current_user.current_tenant_id
|
|
|
+ )
|
|
|
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
|
|
# dataset = Dataset(name=name, provider=provider, config=config)
|
|
|
dataset.created_by = account.id
|
|
|
dataset.updated_by = account.id
|
|
|
dataset.tenant_id = tenant_id
|
|
|
+ dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
|
|
+ dataset.embedding_model = embedding_model.name
|
|
|
db.session.add(dataset)
|
|
|
db.session.commit()
|
|
|
return dataset
|
|
@@ -372,7 +379,7 @@ class DocumentService:
|
|
|
indexing_cache_key = 'document_{}_is_paused'.format(document.id)
|
|
|
redis_client.delete(indexing_cache_key)
|
|
|
# trigger async task
|
|
|
- document_indexing_task.delay(document.dataset_id, document.id)
|
|
|
+ recover_document_indexing_task.delay(document.dataset_id, document.id)
|
|
|
|
|
|
@staticmethod
|
|
|
def get_documents_position(dataset_id):
|
|
@@ -450,6 +457,7 @@ class DocumentService:
|
|
|
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
|
|
document_data["data_source"]["type"],
|
|
|
document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
data_source_info, created_from, position,
|
|
|
account, file_name, batch)
|
|
|
db.session.add(document)
|
|
@@ -495,20 +503,11 @@ class DocumentService:
|
|
|
document = DocumentService.save_document(dataset, dataset_process_rule.id,
|
|
|
document_data["data_source"]["type"],
|
|
|
document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
data_source_info, created_from, position,
|
|
|
account, page['page_name'], batch)
|
|
|
- # if page['type'] == 'database':
|
|
|
- # document.splitting_completed_at = datetime.datetime.utcnow()
|
|
|
- # document.cleaning_completed_at = datetime.datetime.utcnow()
|
|
|
- # document.parsing_completed_at = datetime.datetime.utcnow()
|
|
|
- # document.completed_at = datetime.datetime.utcnow()
|
|
|
- # document.indexing_status = 'completed'
|
|
|
- # document.word_count = 0
|
|
|
- # document.tokens = 0
|
|
|
- # document.indexing_latency = 0
|
|
|
db.session.add(document)
|
|
|
db.session.flush()
|
|
|
- # if page['type'] != 'database':
|
|
|
document_ids.append(document.id)
|
|
|
documents.append(document)
|
|
|
position += 1
|
|
@@ -520,15 +519,15 @@ class DocumentService:
|
|
|
db.session.commit()
|
|
|
|
|
|
# trigger async task
|
|
|
- #document_index_created.send(dataset.id, document_ids=document_ids)
|
|
|
document_indexing_task.delay(dataset.id, document_ids)
|
|
|
|
|
|
return documents, batch
|
|
|
|
|
|
@staticmethod
|
|
|
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
|
|
- data_source_info: dict, created_from: str, position: int, account: Account, name: str,
|
|
|
- batch: str):
|
|
|
+ document_language: str, data_source_info: dict, created_from: str, position: int,
|
|
|
+ account: Account,
|
|
|
+ name: str, batch: str):
|
|
|
document = Document(
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
dataset_id=dataset.id,
|
|
@@ -540,7 +539,8 @@ class DocumentService:
|
|
|
name=name,
|
|
|
created_from=created_from,
|
|
|
created_by=account.id,
|
|
|
- doc_form=document_form
|
|
|
+ doc_form=document_form,
|
|
|
+ doc_language=document_language
|
|
|
)
|
|
|
return document
|
|
|
|
|
@@ -654,13 +654,18 @@ class DocumentService:
|
|
|
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
|
|
if documents_count > tenant_document_count:
|
|
|
raise ValueError(f"over document limit {tenant_document_count}.")
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=tenant_id
|
|
|
+ )
|
|
|
# save dataset
|
|
|
dataset = Dataset(
|
|
|
tenant_id=tenant_id,
|
|
|
name='',
|
|
|
data_source_type=document_data["data_source"]["type"],
|
|
|
indexing_technique=document_data["indexing_technique"],
|
|
|
- created_by=account.id
|
|
|
+ created_by=account.id,
|
|
|
+ embedding_model=embedding_model.name,
|
|
|
+ embedding_model_provider=embedding_model.model_provider.provider_name
|
|
|
)
|
|
|
|
|
|
db.session.add(dataset)
|
|
@@ -870,13 +875,15 @@ class SegmentService:
|
|
|
raise ValueError("Answer is required")
|
|
|
|
|
|
@classmethod
|
|
|
- def create_segment(cls, args: dict, document: Document):
|
|
|
+ def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
|
|
content = args['content']
|
|
|
doc_id = str(uuid.uuid4())
|
|
|
segment_hash = helper.generate_text_hash(content)
|
|
|
|
|
|
embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=document.tenant_id
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
)
|
|
|
|
|
|
# calc embedding use tokens
|
|
@@ -894,6 +901,9 @@ class SegmentService:
|
|
|
content=content,
|
|
|
word_count=len(content),
|
|
|
tokens=tokens,
|
|
|
+ status='completed',
|
|
|
+ indexing_at=datetime.datetime.utcnow(),
|
|
|
+ completed_at=datetime.datetime.utcnow(),
|
|
|
created_by=current_user.id
|
|
|
)
|
|
|
if document.doc_form == 'qa_model':
|
|
@@ -901,49 +911,88 @@ class SegmentService:
|
|
|
|
|
|
db.session.add(segment_document)
|
|
|
db.session.commit()
|
|
|
- indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id)
|
|
|
- redis_client.setex(indexing_cache_key, 600, 1)
|
|
|
- create_segment_to_index_task.delay(segment_document.id, args['keywords'])
|
|
|
- return segment_document
|
|
|
+
|
|
|
+ # save vector index
|
|
|
+ try:
|
|
|
+ VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
|
|
|
+ except Exception as e:
|
|
|
+ logging.exception("create segment index failed")
|
|
|
+ segment_document.enabled = False
|
|
|
+ segment_document.disabled_at = datetime.datetime.utcnow()
|
|
|
+ segment_document.status = 'error'
|
|
|
+ segment_document.error = str(e)
|
|
|
+ db.session.commit()
|
|
|
+ segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
|
|
+ return segment
|
|
|
|
|
|
@classmethod
|
|
|
- def update_segment(cls, args: dict, segment: DocumentSegment, document: Document):
|
|
|
+ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
|
|
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
|
|
cache_result = redis_client.get(indexing_cache_key)
|
|
|
if cache_result is not None:
|
|
|
raise ValueError("Segment is indexing, please try again later")
|
|
|
- content = args['content']
|
|
|
- if segment.content == content:
|
|
|
- if document.doc_form == 'qa_model':
|
|
|
- segment.answer = args['answer']
|
|
|
- if args['keywords']:
|
|
|
- segment.keywords = args['keywords']
|
|
|
- db.session.add(segment)
|
|
|
- db.session.commit()
|
|
|
- # update segment index task
|
|
|
- redis_client.setex(indexing_cache_key, 600, 1)
|
|
|
- update_segment_keyword_index_task.delay(segment.id)
|
|
|
- else:
|
|
|
- segment_hash = helper.generate_text_hash(content)
|
|
|
-
|
|
|
- embedding_model = ModelFactory.get_embedding_model(
|
|
|
- tenant_id=document.tenant_id
|
|
|
- )
|
|
|
-
|
|
|
- # calc embedding use tokens
|
|
|
- tokens = embedding_model.get_num_tokens(content)
|
|
|
- segment.content = content
|
|
|
- segment.index_node_hash = segment_hash
|
|
|
- segment.word_count = len(content)
|
|
|
- segment.tokens = tokens
|
|
|
- segment.status = 'updating'
|
|
|
- segment.updated_by = current_user.id
|
|
|
- segment.updated_at = datetime.datetime.utcnow()
|
|
|
- if document.doc_form == 'qa_model':
|
|
|
- segment.answer = args['answer']
|
|
|
- db.session.add(segment)
|
|
|
+ try:
|
|
|
+ content = args['content']
|
|
|
+ if segment.content == content:
|
|
|
+ if document.doc_form == 'qa_model':
|
|
|
+ segment.answer = args['answer']
|
|
|
+ if args['keywords']:
|
|
|
+ segment.keywords = args['keywords']
|
|
|
+ db.session.add(segment)
|
|
|
+ db.session.commit()
|
|
|
+ # update segment index task
|
|
|
+ if args['keywords']:
|
|
|
+ kw_index = IndexBuilder.get_index(dataset, 'economy')
|
|
|
+ # delete from keyword index
|
|
|
+ kw_index.delete_by_ids([segment.index_node_id])
|
|
|
+ # save keyword index
|
|
|
+ kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
|
|
+ else:
|
|
|
+ segment_hash = helper.generate_text_hash(content)
|
|
|
+
|
|
|
+ embedding_model = ModelFactory.get_embedding_model(
|
|
|
+ tenant_id=dataset.tenant_id,
|
|
|
+ model_provider_name=dataset.embedding_model_provider,
|
|
|
+ model_name=dataset.embedding_model
|
|
|
+ )
|
|
|
+
|
|
|
+ # calc embedding use tokens
|
|
|
+ tokens = embedding_model.get_num_tokens(content)
|
|
|
+ segment.content = content
|
|
|
+ segment.index_node_hash = segment_hash
|
|
|
+ segment.word_count = len(content)
|
|
|
+ segment.tokens = tokens
|
|
|
+ segment.status = 'completed'
|
|
|
+ segment.indexing_at = datetime.datetime.utcnow()
|
|
|
+ segment.completed_at = datetime.datetime.utcnow()
|
|
|
+ segment.updated_by = current_user.id
|
|
|
+ segment.updated_at = datetime.datetime.utcnow()
|
|
|
+ if document.doc_form == 'qa_model':
|
|
|
+ segment.answer = args['answer']
|
|
|
+ db.session.add(segment)
|
|
|
+ db.session.commit()
|
|
|
+ # update segment vector index
|
|
|
+ VectorService.create_segment_vector(args['keywords'], segment, dataset)
|
|
|
+ except Exception as e:
|
|
|
+ logging.exception("update segment index failed")
|
|
|
+ segment.enabled = False
|
|
|
+ segment.disabled_at = datetime.datetime.utcnow()
|
|
|
+ segment.status = 'error'
|
|
|
+ segment.error = str(e)
|
|
|
db.session.commit()
|
|
|
- # update segment index task
|
|
|
- redis_client.setex(indexing_cache_key, 600, 1)
|
|
|
- update_segment_index_task.delay(segment.id, args['keywords'])
|
|
|
+ segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
|
|
|
return segment
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
|
|
|
+ indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id)
|
|
|
+ cache_result = redis_client.get(indexing_cache_key)
|
|
|
+ if cache_result is not None:
|
|
|
+ raise ValueError("Segment is deleting.")
|
|
|
+ # send delete segment index task
|
|
|
+ redis_client.setex(indexing_cache_key, 600, 1)
|
|
|
+ # enabled segment need to delete index
|
|
|
+ if segment.enabled:
|
|
|
+ delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
|
|
|
+ db.session.delete(segment)
|
|
|
+ db.session.commit()
|