Przeglądaj źródła

text spliter length method use default embedding model tokenizer (#2011)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 rok temu
rodzic
commit
a63a9c7d45
2 zmienionych plików z 69 dodań i 24 usunięć
  1. 46 12
      api/core/indexing_runner.py
  2. 23 12
      api/core/spiltter/fixed_text_splitter.py

+ 46 - 12
api/core/indexing_runner.py

@@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.errors.error import ProviderTokenNotInitError
 from core.generator.llm_generator import LLMGenerator
 from core.index.index import IndexBuilder
-from core.model_manager import ModelManager
+from core.model_manager import ModelManager, ModelInstance
 from core.model_runtime.entities.model_entities import ModelType, PriceType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -61,8 +61,24 @@ class IndexingRunner:
                 # load file
                 text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
 
+                # get embedding model instance
+                embedding_model_instance = None
+                if dataset.indexing_technique == 'high_quality':
+                    if dataset.embedding_model_provider:
+                        embedding_model_instance = self.model_manager.get_model_instance(
+                            tenant_id=dataset.tenant_id,
+                            provider=dataset.embedding_model_provider,
+                            model_type=ModelType.TEXT_EMBEDDING,
+                            model=dataset.embedding_model
+                        )
+                    else:
+                        embedding_model_instance = self.model_manager.get_default_model_instance(
+                            tenant_id=dataset.tenant_id,
+                            model_type=ModelType.TEXT_EMBEDDING,
+                        )
+
                 # get splitter
-                splitter = self._get_splitter(processing_rule)
+                splitter = self._get_splitter(processing_rule, embedding_model_instance)
 
                 # split to documents
                 documents = self._step_split(
@@ -121,8 +137,24 @@ class IndexingRunner:
             # load file
             text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
 
+            # get embedding model instance
+            embedding_model_instance = None
+            if dataset.indexing_technique == 'high_quality':
+                if dataset.embedding_model_provider:
+                    embedding_model_instance = self.model_manager.get_model_instance(
+                        tenant_id=dataset.tenant_id,
+                        provider=dataset.embedding_model_provider,
+                        model_type=ModelType.TEXT_EMBEDDING,
+                        model=dataset.embedding_model
+                    )
+                else:
+                    embedding_model_instance = self.model_manager.get_default_model_instance(
+                        tenant_id=dataset.tenant_id,
+                        model_type=ModelType.TEXT_EMBEDDING,
+                    )
+
             # get splitter
-            splitter = self._get_splitter(processing_rule)
+            splitter = self._get_splitter(processing_rule, embedding_model_instance)
 
             # split to documents
             documents = self._step_split(
@@ -253,7 +285,7 @@ class IndexingRunner:
             text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
 
             # get splitter
-            splitter = self._get_splitter(processing_rule)
+            splitter = self._get_splitter(processing_rule, embedding_model_instance)
 
             # split to documents
             documents = self._split_to_documents_for_estimate(
@@ -384,7 +416,7 @@ class IndexingRunner:
                 )
 
                 # get splitter
-                splitter = self._get_splitter(processing_rule)
+                splitter = self._get_splitter(processing_rule, embedding_model_instance)
 
                 # split to documents
                 documents = self._split_to_documents_for_estimate(
@@ -502,7 +534,8 @@ class IndexingRunner:
         text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
         return text
 
-    def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
+    def _get_splitter(self, processing_rule: DatasetProcessRule,
+                      embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
         """
         Get the NodeParser object according to the processing rule.
         """
@@ -517,19 +550,20 @@ class IndexingRunner:
             if separator:
                 separator = separator.replace('\\n', '\n')
 
-
-            character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
+            character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
                 chunk_size=segmentation["max_tokens"],
                 chunk_overlap=0,
                 fixed_separator=separator,
-                separators=["\n\n", "。", ".", " ", ""]
+                separators=["\n\n", "。", ".", " ", ""],
+                embedding_model_instance=embedding_model_instance
             )
         else:
             # Automatic segmentation
-            character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
+            character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
                 chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
                 chunk_overlap=0,
-                separators=["\n\n", "。", ".", " ", ""]
+                separators=["\n\n", "。", ".", " ", ""],
+                embedding_model_instance=embedding_model_instance
             )
 
         return character_splitter
@@ -714,7 +748,7 @@ class IndexingRunner:
         return text
 
     def format_split_text(self, text):
-        regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" 
+        regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
         matches = re.findall(regex, text, re.UNICODE)
 
         return [

+ 23 - 12
api/core/spiltter/fixed_text_splitter.py

@@ -1,8 +1,10 @@
 """Functionality for splitting text."""
 from __future__ import annotations
 
-from typing import Any, List, Optional
+from typing import Any, List, Optional, cast
 
+from core.model_manager import ModelInstance
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
 from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter,
                                      TokenTextSplitter, Type, Union)
@@ -12,22 +14,30 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
     """
         This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
     """
+
     @classmethod
-    def from_gpt2_encoder(
-        cls: Type[TS],
-        encoding_name: str = "gpt2",
-        model_name: Optional[str] = None,
-        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
-        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
-        **kwargs: Any,
+    def from_encoder(
+            cls: Type[TS],
+            embedding_model_instance: Optional[ModelInstance],
+            allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
+            disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+            **kwargs: Any,
     ):
         def _token_encoder(text: str) -> int:
-            return GPT2Tokenizer.get_num_tokens(text)
+            if embedding_model_instance:
+                embedding_model_type_instance = embedding_model_instance.model_type_instance
+                embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+                return embedding_model_type_instance.get_num_tokens(
+                    model=embedding_model_instance.model,
+                    credentials=embedding_model_instance.credentials,
+                    texts=[text]
+                )
+            else:
+                return GPT2Tokenizer.get_num_tokens(text)
 
         if issubclass(cls, TokenTextSplitter):
             extra_kwargs = {
-                "encoding_name": encoding_name,
-                "model_name": model_name,
+                "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
                 "allowed_special": allowed_special,
                 "disallowed_special": disallowed_special,
             }
@@ -35,6 +45,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
 
         return cls(length_function=_token_encoder, **kwargs)
 
+
 class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
     def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
         """Create a new TextSplitter."""
@@ -90,4 +101,4 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
         if _good_splits:
             merged_text = self._merge_splits(_good_splits, separator)
             final_chunks.extend(merged_text)
-        return final_chunks
+        return final_chunks