|
@@ -0,0 +1,903 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import copy
|
|
|
+import logging
|
|
|
+import re
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
|
|
+from dataclasses import dataclass
|
|
|
+from enum import Enum
|
|
|
+from typing import (
|
|
|
+ Any,
|
|
|
+ Literal,
|
|
|
+ Optional,
|
|
|
+ TypedDict,
|
|
|
+ TypeVar,
|
|
|
+ Union,
|
|
|
+)
|
|
|
+
|
|
|
+from core.rag.models.document import BaseDocumentTransformer, Document
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+TS = TypeVar("TS", bound="TextSplitter")
|
|
|
+
|
|
|
+
|
|
|
+def _split_text_with_regex(
|
|
|
+ text: str, separator: str, keep_separator: bool
|
|
|
+) -> list[str]:
|
|
|
+ # Now that we have the separator, split the text
|
|
|
+ if separator:
|
|
|
+ if keep_separator:
|
|
|
+ # The parentheses in the pattern keep the delimiters in the result.
|
|
|
+ _splits = re.split(f"({separator})", text)
|
|
|
+ splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
|
|
+ if len(_splits) % 2 == 0:
|
|
|
+ splits += _splits[-1:]
|
|
|
+ splits = [_splits[0]] + splits
|
|
|
+ else:
|
|
|
+ splits = re.split(separator, text)
|
|
|
+ else:
|
|
|
+ splits = list(text)
|
|
|
+ return [s for s in splits if s != ""]
|
|
|
+
|
|
|
+
|
|
|
+class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
+ """Interface for splitting text into chunks."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ chunk_size: int = 4000,
|
|
|
+ chunk_overlap: int = 200,
|
|
|
+ length_function: Callable[[str], int] = len,
|
|
|
+ keep_separator: bool = False,
|
|
|
+ add_start_index: bool = False,
|
|
|
+ ) -> None:
|
|
|
+ """Create a new TextSplitter.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ chunk_size: Maximum size of chunks to return
|
|
|
+ chunk_overlap: Overlap in characters between chunks
|
|
|
+ length_function: Function that measures the length of given chunks
|
|
|
+ keep_separator: Whether to keep the separator in the chunks
|
|
|
+ add_start_index: If `True`, includes chunk's start index in metadata
|
|
|
+ """
|
|
|
+ if chunk_overlap > chunk_size:
|
|
|
+ raise ValueError(
|
|
|
+ f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
|
+ f"({chunk_size}), should be smaller."
|
|
|
+ )
|
|
|
+ self._chunk_size = chunk_size
|
|
|
+ self._chunk_overlap = chunk_overlap
|
|
|
+ self._length_function = length_function
|
|
|
+ self._keep_separator = keep_separator
|
|
|
+ self._add_start_index = add_start_index
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def split_text(self, text: str) -> list[str]:
|
|
|
+ """Split text into multiple components."""
|
|
|
+
|
|
|
+ def create_documents(
|
|
|
+ self, texts: list[str], metadatas: Optional[list[dict]] = None
|
|
|
+ ) -> list[Document]:
|
|
|
+ """Create documents from a list of texts."""
|
|
|
+ _metadatas = metadatas or [{}] * len(texts)
|
|
|
+ documents = []
|
|
|
+ for i, text in enumerate(texts):
|
|
|
+ index = -1
|
|
|
+ for chunk in self.split_text(text):
|
|
|
+ metadata = copy.deepcopy(_metadatas[i])
|
|
|
+ if self._add_start_index:
|
|
|
+ index = text.find(chunk, index + 1)
|
|
|
+ metadata["start_index"] = index
|
|
|
+ new_doc = Document(page_content=chunk, metadata=metadata)
|
|
|
+ documents.append(new_doc)
|
|
|
+ return documents
|
|
|
+
|
|
|
+ def split_documents(self, documents: Iterable[Document]) -> list[Document]:
|
|
|
+ """Split documents."""
|
|
|
+ texts, metadatas = [], []
|
|
|
+ for doc in documents:
|
|
|
+ texts.append(doc.page_content)
|
|
|
+ metadatas.append(doc.metadata)
|
|
|
+ return self.create_documents(texts, metadatas=metadatas)
|
|
|
+
|
|
|
+ def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
|
|
|
+ text = separator.join(docs)
|
|
|
+ text = text.strip()
|
|
|
+ if text == "":
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return text
|
|
|
+
|
|
|
+ def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
|
|
+ # We now want to combine these smaller pieces into medium size
|
|
|
+ # chunks to send to the LLM.
|
|
|
+ separator_len = self._length_function(separator)
|
|
|
+
|
|
|
+ docs = []
|
|
|
+ current_doc: list[str] = []
|
|
|
+ total = 0
|
|
|
+ for d in splits:
|
|
|
+ _len = self._length_function(d)
|
|
|
+ if (
|
|
|
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
|
|
|
+ > self._chunk_size
|
|
|
+ ):
|
|
|
+ if total > self._chunk_size:
|
|
|
+ logger.warning(
|
|
|
+ f"Created a chunk of size {total}, "
|
|
|
+ f"which is longer than the specified {self._chunk_size}"
|
|
|
+ )
|
|
|
+ if len(current_doc) > 0:
|
|
|
+ doc = self._join_docs(current_doc, separator)
|
|
|
+ if doc is not None:
|
|
|
+ docs.append(doc)
|
|
|
+ # Keep on popping if:
|
|
|
+ # - we have a larger chunk than in the chunk overlap
|
|
|
+ # - or if we still have any chunks and the length is long
|
|
|
+ while total > self._chunk_overlap or (
|
|
|
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
|
|
|
+ > self._chunk_size
|
|
|
+ and total > 0
|
|
|
+ ):
|
|
|
+ total -= self._length_function(current_doc[0]) + (
|
|
|
+ separator_len if len(current_doc) > 1 else 0
|
|
|
+ )
|
|
|
+ current_doc = current_doc[1:]
|
|
|
+ current_doc.append(d)
|
|
|
+ total += _len + (separator_len if len(current_doc) > 1 else 0)
|
|
|
+ doc = self._join_docs(current_doc, separator)
|
|
|
+ if doc is not None:
|
|
|
+ docs.append(doc)
|
|
|
+ return docs
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
|
|
+ """Text splitter that uses HuggingFace tokenizer to count length."""
|
|
|
+ try:
|
|
|
+ from transformers import PreTrainedTokenizerBase
|
|
|
+
|
|
|
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
|
+ raise ValueError(
|
|
|
+ "Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
|
|
+ )
|
|
|
+
|
|
|
+ def _huggingface_tokenizer_length(text: str) -> int:
|
|
|
+ return len(tokenizer.encode(text))
|
|
|
+
|
|
|
+ except ImportError:
|
|
|
+ raise ValueError(
|
|
|
+ "Could not import transformers python package. "
|
|
|
+ "Please install it with `pip install transformers`."
|
|
|
+ )
|
|
|
+ return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_tiktoken_encoder(
|
|
|
+ cls: type[TS],
|
|
|
+ encoding_name: str = "gpt2",
|
|
|
+ model_name: Optional[str] = None,
|
|
|
+ allowed_special: Union[Literal["all"], Set[str]] = set(),
|
|
|
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> TS:
|
|
|
+ """Text splitter that uses tiktoken encoder to count length."""
|
|
|
+ try:
|
|
|
+ import tiktoken
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError(
|
|
|
+ "Could not import tiktoken python package. "
|
|
|
+ "This is needed in order to calculate max_tokens_for_prompt. "
|
|
|
+ "Please install it with `pip install tiktoken`."
|
|
|
+ )
|
|
|
+
|
|
|
+ if model_name is not None:
|
|
|
+ enc = tiktoken.encoding_for_model(model_name)
|
|
|
+ else:
|
|
|
+ enc = tiktoken.get_encoding(encoding_name)
|
|
|
+
|
|
|
+ def _tiktoken_encoder(text: str) -> int:
|
|
|
+ return len(
|
|
|
+ enc.encode(
|
|
|
+ text,
|
|
|
+ allowed_special=allowed_special,
|
|
|
+ disallowed_special=disallowed_special,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ if issubclass(cls, TokenTextSplitter):
|
|
|
+ extra_kwargs = {
|
|
|
+ "encoding_name": encoding_name,
|
|
|
+ "model_name": model_name,
|
|
|
+ "allowed_special": allowed_special,
|
|
|
+ "disallowed_special": disallowed_special,
|
|
|
+ }
|
|
|
+ kwargs = {**kwargs, **extra_kwargs}
|
|
|
+
|
|
|
+ return cls(length_function=_tiktoken_encoder, **kwargs)
|
|
|
+
|
|
|
+ def transform_documents(
|
|
|
+ self, documents: Sequence[Document], **kwargs: Any
|
|
|
+ ) -> Sequence[Document]:
|
|
|
+ """Transform sequence of documents by splitting them."""
|
|
|
+ return self.split_documents(list(documents))
|
|
|
+
|
|
|
+ async def atransform_documents(
|
|
|
+ self, documents: Sequence[Document], **kwargs: Any
|
|
|
+ ) -> Sequence[Document]:
|
|
|
+ """Asynchronously transform a sequence of documents by splitting them."""
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+class CharacterTextSplitter(TextSplitter):
|
|
|
+ """Splitting text that looks at characters."""
|
|
|
+
|
|
|
+ def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
|
|
+ """Create a new TextSplitter."""
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self._separator = separator
|
|
|
+
|
|
|
+ def split_text(self, text: str) -> list[str]:
|
|
|
+ """Split incoming text and return chunks."""
|
|
|
+ # First we naively split the large input into a bunch of smaller ones.
|
|
|
+ splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
|
|
+ _separator = "" if self._keep_separator else self._separator
|
|
|
+ return self._merge_splits(splits, _separator)
|
|
|
+
|
|
|
+
|
|
|
+class LineType(TypedDict):
|
|
|
+ """Line type as typed dict."""
|
|
|
+
|
|
|
+ metadata: dict[str, str]
|
|
|
+ content: str
|
|
|
+
|
|
|
+
|
|
|
+class HeaderType(TypedDict):
|
|
|
+ """Header type as typed dict."""
|
|
|
+
|
|
|
+ level: int
|
|
|
+ name: str
|
|
|
+ data: str
|
|
|
+
|
|
|
+
|
|
|
+class MarkdownHeaderTextSplitter:
|
|
|
+ """Splitting markdown files based on specified headers."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
|
|
|
+ ):
|
|
|
+ """Create a new MarkdownHeaderTextSplitter.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ headers_to_split_on: Headers we want to track
|
|
|
+ return_each_line: Return each line w/ associated headers
|
|
|
+ """
|
|
|
+ # Output line-by-line or aggregated into chunks w/ common headers
|
|
|
+ self.return_each_line = return_each_line
|
|
|
+ # Given the headers we want to split on,
|
|
|
+ # (e.g., "#, ##, etc") order by length
|
|
|
+ self.headers_to_split_on = sorted(
|
|
|
+ headers_to_split_on, key=lambda split: len(split[0]), reverse=True
|
|
|
+ )
|
|
|
+
|
|
|
+ def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
|
|
|
+ """Combine lines with common metadata into chunks
|
|
|
+ Args:
|
|
|
+ lines: Line of text / associated header metadata
|
|
|
+ """
|
|
|
+ aggregated_chunks: list[LineType] = []
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ if (
|
|
|
+ aggregated_chunks
|
|
|
+ and aggregated_chunks[-1]["metadata"] == line["metadata"]
|
|
|
+ ):
|
|
|
+ # If the last line in the aggregated list
|
|
|
+ # has the same metadata as the current line,
|
|
|
+ # append the current content to the last lines's content
|
|
|
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
|
|
|
+ else:
|
|
|
+ # Otherwise, append the current line to the aggregated list
|
|
|
+ aggregated_chunks.append(line)
|
|
|
+
|
|
|
+ return [
|
|
|
+ Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
|
|
+ for chunk in aggregated_chunks
|
|
|
+ ]
|
|
|
+
|
|
|
+ def split_text(self, text: str) -> list[Document]:
|
|
|
+ """Split markdown file
|
|
|
+ Args:
|
|
|
+ text: Markdown file"""
|
|
|
+
|
|
|
+ # Split the input text by newline character ("\n").
|
|
|
+ lines = text.split("\n")
|
|
|
+ # Final output
|
|
|
+ lines_with_metadata: list[LineType] = []
|
|
|
+ # Content and metadata of the chunk currently being processed
|
|
|
+ current_content: list[str] = []
|
|
|
+ current_metadata: dict[str, str] = {}
|
|
|
+ # Keep track of the nested header structure
|
|
|
+ # header_stack: List[Dict[str, Union[int, str]]] = []
|
|
|
+ header_stack: list[HeaderType] = []
|
|
|
+ initial_metadata: dict[str, str] = {}
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ stripped_line = line.strip()
|
|
|
+ # Check each line against each of the header types (e.g., #, ##)
|
|
|
+ for sep, name in self.headers_to_split_on:
|
|
|
+ # Check if line starts with a header that we intend to split on
|
|
|
+ if stripped_line.startswith(sep) and (
|
|
|
+ # Header with no text OR header is followed by space
|
|
|
+ # Both are valid conditions that sep is being used a header
|
|
|
+ len(stripped_line) == len(sep)
|
|
|
+ or stripped_line[len(sep)] == " "
|
|
|
+ ):
|
|
|
+ # Ensure we are tracking the header as metadata
|
|
|
+ if name is not None:
|
|
|
+ # Get the current header level
|
|
|
+ current_header_level = sep.count("#")
|
|
|
+
|
|
|
+ # Pop out headers of lower or same level from the stack
|
|
|
+ while (
|
|
|
+ header_stack
|
|
|
+ and header_stack[-1]["level"] >= current_header_level
|
|
|
+ ):
|
|
|
+ # We have encountered a new header
|
|
|
+ # at the same or higher level
|
|
|
+ popped_header = header_stack.pop()
|
|
|
+ # Clear the metadata for the
|
|
|
+ # popped header in initial_metadata
|
|
|
+ if popped_header["name"] in initial_metadata:
|
|
|
+ initial_metadata.pop(popped_header["name"])
|
|
|
+
|
|
|
+ # Push the current header to the stack
|
|
|
+ header: HeaderType = {
|
|
|
+ "level": current_header_level,
|
|
|
+ "name": name,
|
|
|
+ "data": stripped_line[len(sep):].strip(),
|
|
|
+ }
|
|
|
+ header_stack.append(header)
|
|
|
+ # Update initial_metadata with the current header
|
|
|
+ initial_metadata[name] = header["data"]
|
|
|
+
|
|
|
+ # Add the previous line to the lines_with_metadata
|
|
|
+ # only if current_content is not empty
|
|
|
+ if current_content:
|
|
|
+ lines_with_metadata.append(
|
|
|
+ {
|
|
|
+ "content": "\n".join(current_content),
|
|
|
+ "metadata": current_metadata.copy(),
|
|
|
+ }
|
|
|
+ )
|
|
|
+ current_content.clear()
|
|
|
+
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ if stripped_line:
|
|
|
+ current_content.append(stripped_line)
|
|
|
+ elif current_content:
|
|
|
+ lines_with_metadata.append(
|
|
|
+ {
|
|
|
+ "content": "\n".join(current_content),
|
|
|
+ "metadata": current_metadata.copy(),
|
|
|
+ }
|
|
|
+ )
|
|
|
+ current_content.clear()
|
|
|
+
|
|
|
+ current_metadata = initial_metadata.copy()
|
|
|
+
|
|
|
+ if current_content:
|
|
|
+ lines_with_metadata.append(
|
|
|
+ {"content": "\n".join(current_content), "metadata": current_metadata}
|
|
|
+ )
|
|
|
+
|
|
|
+ # lines_with_metadata has each line with associated header metadata
|
|
|
+ # aggregate these into chunks based on common metadata
|
|
|
+ if not self.return_each_line:
|
|
|
+ return self.aggregate_lines_to_chunks(lines_with_metadata)
|
|
|
+ else:
|
|
|
+ return [
|
|
|
+ Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
|
|
+ for chunk in lines_with_metadata
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+# should be in newer Python versions (3.10+)
|
|
|
+# @dataclass(frozen=True, kw_only=True, slots=True)
|
|
|
+@dataclass(frozen=True)
|
|
|
+class Tokenizer:
|
|
|
+ chunk_overlap: int
|
|
|
+ tokens_per_chunk: int
|
|
|
+ decode: Callable[[list[int]], str]
|
|
|
+ encode: Callable[[str], list[int]]
|
|
|
+
|
|
|
+
|
|
|
+def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
|
|
+ """Split incoming text and return chunks using tokenizer."""
|
|
|
+ splits: list[str] = []
|
|
|
+ input_ids = tokenizer.encode(text)
|
|
|
+ start_idx = 0
|
|
|
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
|
|
+ chunk_ids = input_ids[start_idx:cur_idx]
|
|
|
+ while start_idx < len(input_ids):
|
|
|
+ splits.append(tokenizer.decode(chunk_ids))
|
|
|
+ start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
|
|
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
|
|
+ chunk_ids = input_ids[start_idx:cur_idx]
|
|
|
+ return splits
|
|
|
+
|
|
|
+
|
|
|
+class TokenTextSplitter(TextSplitter):
|
|
|
+ """Splitting text to tokens using model tokenizer."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ encoding_name: str = "gpt2",
|
|
|
+ model_name: Optional[str] = None,
|
|
|
+ allowed_special: Union[Literal["all"], Set[str]] = set(),
|
|
|
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> None:
|
|
|
+ """Create a new TextSplitter."""
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ try:
|
|
|
+ import tiktoken
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError(
|
|
|
+ "Could not import tiktoken python package. "
|
|
|
+ "This is needed in order to for TokenTextSplitter. "
|
|
|
+ "Please install it with `pip install tiktoken`."
|
|
|
+ )
|
|
|
+
|
|
|
+ if model_name is not None:
|
|
|
+ enc = tiktoken.encoding_for_model(model_name)
|
|
|
+ else:
|
|
|
+ enc = tiktoken.get_encoding(encoding_name)
|
|
|
+ self._tokenizer = enc
|
|
|
+ self._allowed_special = allowed_special
|
|
|
+ self._disallowed_special = disallowed_special
|
|
|
+
|
|
|
+ def split_text(self, text: str) -> list[str]:
|
|
|
+ def _encode(_text: str) -> list[int]:
|
|
|
+ return self._tokenizer.encode(
|
|
|
+ _text,
|
|
|
+ allowed_special=self._allowed_special,
|
|
|
+ disallowed_special=self._disallowed_special,
|
|
|
+ )
|
|
|
+
|
|
|
+ tokenizer = Tokenizer(
|
|
|
+ chunk_overlap=self._chunk_overlap,
|
|
|
+ tokens_per_chunk=self._chunk_size,
|
|
|
+ decode=self._tokenizer.decode,
|
|
|
+ encode=_encode,
|
|
|
+ )
|
|
|
+
|
|
|
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
|
|
+
|
|
|
+
|
|
|
+class Language(str, Enum):
|
|
|
+ """Enum of the programming languages."""
|
|
|
+
|
|
|
+ CPP = "cpp"
|
|
|
+ GO = "go"
|
|
|
+ JAVA = "java"
|
|
|
+ JS = "js"
|
|
|
+ PHP = "php"
|
|
|
+ PROTO = "proto"
|
|
|
+ PYTHON = "python"
|
|
|
+ RST = "rst"
|
|
|
+ RUBY = "ruby"
|
|
|
+ RUST = "rust"
|
|
|
+ SCALA = "scala"
|
|
|
+ SWIFT = "swift"
|
|
|
+ MARKDOWN = "markdown"
|
|
|
+ LATEX = "latex"
|
|
|
+ HTML = "html"
|
|
|
+ SOL = "sol"
|
|
|
+
|
|
|
+
|
|
|
+class RecursiveCharacterTextSplitter(TextSplitter):
|
|
|
+ """Splitting text by recursively look at characters.
|
|
|
+
|
|
|
+ Recursively tries to split by different characters to find one
|
|
|
+ that works.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ separators: Optional[list[str]] = None,
|
|
|
+ keep_separator: bool = True,
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> None:
|
|
|
+ """Create a new TextSplitter."""
|
|
|
+ super().__init__(keep_separator=keep_separator, **kwargs)
|
|
|
+ self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
|
+
|
|
|
+ def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
|
|
+ """Split incoming text and return chunks."""
|
|
|
+ final_chunks = []
|
|
|
+ # Get appropriate separator to use
|
|
|
+ separator = separators[-1]
|
|
|
+ new_separators = []
|
|
|
+ for i, _s in enumerate(separators):
|
|
|
+ if _s == "":
|
|
|
+ separator = _s
|
|
|
+ break
|
|
|
+ if re.search(_s, text):
|
|
|
+ separator = _s
|
|
|
+ new_separators = separators[i + 1:]
|
|
|
+ break
|
|
|
+
|
|
|
+ splits = _split_text_with_regex(text, separator, self._keep_separator)
|
|
|
+ # Now go merging things, recursively splitting longer texts.
|
|
|
+ _good_splits = []
|
|
|
+ _separator = "" if self._keep_separator else separator
|
|
|
+ for s in splits:
|
|
|
+ if self._length_function(s) < self._chunk_size:
|
|
|
+ _good_splits.append(s)
|
|
|
+ else:
|
|
|
+ if _good_splits:
|
|
|
+ merged_text = self._merge_splits(_good_splits, _separator)
|
|
|
+ final_chunks.extend(merged_text)
|
|
|
+ _good_splits = []
|
|
|
+ if not new_separators:
|
|
|
+ final_chunks.append(s)
|
|
|
+ else:
|
|
|
+ other_info = self._split_text(s, new_separators)
|
|
|
+ final_chunks.extend(other_info)
|
|
|
+ if _good_splits:
|
|
|
+ merged_text = self._merge_splits(_good_splits, _separator)
|
|
|
+ final_chunks.extend(merged_text)
|
|
|
+ return final_chunks
|
|
|
+
|
|
|
+ def split_text(self, text: str) -> list[str]:
|
|
|
+ return self._split_text(text, self._separators)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_language(
|
|
|
+ cls, language: Language, **kwargs: Any
|
|
|
+ ) -> RecursiveCharacterTextSplitter:
|
|
|
+ separators = cls.get_separators_for_language(language)
|
|
|
+ return cls(separators=separators, **kwargs)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_separators_for_language(language: Language) -> list[str]:
|
|
|
+ if language == Language.CPP:
|
|
|
+ return [
|
|
|
+ # Split along class definitions
|
|
|
+ "\nclass ",
|
|
|
+ # Split along function definitions
|
|
|
+ "\nvoid ",
|
|
|
+ "\nint ",
|
|
|
+ "\nfloat ",
|
|
|
+ "\ndouble ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\nswitch ",
|
|
|
+ "\ncase ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.GO:
|
|
|
+ return [
|
|
|
+ # Split along function definitions
|
|
|
+ "\nfunc ",
|
|
|
+ "\nvar ",
|
|
|
+ "\nconst ",
|
|
|
+ "\ntype ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nswitch ",
|
|
|
+ "\ncase ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.JAVA:
|
|
|
+ return [
|
|
|
+ # Split along class definitions
|
|
|
+ "\nclass ",
|
|
|
+ # Split along method definitions
|
|
|
+ "\npublic ",
|
|
|
+ "\nprotected ",
|
|
|
+ "\nprivate ",
|
|
|
+ "\nstatic ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\nswitch ",
|
|
|
+ "\ncase ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.JS:
|
|
|
+ return [
|
|
|
+ # Split along function definitions
|
|
|
+ "\nfunction ",
|
|
|
+ "\nconst ",
|
|
|
+ "\nlet ",
|
|
|
+ "\nvar ",
|
|
|
+ "\nclass ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\nswitch ",
|
|
|
+ "\ncase ",
|
|
|
+ "\ndefault ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.PHP:
|
|
|
+ return [
|
|
|
+ # Split along function definitions
|
|
|
+ "\nfunction ",
|
|
|
+ # Split along class definitions
|
|
|
+ "\nclass ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nforeach ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\ndo ",
|
|
|
+ "\nswitch ",
|
|
|
+ "\ncase ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.PROTO:
|
|
|
+ return [
|
|
|
+ # Split along message definitions
|
|
|
+ "\nmessage ",
|
|
|
+ # Split along service definitions
|
|
|
+ "\nservice ",
|
|
|
+ # Split along enum definitions
|
|
|
+ "\nenum ",
|
|
|
+ # Split along option definitions
|
|
|
+ "\noption ",
|
|
|
+ # Split along import statements
|
|
|
+ "\nimport ",
|
|
|
+ # Split along syntax declarations
|
|
|
+ "\nsyntax ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.PYTHON:
|
|
|
+ return [
|
|
|
+ # First, try to split along class definitions
|
|
|
+ "\nclass ",
|
|
|
+ "\ndef ",
|
|
|
+ "\n\tdef ",
|
|
|
+ # Now split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.RST:
|
|
|
+ return [
|
|
|
+ # Split along section titles
|
|
|
+ "\n=+\n",
|
|
|
+ "\n-+\n",
|
|
|
+ "\n\*+\n",
|
|
|
+ # Split along directive markers
|
|
|
+ "\n\n.. *\n\n",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.RUBY:
|
|
|
+ return [
|
|
|
+ # Split along method definitions
|
|
|
+ "\ndef ",
|
|
|
+ "\nclass ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nunless ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\nfor ",
|
|
|
+ "\ndo ",
|
|
|
+ "\nbegin ",
|
|
|
+ "\nrescue ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.RUST:
|
|
|
+ return [
|
|
|
+ # Split along function definitions
|
|
|
+ "\nfn ",
|
|
|
+ "\nconst ",
|
|
|
+ "\nlet ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nloop ",
|
|
|
+ "\nmatch ",
|
|
|
+ "\nconst ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.SCALA:
|
|
|
+ return [
|
|
|
+ # Split along class definitions
|
|
|
+ "\nclass ",
|
|
|
+ "\nobject ",
|
|
|
+ # Split along method definitions
|
|
|
+ "\ndef ",
|
|
|
+ "\nval ",
|
|
|
+ "\nvar ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\nmatch ",
|
|
|
+ "\ncase ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.SWIFT:
|
|
|
+ return [
|
|
|
+ # Split along function definitions
|
|
|
+ "\nfunc ",
|
|
|
+ # Split along class definitions
|
|
|
+ "\nclass ",
|
|
|
+ "\nstruct ",
|
|
|
+ "\nenum ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\ndo ",
|
|
|
+ "\nswitch ",
|
|
|
+ "\ncase ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.MARKDOWN:
|
|
|
+ return [
|
|
|
+ # First, try to split along Markdown headings (starting with level 2)
|
|
|
+ "\n#{1,6} ",
|
|
|
+ # Note the alternative syntax for headings (below) is not handled here
|
|
|
+ # Heading level 2
|
|
|
+ # ---------------
|
|
|
+ # End of code block
|
|
|
+ "```\n",
|
|
|
+ # Horizontal lines
|
|
|
+ "\n\*\*\*+\n",
|
|
|
+ "\n---+\n",
|
|
|
+ "\n___+\n",
|
|
|
+ # Note that this splitter doesn't handle horizontal lines defined
|
|
|
+ # by *three or more* of ***, ---, or ___, but this is not handled
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.LATEX:
|
|
|
+ return [
|
|
|
+ # First, try to split along Latex sections
|
|
|
+ "\n\\\chapter{",
|
|
|
+ "\n\\\section{",
|
|
|
+ "\n\\\subsection{",
|
|
|
+ "\n\\\subsubsection{",
|
|
|
+ # Now split by environments
|
|
|
+ "\n\\\begin{enumerate}",
|
|
|
+ "\n\\\begin{itemize}",
|
|
|
+ "\n\\\begin{description}",
|
|
|
+ "\n\\\begin{list}",
|
|
|
+ "\n\\\begin{quote}",
|
|
|
+ "\n\\\begin{quotation}",
|
|
|
+ "\n\\\begin{verse}",
|
|
|
+ "\n\\\begin{verbatim}",
|
|
|
+ # Now split by math environments
|
|
|
+ "\n\\\begin{align}",
|
|
|
+ "$$",
|
|
|
+ "$",
|
|
|
+ # Now split by the normal type of lines
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.HTML:
|
|
|
+ return [
|
|
|
+ # First, try to split along HTML tags
|
|
|
+ "<body",
|
|
|
+ "<div",
|
|
|
+ "<p",
|
|
|
+ "<br",
|
|
|
+ "<li",
|
|
|
+ "<h1",
|
|
|
+ "<h2",
|
|
|
+ "<h3",
|
|
|
+ "<h4",
|
|
|
+ "<h5",
|
|
|
+ "<h6",
|
|
|
+ "<span",
|
|
|
+ "<table",
|
|
|
+ "<tr",
|
|
|
+ "<td",
|
|
|
+ "<th",
|
|
|
+ "<ul",
|
|
|
+ "<ol",
|
|
|
+ "<header",
|
|
|
+ "<footer",
|
|
|
+ "<nav",
|
|
|
+ # Head
|
|
|
+ "<head",
|
|
|
+ "<style",
|
|
|
+ "<script",
|
|
|
+ "<meta",
|
|
|
+ "<title",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ elif language == Language.SOL:
|
|
|
+ return [
|
|
|
+ # Split along compiler information definitions
|
|
|
+ "\npragma ",
|
|
|
+ "\nusing ",
|
|
|
+ # Split along contract definitions
|
|
|
+ "\ncontract ",
|
|
|
+ "\ninterface ",
|
|
|
+ "\nlibrary ",
|
|
|
+ # Split along method definitions
|
|
|
+ "\nconstructor ",
|
|
|
+ "\ntype ",
|
|
|
+ "\nfunction ",
|
|
|
+ "\nevent ",
|
|
|
+ "\nmodifier ",
|
|
|
+ "\nerror ",
|
|
|
+ "\nstruct ",
|
|
|
+ "\nenum ",
|
|
|
+ # Split along control flow statements
|
|
|
+ "\nif ",
|
|
|
+ "\nfor ",
|
|
|
+ "\nwhile ",
|
|
|
+ "\ndo while ",
|
|
|
+ "\nassembly ",
|
|
|
+ # Split by the normal type of lines
|
|
|
+ "\n\n",
|
|
|
+ "\n",
|
|
|
+ " ",
|
|
|
+ "",
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Language {language} is not supported! "
|
|
|
+ f"Please choose from {list(Language)}"
|
|
|
+ )
|