소스 검색

feat: optimize split rule when use custom split segment identifier (#35)

John Wang 1 년 전
부모
커밋
815f794eef
2개의 변경된 파일73개의 추가작업 그리고 6개의 파일을 삭제
  1. 68 0
      api/core/index/spiltter/fixed_text_splitter.py
  2. 5 6
      api/core/indexing_runner.py

+ 68 - 0
api/core/index/spiltter/fixed_text_splitter.py

@@ -0,0 +1,68 @@
+"""Functionality for splitting text."""
+from __future__ import annotations
+
+from typing import (
+    Any,
+    List,
+    Optional,
+)
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+
+class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
+    def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
+        """Create a new TextSplitter."""
+        super().__init__(**kwargs)
+        self._fixed_separator = fixed_separator
+        self._separators = separators or ["\n\n", "\n", " ", ""]
+
+    def split_text(self, text: str) -> List[str]:
+        """Split incoming text and return chunks."""
+        if self._fixed_separator:
+            chunks = text.split(self._fixed_separator)
+        else:
+            chunks = list(text)
+
+        final_chunks = []
+        for chunk in chunks:
+            if self._length_function(chunk) > self._chunk_size:
+                final_chunks.extend(self.recursive_split_text(chunk))
+            else:
+                final_chunks.append(chunk)
+
+        return final_chunks
+
+    def recursive_split_text(self, text: str) -> List[str]:
+        """Split incoming text and return chunks."""
+        final_chunks = []
+        # Get appropriate separator to use
+        separator = self._separators[-1]
+        for _s in self._separators:
+            if _s == "":
+                separator = _s
+                break
+            if _s in text:
+                separator = _s
+                break
+        # Now that we have the separator, split the text
+        if separator:
+            splits = text.split(separator)
+        else:
+            splits = list(text)
+        # Now go merging things, recursively splitting longer texts.
+        _good_splits = []
+        for s in splits:
+            if self._length_function(s) < self._chunk_size:
+                _good_splits.append(s)
+            else:
+                if _good_splits:
+                    merged_text = self._merge_splits(_good_splits, separator)
+                    final_chunks.extend(merged_text)
+                    _good_splits = []
+                other_info = self.recursive_split_text(s)
+                final_chunks.extend(other_info)
+        if _good_splits:
+            merged_text = self._merge_splits(_good_splits, separator)
+            final_chunks.extend(merged_text)
+        return final_chunks

+ 5 - 6
api/core/indexing_runner.py

@@ -18,6 +18,7 @@ from core.docstore.dataset_docstore import DatesetDocumentStore
 from core.index.keyword_table_index import KeywordTableIndex
 from core.index.readers.html_parser import HTMLParser
 from core.index.readers.pdf_parser import PDFParser
+from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from core.index.vector_index import VectorIndex
 from core.llm.token_calculator import TokenCalculator
 from extensions.ext_database import db
@@ -267,16 +268,14 @@ class IndexingRunner:
                 raise ValueError("Custom segment length should be between 50 and 1000.")
 
             separator = segmentation["separator"]
-            if not separator:
-                separators = ["\n\n", "。", ".", " ", ""]
-            else:
+            if separator:
                 separator = separator.replace('\\n', '\n')
-                separators = [separator, ""]
 
-            character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
+            character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
                 chunk_size=segmentation["max_tokens"],
                 chunk_overlap=0,
-                separators=separators
+                fixed_separator=separator,
+                separators=["\n\n", "。", ".", " ", ""]
             )
         else:
             # Automatic segmentation