|
@@ -31,7 +31,7 @@ from models.dataset import (
|
|
|
DocumentSegment,
|
|
|
)
|
|
|
from models.model import UploadFile
|
|
|
-from models.source import DataSourceBinding
|
|
|
+from models.source import DataSourceOauthBinding
|
|
|
from services.errors.account import NoPermissionError
|
|
|
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
|
|
|
from services.errors.document import DocumentIndexingError
|
|
@@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
|
|
|
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
|
|
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
|
|
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
|
|
+from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
|
|
|
|
|
|
|
|
class DatasetService:
|
|
@@ -508,18 +509,40 @@ class DocumentService:
|
|
|
@staticmethod
|
|
|
def retry_document(dataset_id: str, documents: list[Document]):
|
|
|
for document in documents:
|
|
|
+ # add retry flag
|
|
|
+ retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
|
|
+ cache_result = redis_client.get(retry_indexing_cache_key)
|
|
|
+ if cache_result is not None:
|
|
|
+ raise ValueError("Document is being retried, please try again later")
|
|
|
# retry document indexing
|
|
|
document.indexing_status = 'waiting'
|
|
|
db.session.add(document)
|
|
|
db.session.commit()
|
|
|
- # add retry flag
|
|
|
- retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
|
|
+
|
|
|
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
|
|
# trigger async task
|
|
|
document_ids = [document.id for document in documents]
|
|
|
retry_document_indexing_task.delay(dataset_id, document_ids)
|
|
|
|
|
|
@staticmethod
|
|
|
+ def sync_website_document(dataset_id: str, document: Document):
|
|
|
+ # add sync flag
|
|
|
+ sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
|
|
|
+ cache_result = redis_client.get(sync_indexing_cache_key)
|
|
|
+ if cache_result is not None:
|
|
|
+ raise ValueError("Document is being synced, please try again later")
|
|
|
+ # sync document indexing
|
|
|
+ document.indexing_status = 'waiting'
|
|
|
+ data_source_info = document.data_source_info_dict
|
|
|
+ data_source_info['mode'] = 'scrape'
|
|
|
+ document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
|
|
+ db.session.add(document)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ redis_client.setex(sync_indexing_cache_key, 600, 1)
|
|
|
+
|
|
|
+ sync_website_document_indexing_task.delay(dataset_id, document.id)
|
|
|
+ @staticmethod
|
|
|
def get_documents_position(dataset_id):
|
|
|
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
|
|
if document:
|
|
@@ -545,6 +568,9 @@ class DocumentService:
|
|
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
for notion_info in notion_info_list:
|
|
|
count = count + len(notion_info['pages'])
|
|
|
+ elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
+ website_info = document_data["data_source"]['info_list']['website_info_list']
|
|
|
+ count = len(website_info['urls'])
|
|
|
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
|
|
if count > batch_upload_limit:
|
|
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
@@ -683,12 +709,12 @@ class DocumentService:
|
|
|
exist_document[data_source_info['notion_page_id']] = document.id
|
|
|
for notion_info in notion_info_list:
|
|
|
workspace_id = notion_info['workspace_id']
|
|
|
- data_source_binding = DataSourceBinding.query.filter(
|
|
|
+ data_source_binding = DataSourceOauthBinding.query.filter(
|
|
|
db.and_(
|
|
|
- DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
|
|
- DataSourceBinding.provider == 'notion',
|
|
|
- DataSourceBinding.disabled == False,
|
|
|
- DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
|
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
|
+ DataSourceOauthBinding.provider == 'notion',
|
|
|
+ DataSourceOauthBinding.disabled == False,
|
|
|
+ DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
|
)
|
|
|
).first()
|
|
|
if not data_source_binding:
|
|
@@ -717,6 +743,28 @@ class DocumentService:
|
|
|
# delete not selected documents
|
|
|
if len(exist_document) > 0:
|
|
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
|
|
+ elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
+ website_info = document_data["data_source"]['info_list']['website_info_list']
|
|
|
+ urls = website_info['urls']
|
|
|
+ for url in urls:
|
|
|
+ data_source_info = {
|
|
|
+ 'url': url,
|
|
|
+ 'provider': website_info['provider'],
|
|
|
+ 'job_id': website_info['job_id'],
|
|
|
+ 'only_main_content': website_info.get('only_main_content', False),
|
|
|
+ 'mode': 'crawl',
|
|
|
+ }
|
|
|
+ document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, url, batch)
|
|
|
+ db.session.add(document)
|
|
|
+ db.session.flush()
|
|
|
+ document_ids.append(document.id)
|
|
|
+ documents.append(document)
|
|
|
+ position += 1
|
|
|
db.session.commit()
|
|
|
|
|
|
# trigger async task
|
|
@@ -818,12 +866,12 @@ class DocumentService:
|
|
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
for notion_info in notion_info_list:
|
|
|
workspace_id = notion_info['workspace_id']
|
|
|
- data_source_binding = DataSourceBinding.query.filter(
|
|
|
+ data_source_binding = DataSourceOauthBinding.query.filter(
|
|
|
db.and_(
|
|
|
- DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
|
|
- DataSourceBinding.provider == 'notion',
|
|
|
- DataSourceBinding.disabled == False,
|
|
|
- DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
|
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
|
+ DataSourceOauthBinding.provider == 'notion',
|
|
|
+ DataSourceOauthBinding.disabled == False,
|
|
|
+ DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
|
)
|
|
|
).first()
|
|
|
if not data_source_binding:
|
|
@@ -835,6 +883,17 @@ class DocumentService:
|
|
|
"notion_page_icon": page['page_icon'],
|
|
|
"type": page['type']
|
|
|
}
|
|
|
+ elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
+ website_info = document_data["data_source"]['info_list']['website_info_list']
|
|
|
+ urls = website_info['urls']
|
|
|
+ for url in urls:
|
|
|
+ data_source_info = {
|
|
|
+ 'url': url,
|
|
|
+ 'provider': website_info['provider'],
|
|
|
+ 'job_id': website_info['job_id'],
|
|
|
+ 'only_main_content': website_info.get('only_main_content', False),
|
|
|
+ 'mode': 'crawl',
|
|
|
+ }
|
|
|
document.data_source_type = document_data["data_source"]["type"]
|
|
|
document.data_source_info = json.dumps(data_source_info)
|
|
|
document.name = file_name
|
|
@@ -873,6 +932,9 @@ class DocumentService:
|
|
|
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
|
|
for notion_info in notion_info_list:
|
|
|
count = count + len(notion_info['pages'])
|
|
|
+ elif document_data["data_source"]["type"] == "website_crawl":
|
|
|
+ website_info = document_data["data_source"]['info_list']['website_info_list']
|
|
|
+ count = len(website_info['urls'])
|
|
|
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
|
|
if count > batch_upload_limit:
|
|
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
|
@@ -973,6 +1035,10 @@ class DocumentService:
|
|
|
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
|
|
'notion_info_list']:
|
|
|
raise ValueError("Notion source info is required")
|
|
|
+ if args['data_source']['type'] == 'website_crawl':
|
|
|
+ if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
|
|
+ 'website_info_list']:
|
|
|
+ raise ValueError("Website source info is required")
|
|
|
|
|
|
@classmethod
|
|
|
def process_rule_args_validate(cls, args: dict):
|