Explorar o código

improve text split (#15719)

Jyong hai 1 mes
pai
achega
a8e8c37fdd
Modificáronse 1 ficheiros con 51 adicións e 18 borrados
  1. 51 18
      api/core/rag/splitter/fixed_text_splitter.py

+ 51 - 18
api/core/rag/splitter/fixed_text_splitter.py

@@ -76,16 +76,20 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
 
     def recursive_split_text(self, text: str) -> list[str]:
         """Split incoming text and return chunks."""
+
         final_chunks = []
-        # Get appropriate separator to use
         separator = self._separators[-1]
-        for _s in self._separators:
+        new_separators = []
+
+        for i, _s in enumerate(self._separators):
             if _s == "":
                 separator = _s
                 break
             if _s in text:
                 separator = _s
+                new_separators = self._separators[i + 1 :]
                 break
+
         # Now that we have the separator, split the text
         if separator:
             if separator == " ":
@@ -94,23 +98,52 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
                 splits = text.split(separator)
         else:
             splits = list(text)
-        # Now go merging things, recursively splitting longer texts.
+        splits = [s for s in splits if (s not in {"", "\n"})]
         _good_splits = []
         _good_splits_lengths = []  # cache the lengths of the splits
+        _separator = "" if self._keep_separator else separator
         s_lens = self._length_function(splits)
-        for s, s_len in zip(splits, s_lens):
-            if s_len < self._chunk_size:
-                _good_splits.append(s)
-                _good_splits_lengths.append(s_len)
-            else:
-                if _good_splits:
-                    merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
-                    final_chunks.extend(merged_text)
-                    _good_splits = []
-                    _good_splits_lengths = []
-                other_info = self.recursive_split_text(s)
-                final_chunks.extend(other_info)
-        if _good_splits:
-            merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)
-            final_chunks.extend(merged_text)
+        if _separator != "":
+            for s, s_len in zip(splits, s_lens):
+                if s_len < self._chunk_size:
+                    _good_splits.append(s)
+                    _good_splits_lengths.append(s_len)
+                else:
+                    if _good_splits:
+                        merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
+                        final_chunks.extend(merged_text)
+                        _good_splits = []
+                        _good_splits_lengths = []
+                    if not new_separators:
+                        final_chunks.append(s)
+                    else:
+                        other_info = self._split_text(s, new_separators)
+                        final_chunks.extend(other_info)
+
+            if _good_splits:
+                merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
+                final_chunks.extend(merged_text)
+        else:
+            current_part = ""
+            current_length = 0
+            overlap_part = ""
+            overlap_part_length = 0
+            for s, s_len in zip(splits, s_lens):
+                if current_length + s_len <= self._chunk_size - self._chunk_overlap:
+                    current_part += s
+                    current_length += s_len
+                elif current_length + s_len <= self._chunk_size:
+                    current_part += s
+                    current_length += s_len
+                    overlap_part += s
+                    overlap_part_length += s_len
+                else:
+                    final_chunks.append(current_part)
+                    current_part = overlap_part + s
+                    current_length = s_len + overlap_part_length
+                    overlap_part = ""
+                    overlap_part_length = 0
+            if current_part:
+                final_chunks.append(current_part)
+
         return final_chunks