Parcourir la source

fix spliter length missed (#7987)

Jyong il y a 7 mois
Parent
commit
0e71f6db84

+ 7 - 3
api/core/rag/splitter/fixed_text_splitter.py

@@ -93,17 +93,21 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
             splits = list(text)
         # Now go merging things, recursively splitting longer texts.
         _good_splits = []
+        _good_splits_lengths = []  # cache the lengths of the splits
         for s in splits:
-            if self._length_function(s) < self._chunk_size:
+            s_len = self._length_function(s)
+            if s_len < self._chunk_size:
                 _good_splits.append(s)
+                _good_splits_lengths.append(s_len)
             else:
                 if _good_splits:
-                    merged_text = self._merge_splits(_good_splits, separator)
+                    merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
                     final_chunks.extend(merged_text)
                     _good_splits = []
+                    _good_splits_lengths = []
                 other_info = self.recursive_split_text(s)
                 final_chunks.extend(other_info)
         if _good_splits:
-            merged_text = self._merge_splits(_good_splits, separator)
+            merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
             final_chunks.extend(merged_text)
         return final_chunks

+ 4 - 1
api/core/rag/splitter/text_splitter.py

@@ -243,7 +243,10 @@ class CharacterTextSplitter(TextSplitter):
         # First we naively split the large input into a bunch of smaller ones.
         splits = _split_text_with_regex(text, self._separator, self._keep_separator)
         _separator = "" if self._keep_separator else self._separator
-        return self._merge_splits(splits, _separator)
+        _good_splits_lengths = []  # cache the lengths of the splits
+        for split in splits:
+            _good_splits_lengths.append(self._length_function(split))
+        return self._merge_splits(splits, _separator, _good_splits_lengths)
 
 
 class LineType(TypedDict):