Bläddra i källkod

Fix/ignore economy dataset (#1043)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 år sedan
förälder
incheckning
a55ba6e614

+ 17 - 14
api/controllers/console/datasets/datasets.py

@@ -92,11 +92,14 @@ class DatasetListApi(Resource):
             model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
-            item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
-            if item_model in model_names:
-                item['embedding_available'] = True
+            if item['indexing_technique'] == 'high_quality':
+                item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
+                if item_model in model_names:
+                    item['embedding_available'] = True
+                else:
+                    item['embedding_available'] = False
             else:
-                item['embedding_available'] = False
+                item['embedding_available'] = True
         response = {
             'data': data,
             'has_more': len(datasets) == limit,
@@ -122,14 +125,6 @@ class DatasetListApi(Resource):
         # The role of the current user in the ta table must be admin or owner
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
 
         try:
             dataset = DatasetService.create_empty_dataset(
@@ -167,6 +162,11 @@ class DatasetApi(Resource):
     @account_initialization_required
     def patch(self, dataset_id):
         dataset_id_str = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
 
         parser = reqparse.RequestParser()
         parser.add_argument('name', nullable=False,
@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
+        parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
         parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
             try:
                 response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
                                                                   args['process_rule'], args['doc_form'],
-                                                                  args['doc_language'], args['dataset_id'])
+                                                                  args['doc_language'], args['dataset_id'],
+                                                                  args['indexing_technique'])
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     f"No Embedding Model available. Please configure a valid provider "
@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
                 response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
                                                                     args['info_list']['notion_info_list'],
                                                                     args['process_rule'], args['doc_form'],
-                                                                    args['doc_language'], args['dataset_id'])
+                                                                    args['doc_language'], args['dataset_id'],
+                                                                    args['indexing_technique'])
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     f"No Embedding Model available. Please configure a valid provider "

+ 23 - 23
api/controllers/console/datasets/datasets_document.py

@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
         # validate args
         DocumentService.document_create_args_validate(args)
 
-        # check embedding model setting
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
-
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
         except ProviderTokenNotInitError as ex:
@@ -339,15 +325,17 @@ class DatasetInitApi(Resource):
         parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
                             location='json')
         args = parser.parse_args()
-
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
+        if args['indexing_technique'] == 'high_quality':
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
 
         # validate args
         DocumentService.document_create_args_validate(args)
@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
     def delete(self, dataset_id, document_id):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+
         document = self.get_document(dataset_id, document_id)
 
         try:
@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
     def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+
         document = self.get_document(dataset_id, document_id)
 
         # The role of the current user in the ta table must be admin or owner

+ 48 - 53
api/controllers/console/datasets/datasets_segments.py

@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
             raise NotFound('Dataset not found.')
-
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
         # The role of the current user in the ta table must be admin or owner
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
             DatasetService.check_dataset_permission(dataset, current_user)
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
-
-        # check embedding model setting
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
+        if dataset.indexing_technique == 'high_quality':
+            # check embedding model setting
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
 
         segment = DocumentSegment.query.filter(
             DocumentSegment.id == str(segment_id),
@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
         # check embedding model setting
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
+        if dataset.indexing_technique == 'high_quality':
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
         try:
             DatasetService.check_dataset_permission(dataset, current_user)
         except services.errors.account.NoPermissionError as e:
@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
             raise NotFound('Dataset not found.')
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
             raise NotFound('Document not found.')
-        # check embedding model setting
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
-        # check segment
+        if dataset.indexing_technique == 'high_quality':
+            # check embedding model setting
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
+            # check segment
         segment_id = str(segment_id)
         segment = DocumentSegment.query.filter(
             DocumentSegment.id == str(segment_id),
@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
             raise NotFound('Dataset not found.')
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset_id, document_id)
@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
         document = DocumentService.get_document(dataset_id, document_id)
         if not document:
             raise NotFound('Document not found.')
-        try:
-            ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
-        except LLMBadRequestError:
-            raise ProviderNotInitializeError(
-                f"No Embedding Model available. Please configure a valid provider "
-                f"in the Settings -> Model Provider.")
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
         # get file from request
         file = request.files['file']
         # check file

+ 8 - 7
api/core/docstore/dataset_docstore.py

@@ -67,12 +67,13 @@ class DatesetDocumentStore:
 
         if max_position is None:
             max_position = 0
-
-        embedding_model = ModelFactory.get_embedding_model(
-            tenant_id=self._dataset.tenant_id,
-            model_provider_name=self._dataset.embedding_model_provider,
-            model_name=self._dataset.embedding_model
-        )
+        embedding_model = None
+        if self._dataset.indexing_technique == 'high_quality':
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=self._dataset.tenant_id,
+                model_provider_name=self._dataset.embedding_model_provider,
+                model_name=self._dataset.embedding_model
+            )
 
         for doc in docs:
             if not isinstance(doc, Document):
@@ -88,7 +89,7 @@ class DatesetDocumentStore:
                 )
 
             # calc embedding use tokens
-            tokens = embedding_model.get_num_tokens(doc.page_content)
+            tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
 
             if not segment_document:
                 max_position += 1

+ 18 - 1
api/core/index/index.py

@@ -1,10 +1,18 @@
+import json
+
 from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
 from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
+from core.model_providers.models.entity.model_params import ModelKwargs
+from core.model_providers.models.llm.openai_model import OpenAIModel
+from core.model_providers.providers.openai_provider import OpenAIProvider
 from models.dataset import Dataset
+from models.provider import Provider, ProviderType
 
 
 class IndexBuilder:
@@ -35,4 +43,13 @@ class IndexBuilder:
                 )
             )
         else:
-            raise ValueError('Unknown indexing technique')
+            raise ValueError('Unknown indexing technique')
+
+    @classmethod
+    def get_default_high_quality_index(cls, dataset: Dataset):
+        embeddings = OpenAIEmbeddings(openai_api_key=' ')
+        return VectorIndex(
+            dataset=dataset,
+            config=current_app.config,
+            embeddings=embeddings
+        )

+ 46 - 38
api/core/indexing_runner.py

@@ -217,25 +217,29 @@ class IndexingRunner:
             db.session.commit()
 
     def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
-                               doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
+                               doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
+                               indexing_technique: str = 'economy') -> dict:
         """
         Estimate the indexing for the document.
         """
+        embedding_model = None
         if dataset_id:
             dataset = Dataset.query.filter_by(
                 id=dataset_id
             ).first()
             if not dataset:
                 raise ValueError('Dataset not found.')
-            embedding_model = ModelFactory.get_embedding_model(
-                tenant_id=dataset.tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
+            if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
+                embedding_model = ModelFactory.get_embedding_model(
+                    tenant_id=dataset.tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
         else:
-            embedding_model = ModelFactory.get_embedding_model(
-                tenant_id=tenant_id
-            )
+            if indexing_technique == 'high_quality':
+                embedding_model = ModelFactory.get_embedding_model(
+                    tenant_id=tenant_id
+                )
         tokens = 0
         preview_texts = []
         total_segments = 0
@@ -263,8 +267,8 @@ class IndexingRunner:
             for document in documents:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
-
-                tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
+                if indexing_technique == 'high_quality' or embedding_model:
+                    tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
 
         if doc_form and doc_form == 'qa_model':
             text_generation_model = ModelFactory.get_text_generation_model(
@@ -286,32 +290,35 @@ class IndexingRunner:
         return {
             "total_segments": total_segments,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
-            "currency": embedding_model.get_currency(),
+            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
+            "currency": embedding_model.get_currency() if embedding_model else 'USD',
             "preview": preview_texts
         }
 
     def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
-                                 doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
+                                 doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
+                                 indexing_technique: str = 'economy') -> dict:
         """
         Estimate the indexing for the document.
         """
+        embedding_model = None
         if dataset_id:
             dataset = Dataset.query.filter_by(
                 id=dataset_id
             ).first()
             if not dataset:
                 raise ValueError('Dataset not found.')
-            embedding_model = ModelFactory.get_embedding_model(
-                tenant_id=dataset.tenant_id,
-                model_provider_name=dataset.embedding_model_provider,
-                model_name=dataset.embedding_model
-            )
+            if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
+                embedding_model = ModelFactory.get_embedding_model(
+                    tenant_id=dataset.tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
         else:
-            embedding_model = ModelFactory.get_embedding_model(
-                tenant_id=tenant_id
-            )
-
+            if indexing_technique == 'high_quality':
+                embedding_model = ModelFactory.get_embedding_model(
+                    tenant_id=tenant_id
+                )
         # load data from notion
         tokens = 0
         preview_texts = []
@@ -356,8 +363,8 @@ class IndexingRunner:
                 for document in documents:
                     if len(preview_texts) < 5:
                         preview_texts.append(document.page_content)
-
-                    tokens += embedding_model.get_num_tokens(document.page_content)
+                    if indexing_technique == 'high_quality' or embedding_model:
+                        tokens += embedding_model.get_num_tokens(document.page_content)
 
         if doc_form and doc_form == 'qa_model':
             text_generation_model = ModelFactory.get_text_generation_model(
@@ -379,8 +386,8 @@ class IndexingRunner:
         return {
             "total_segments": total_segments,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
-            "currency": embedding_model.get_currency(),
+            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
+            "currency": embedding_model.get_currency() if embedding_model else 'USD',
             "preview": preview_texts
         }
 
@@ -657,12 +664,13 @@ class IndexingRunner:
         """
         vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
-
-        embedding_model = ModelFactory.get_embedding_model(
-            tenant_id=dataset.tenant_id,
-            model_provider_name=dataset.embedding_model_provider,
-            model_name=dataset.embedding_model
-        )
+        embedding_model = None
+        if dataset.indexing_technique == 'high_quality':
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=dataset.tenant_id,
+                model_provider_name=dataset.embedding_model_provider,
+                model_name=dataset.embedding_model
+            )
 
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
@@ -672,11 +680,11 @@ class IndexingRunner:
             # check document is paused
             self._check_document_paused_status(dataset_document.id)
             chunk_documents = documents[i:i + chunk_size]
-
-            tokens += sum(
-                embedding_model.get_num_tokens(document.page_content)
-                for document in chunk_documents
-            )
+            if dataset.indexing_technique == 'high_quality' or embedding_model:
+                tokens += sum(
+                    embedding_model.get_num_tokens(document.page_content)
+                    for document in chunk_documents
+                )
 
             # save vector index
             if vector_index:

+ 0 - 1
api/events/event_handlers/create_document_index.py

@@ -1,6 +1,5 @@
 from events.dataset_event import dataset_was_deleted
 from events.event_handlers.document_index_event import document_index_created
-from tasks.clean_dataset_task import clean_dataset_task
 import datetime
 import logging
 import time

+ 46 - 0
api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py

@@ -0,0 +1,46 @@
+"""update_dataset_model_field_null_available
+
+Revision ID: 4bcffcd64aa4
+Revises: 853f9b9cd3b6
+Create Date: 2023-08-28 20:58:50.077056
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '4bcffcd64aa4'
+down_revision = '853f9b9cd3b6'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.alter_column('embedding_model',
+               existing_type=sa.VARCHAR(length=255),
+               nullable=True,
+               existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
+        batch_op.alter_column('embedding_model_provider',
+               existing_type=sa.VARCHAR(length=255),
+               nullable=True,
+               existing_server_default=sa.text("'openai'::character varying"))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.alter_column('embedding_model_provider',
+               existing_type=sa.VARCHAR(length=255),
+               nullable=False,
+               existing_server_default=sa.text("'openai'::character varying"))
+        batch_op.alter_column('embedding_model',
+               existing_type=sa.VARCHAR(length=255),
+               nullable=False,
+               existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
+
+    # ### end Alembic commands ###

+ 2 - 4
api/models/dataset.py

@@ -36,10 +36,8 @@ class Dataset(db.Model):
     updated_by = db.Column(UUID, nullable=True)
     updated_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
-    embedding_model = db.Column(db.String(
-        255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
-    embedding_model_provider = db.Column(db.String(
-        255), nullable=False, server_default=db.text("'openai'::character varying"))
+    embedding_model = db.Column(db.String(255), nullable=True)
+    embedding_model_provider = db.Column(db.String(255), nullable=True)
 
     @property
     def dataset_keyword_table(self):

+ 99 - 55
api/services/dataset_service.py

@@ -10,6 +10,7 @@ from flask import current_app
 from sqlalchemy import func
 
 from core.index.index import IndexBuilder
+from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_redis import redis_client
 from flask_login import current_user
@@ -91,16 +92,18 @@ class DatasetService:
         if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
             raise DatasetNameDuplicateError(
                 f'Dataset with name {name} already exists.')
-        embedding_model = ModelFactory.get_embedding_model(
-            tenant_id=current_user.current_tenant_id
-        )
+        embedding_model = None
+        if indexing_technique == 'high_quality':
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=current_user.current_tenant_id
+            )
         dataset = Dataset(name=name, indexing_technique=indexing_technique)
         # dataset = Dataset(name=name, provider=provider, config=config)
         dataset.created_by = account.id
         dataset.updated_by = account.id
         dataset.tenant_id = tenant_id
-        dataset.embedding_model_provider = embedding_model.model_provider.provider_name
-        dataset.embedding_model = embedding_model.name
+        dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
+        dataset.embedding_model = embedding_model.name if embedding_model else None
         db.session.add(dataset)
         db.session.commit()
         return dataset
@@ -115,6 +118,23 @@ class DatasetService:
         else:
             return dataset
 
+    @staticmethod
+    def check_dataset_model_setting(dataset):
+        if dataset.indexing_technique == 'high_quality':
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=dataset.tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ValueError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ValueError(f"The dataset in unavailable, due to: "
+                                 f"{ex.description}")
+
     @staticmethod
     def update_dataset(dataset_id, data, user):
         dataset = DatasetService.get_dataset(dataset_id)
@@ -124,6 +144,19 @@ class DatasetService:
             if data['indexing_technique'] == 'economy':
                 deal_dataset_vector_index_task.delay(dataset_id, 'remove')
             elif data['indexing_technique'] == 'high_quality':
+                # check embedding model setting
+                try:
+                    ModelFactory.get_embedding_model(
+                        tenant_id=current_user.current_tenant_id,
+                        model_provider_name=dataset.embedding_model_provider,
+                        model_name=dataset.embedding_model
+                    )
+                except LLMBadRequestError:
+                    raise ValueError(
+                        f"No Embedding Model available. Please configure a valid provider "
+                        f"in the Settings -> Model Provider.")
+                except ProviderTokenNotInitError as ex:
+                    raise ValueError(ex.description)
                 deal_dataset_vector_index_task.delay(dataset_id, 'add')
         filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
 
@@ -397,23 +430,23 @@ class DocumentService:
 
         # check document limit
         if current_app.config['EDITION'] == 'CLOUD':
-            count = 0
-            if document_data["data_source"]["type"] == "upload_file":
-                upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
-                count = len(upload_file_list)
-            elif document_data["data_source"]["type"] == "notion_import":
-                notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
-                for notion_info in notion_info_list:
-                    count = count + len(notion_info['pages'])
-            documents_count = DocumentService.get_tenant_documents_count()
-            total_count = documents_count + count
-            tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-            if total_count > tenant_document_count:
-                raise ValueError(f"over document limit {tenant_document_count}.")
+            if 'original_document_id' not in document_data or not document_data['original_document_id']:
+                count = 0
+                if document_data["data_source"]["type"] == "upload_file":
+                    upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
+                    count = len(upload_file_list)
+                elif document_data["data_source"]["type"] == "notion_import":
+                    notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
+                    for notion_info in notion_info_list:
+                        count = count + len(notion_info['pages'])
+                documents_count = DocumentService.get_tenant_documents_count()
+                total_count = documents_count + count
+                tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
+                if total_count > tenant_document_count:
+                    raise ValueError(f"over document limit {tenant_document_count}.")
         # if dataset is empty, update dataset data_source_type
         if not dataset.data_source_type:
             dataset.data_source_type = document_data["data_source"]["type"]
-            db.session.commit()
 
         if not dataset.indexing_technique:
             if 'indexing_technique' not in document_data \
@@ -421,6 +454,13 @@ class DocumentService:
                 raise ValueError("Indexing technique is required")
 
             dataset.indexing_technique = document_data["indexing_technique"]
+            if document_data["indexing_technique"] == 'high_quality':
+                embedding_model = ModelFactory.get_embedding_model(
+                    tenant_id=dataset.tenant_id
+                )
+                dataset.embedding_model = embedding_model.name
+                dataset.embedding_model_provider = embedding_model.model_provider.provider_name
+
 
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -466,11 +506,11 @@ class DocumentService:
                         "upload_file_id": file_id,
                     }
                     document = DocumentService.build_document(dataset, dataset_process_rule.id,
-                                                             document_data["data_source"]["type"],
-                                                             document_data["doc_form"],
-                                                             document_data["doc_language"],
-                                                             data_source_info, created_from, position,
-                                                             account, file_name, batch)
+                                                              document_data["data_source"]["type"],
+                                                              document_data["doc_form"],
+                                                              document_data["doc_language"],
+                                                              data_source_info, created_from, position,
+                                                              account, file_name, batch)
                     db.session.add(document)
                     db.session.flush()
                     document_ids.append(document.id)
@@ -512,11 +552,11 @@ class DocumentService:
                                 "type": page['type']
                             }
                             document = DocumentService.build_document(dataset, dataset_process_rule.id,
-                                                                     document_data["data_source"]["type"],
-                                                                     document_data["doc_form"],
-                                                                     document_data["doc_language"],
-                                                                     data_source_info, created_from, position,
-                                                                     account, page['page_name'], batch)
+                                                                      document_data["data_source"]["type"],
+                                                                      document_data["doc_form"],
+                                                                      document_data["doc_language"],
+                                                                      data_source_info, created_from, position,
+                                                                      account, page['page_name'], batch)
                             db.session.add(document)
                             db.session.flush()
                             document_ids.append(document.id)
@@ -536,9 +576,9 @@ class DocumentService:
 
     @staticmethod
     def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
-                      document_language: str, data_source_info: dict, created_from: str, position: int,
-                      account: Account,
-                      name: str, batch: str):
+                       document_language: str, data_source_info: dict, created_from: str, position: int,
+                       account: Account,
+                       name: str, batch: str):
         document = Document(
             tenant_id=dataset.tenant_id,
             dataset_id=dataset.id,
@@ -567,6 +607,7 @@ class DocumentService:
     def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
                                         account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
                                         created_from: str = 'web'):
+        DatasetService.check_dataset_model_setting(dataset)
         document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
         if document.display_status != 'available':
             raise ValueError("Document is not available")
@@ -674,9 +715,11 @@ class DocumentService:
             tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
             if total_count > tenant_document_count:
                 raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
-        embedding_model = ModelFactory.get_embedding_model(
-            tenant_id=tenant_id
-        )
+        embedding_model = None
+        if document_data['indexing_technique'] == 'high_quality':
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=tenant_id
+            )
         # save dataset
         dataset = Dataset(
             tenant_id=tenant_id,
@@ -684,8 +727,8 @@ class DocumentService:
             data_source_type=document_data["data_source"]["type"],
             indexing_technique=document_data["indexing_technique"],
             created_by=account.id,
-            embedding_model=embedding_model.name,
-            embedding_model_provider=embedding_model.model_provider.provider_name
+            embedding_model=embedding_model.name if embedding_model else None,
+            embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
         )
 
         db.session.add(dataset)
@@ -903,15 +946,15 @@ class SegmentService:
         content = args['content']
         doc_id = str(uuid.uuid4())
         segment_hash = helper.generate_text_hash(content)
-
-        embedding_model = ModelFactory.get_embedding_model(
-            tenant_id=dataset.tenant_id,
-            model_provider_name=dataset.embedding_model_provider,
-            model_name=dataset.embedding_model
-        )
-
-        # calc embedding use tokens
-        tokens = embedding_model.get_num_tokens(content)
+        tokens = 0
+        if dataset.indexing_technique == 'high_quality':
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=dataset.tenant_id,
+                model_provider_name=dataset.embedding_model_provider,
+                model_name=dataset.embedding_model
+            )
+            # calc embedding use tokens
+            tokens = embedding_model.get_num_tokens(content)
         max_position = db.session.query(func.max(DocumentSegment.position)).filter(
             DocumentSegment.document_id == document.id
         ).scalar()
@@ -973,15 +1016,16 @@ class SegmentService:
                     kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
             else:
                 segment_hash = helper.generate_text_hash(content)
+                tokens = 0
+                if dataset.indexing_technique == 'high_quality':
+                    embedding_model = ModelFactory.get_embedding_model(
+                        tenant_id=dataset.tenant_id,
+                        model_provider_name=dataset.embedding_model_provider,
+                        model_name=dataset.embedding_model
+                    )
 
-                embedding_model = ModelFactory.get_embedding_model(
-                    tenant_id=dataset.tenant_id,
-                    model_provider_name=dataset.embedding_model_provider,
-                    model_name=dataset.embedding_model
-                )
-
-                # calc embedding use tokens
-                tokens = embedding_model.get_num_tokens(content)
+                    # calc embedding use tokens
+                    tokens = embedding_model.get_num_tokens(content)
                 segment.content = content
                 segment.index_node_hash = segment_hash
                 segment.word_count = len(content)
@@ -1013,7 +1057,7 @@ class SegmentService:
         cache_result = redis_client.get(indexing_cache_key)
         if cache_result is not None:
             raise ValueError("Segment is deleting.")
-        
+
         # enabled segment need to delete index
         if segment.enabled:
             # send delete segment index task

+ 7 - 5
api/tasks/batch_create_segment_to_index_task.py

@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
         if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
             raise ValueError('Document is not available.')
         document_segments = []
-        for segment in content:
-            content = segment['content']
-            doc_id = str(uuid.uuid4())
-            segment_hash = helper.generate_text_hash(content)
+        embedding_model = None
+        if dataset.indexing_technique == 'high_quality':
             embedding_model = ModelFactory.get_embedding_model(
                 tenant_id=dataset.tenant_id,
                 model_provider_name=dataset.embedding_model_provider,
                 model_name=dataset.embedding_model
             )
 
+        for segment in content:
+            content = segment['content']
+            doc_id = str(uuid.uuid4())
+            segment_hash = helper.generate_text_hash(content)
             # calc embedding use tokens
-            tokens = embedding_model.get_num_tokens(content)
+            tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
             max_position = db.session.query(func.max(DocumentSegment.position)).filter(
                 DocumentSegment.document_id == dataset_document.id
             ).scalar()

+ 4 - 2
api/tasks/clean_dataset_task.py

@@ -3,8 +3,10 @@ import time
 
 import click
 from celery import shared_task
+from flask import current_app
 
 from core.index.index import IndexBuilder
+from core.index.vector_index.vector_index import VectorIndex
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
     AppDatasetJoin, Document
@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
         documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
 
-        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
         kw_index = IndexBuilder.get_index(dataset, 'economy')
 
         # delete from vector index
-        if vector_index:
+        if dataset.indexing_technique == 'high_quality':
+            vector_index = IndexBuilder.get_default_high_quality_index(dataset)
             try:
                 vector_index.delete()
             except Exception:

+ 2 - 2
api/tasks/deal_dataset_vector_index_task.py

@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
             raise Exception('Dataset not found')
 
         if action == "remove":
-            index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+            index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
             index.delete()
         elif action == "add":
             dataset_documents = db.session.query(DatasetDocument).filter(
@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
 
             if dataset_documents:
                 # save vector index
-                index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+                index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
                 documents = []
                 for dataset_document in dataset_documents:
                     # delete from vector index