Przeglądaj źródła

Fix/langchain document schema (#2539)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 rok temu
rodzic
commit
91ea6fe4ee

+ 1 - 2
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,8 +1,7 @@
 
-from langchain.schema import Document
-
 from core.application_queue_manager import ApplicationQueueManager, PublishFrom
 from core.entities.application_entities import InvokeFrom
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.dataset import DatasetQuery, DocumentSegment
 from models.model import DatasetRetrieverResource

+ 1 - 1
api/core/indexing_runner.py

@@ -9,7 +9,6 @@ from typing import Optional, cast
 
 from flask import Flask, current_app
 from flask_login import current_user
-from langchain.text_splitter import TextSplitter
 from sqlalchemy.orm.exc import ObjectDeletedError
 
 from core.docstore.dataset_docstore import DatasetDocumentStore
@@ -24,6 +23,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from core.rag.models.document import Document
 from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
+from core.splitter.text_splitter import TextSplitter
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage

+ 1 - 2
api/core/rag/data_post_processor/reorder.py

@@ -1,5 +1,4 @@
-
-from langchain.schema import Document
+from core.rag.models.document import Document
 
 
 class ReorderRunner:

+ 1 - 1
api/core/rag/extractor/notion_extractor.py

@@ -5,9 +5,9 @@ from typing import Any, Optional
 import requests
 from flask import current_app
 from flask_login import current_user
-from langchain.schema import Document
 
 from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.dataset import Document as DocumentModel
 from models.source import DataSourceBinding

+ 1 - 2
api/core/rag/index_processor/index_processor_base.py

@@ -2,12 +2,11 @@
 from abc import ABC, abstractmethod
 from typing import Optional
 
-from langchain.text_splitter import TextSplitter
-
 from core.model_manager import ModelInstance
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.models.document import Document
 from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
+from core.splitter.text_splitter import TextSplitter
 from models.dataset import Dataset, DatasetProcessRule
 
 

+ 64 - 1
api/core/rag/models/document.py

@@ -1,4 +1,6 @@
-from typing import Optional
+from abc import ABC, abstractmethod
+from collections.abc import Sequence
+from typing import Any, Optional
 
 from pydantic import BaseModel, Field
 
@@ -14,3 +16,64 @@ class Document(BaseModel):
     metadata: Optional[dict] = Field(default_factory=dict)
 
 
+class BaseDocumentTransformer(ABC):
+    """Abstract base class for document transformation systems.
+
+    A document transformation system takes a sequence of Documents and returns a
+    sequence of transformed Documents.
+
+    Example:
+        .. code-block:: python
+
+            class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
+                embeddings: Embeddings
+                similarity_fn: Callable = cosine_similarity
+                similarity_threshold: float = 0.95
+
+                class Config:
+                    arbitrary_types_allowed = True
+
+                def transform_documents(
+                    self, documents: Sequence[Document], **kwargs: Any
+                ) -> Sequence[Document]:
+                    stateful_documents = get_stateful_documents(documents)
+                    embedded_documents = _get_embeddings_from_stateful_docs(
+                        self.embeddings, stateful_documents
+                    )
+                    included_idxs = _filter_similar_embeddings(
+                        embedded_documents, self.similarity_fn, self.similarity_threshold
+                    )
+                    return [stateful_documents[i] for i in sorted(included_idxs)]
+
+                async def atransform_documents(
+                    self, documents: Sequence[Document], **kwargs: Any
+                ) -> Sequence[Document]:
+                    raise NotImplementedError
+
+    """  # noqa: E501
+
+    @abstractmethod
+    def transform_documents(
+        self, documents: Sequence[Document], **kwargs: Any
+    ) -> Sequence[Document]:
+        """Transform a list of documents.
+
+        Args:
+            documents: A sequence of Documents to be transformed.
+
+        Returns:
+            A list of transformed Documents.
+        """
+
+    @abstractmethod
+    async def atransform_documents(
+        self, documents: Sequence[Document], **kwargs: Any
+    ) -> Sequence[Document]:
+        """Asynchronously transform a list of documents.
+
+        Args:
+            documents: A sequence of Documents to be transformed.
+
+        Returns:
+            A list of transformed Documents.
+        """

+ 1 - 2
api/core/rerank/rerank.py

@@ -1,8 +1,7 @@
 from typing import Optional
 
-from langchain.schema import Document
-
 from core.model_manager import ModelInstance
+from core.rag.models.document import Document
 
 
 class RerankRunner:

+ 4 - 5
api/core/splitter/fixed_text_splitter.py

@@ -3,7 +3,10 @@ from __future__ import annotations
 
 from typing import Any, Optional, cast
 
-from langchain.text_splitter import (
+from core.model_manager import ModelInstance
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
+from core.splitter.text_splitter import (
     TS,
     AbstractSet,
     Collection,
@@ -14,10 +17,6 @@ from langchain.text_splitter import (
     Union,
 )
 
-from core.model_manager import ModelInstance
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
-
 
 class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
     """

+ 903 - 0
api/core/splitter/text_splitter.py

@@ -0,0 +1,903 @@
+from __future__ import annotations
+
+import copy
+import logging
+import re
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Collection, Iterable, Sequence, Set
+from dataclasses import dataclass
+from enum import Enum
+from typing import (
+    Any,
+    Literal,
+    Optional,
+    TypedDict,
+    TypeVar,
+    Union,
+)
+
+from core.rag.models.document import BaseDocumentTransformer, Document
+
+logger = logging.getLogger(__name__)
+
+TS = TypeVar("TS", bound="TextSplitter")
+
+
+def _split_text_with_regex(
+        text: str, separator: str, keep_separator: bool
+) -> list[str]:
+    # Now that we have the separator, split the text
+    if separator:
+        if keep_separator:
+            # The parentheses in the pattern keep the delimiters in the result.
+            _splits = re.split(f"({separator})", text)
+            splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
+            if len(_splits) % 2 == 0:
+                splits += _splits[-1:]
+            splits = [_splits[0]] + splits
+        else:
+            splits = re.split(separator, text)
+    else:
+        splits = list(text)
+    return [s for s in splits if s != ""]
+
+
+class TextSplitter(BaseDocumentTransformer, ABC):
+    """Interface for splitting text into chunks."""
+
+    def __init__(
+            self,
+            chunk_size: int = 4000,
+            chunk_overlap: int = 200,
+            length_function: Callable[[str], int] = len,
+            keep_separator: bool = False,
+            add_start_index: bool = False,
+    ) -> None:
+        """Create a new TextSplitter.
+
+        Args:
+            chunk_size: Maximum size of chunks to return
+            chunk_overlap: Overlap in characters between chunks
+            length_function: Function that measures the length of given chunks
+            keep_separator: Whether to keep the separator in the chunks
+            add_start_index: If `True`, includes chunk's start index in metadata
+        """
+        if chunk_overlap > chunk_size:
+            raise ValueError(
+                f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
+                f"({chunk_size}), should be smaller."
+            )
+        self._chunk_size = chunk_size
+        self._chunk_overlap = chunk_overlap
+        self._length_function = length_function
+        self._keep_separator = keep_separator
+        self._add_start_index = add_start_index
+
+    @abstractmethod
+    def split_text(self, text: str) -> list[str]:
+        """Split text into multiple components."""
+
+    def create_documents(
+            self, texts: list[str], metadatas: Optional[list[dict]] = None
+    ) -> list[Document]:
+        """Create documents from a list of texts."""
+        _metadatas = metadatas or [{}] * len(texts)
+        documents = []
+        for i, text in enumerate(texts):
+            index = -1
+            for chunk in self.split_text(text):
+                metadata = copy.deepcopy(_metadatas[i])
+                if self._add_start_index:
+                    index = text.find(chunk, index + 1)
+                    metadata["start_index"] = index
+                new_doc = Document(page_content=chunk, metadata=metadata)
+                documents.append(new_doc)
+        return documents
+
+    def split_documents(self, documents: Iterable[Document]) -> list[Document]:
+        """Split documents."""
+        texts, metadatas = [], []
+        for doc in documents:
+            texts.append(doc.page_content)
+            metadatas.append(doc.metadata)
+        return self.create_documents(texts, metadatas=metadatas)
+
+    def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
+        text = separator.join(docs)
+        text = text.strip()
+        if text == "":
+            return None
+        else:
+            return text
+
+    def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
+        # We now want to combine these smaller pieces into medium size
+        # chunks to send to the LLM.
+        separator_len = self._length_function(separator)
+
+        docs = []
+        current_doc: list[str] = []
+        total = 0
+        for d in splits:
+            _len = self._length_function(d)
+            if (
+                    total + _len + (separator_len if len(current_doc) > 0 else 0)
+                    > self._chunk_size
+            ):
+                if total > self._chunk_size:
+                    logger.warning(
+                        f"Created a chunk of size {total}, "
+                        f"which is longer than the specified {self._chunk_size}"
+                    )
+                if len(current_doc) > 0:
+                    doc = self._join_docs(current_doc, separator)
+                    if doc is not None:
+                        docs.append(doc)
+                    # Keep on popping if:
+                    # - we have a larger chunk than in the chunk overlap
+                    # - or if we still have any chunks and the length is long
+                    while total > self._chunk_overlap or (
+                            total + _len + (separator_len if len(current_doc) > 0 else 0)
+                            > self._chunk_size
+                            and total > 0
+                    ):
+                        total -= self._length_function(current_doc[0]) + (
+                            separator_len if len(current_doc) > 1 else 0
+                        )
+                        current_doc = current_doc[1:]
+            current_doc.append(d)
+            total += _len + (separator_len if len(current_doc) > 1 else 0)
+        doc = self._join_docs(current_doc, separator)
+        if doc is not None:
+            docs.append(doc)
+        return docs
+
+    @classmethod
+    def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
+        """Text splitter that uses HuggingFace tokenizer to count length."""
+        try:
+            from transformers import PreTrainedTokenizerBase
+
+            if not isinstance(tokenizer, PreTrainedTokenizerBase):
+                raise ValueError(
+                    "Tokenizer received was not an instance of PreTrainedTokenizerBase"
+                )
+
+            def _huggingface_tokenizer_length(text: str) -> int:
+                return len(tokenizer.encode(text))
+
+        except ImportError:
+            raise ValueError(
+                "Could not import transformers python package. "
+                "Please install it with `pip install transformers`."
+            )
+        return cls(length_function=_huggingface_tokenizer_length, **kwargs)
+
+    @classmethod
+    def from_tiktoken_encoder(
+            cls: type[TS],
+            encoding_name: str = "gpt2",
+            model_name: Optional[str] = None,
+            allowed_special: Union[Literal["all"], Set[str]] = set(),
+            disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+            **kwargs: Any,
+    ) -> TS:
+        """Text splitter that uses tiktoken encoder to count length."""
+        try:
+            import tiktoken
+        except ImportError:
+            raise ImportError(
+                "Could not import tiktoken python package. "
+                "This is needed in order to calculate max_tokens_for_prompt. "
+                "Please install it with `pip install tiktoken`."
+            )
+
+        if model_name is not None:
+            enc = tiktoken.encoding_for_model(model_name)
+        else:
+            enc = tiktoken.get_encoding(encoding_name)
+
+        def _tiktoken_encoder(text: str) -> int:
+            return len(
+                enc.encode(
+                    text,
+                    allowed_special=allowed_special,
+                    disallowed_special=disallowed_special,
+                )
+            )
+
+        if issubclass(cls, TokenTextSplitter):
+            extra_kwargs = {
+                "encoding_name": encoding_name,
+                "model_name": model_name,
+                "allowed_special": allowed_special,
+                "disallowed_special": disallowed_special,
+            }
+            kwargs = {**kwargs, **extra_kwargs}
+
+        return cls(length_function=_tiktoken_encoder, **kwargs)
+
+    def transform_documents(
+            self, documents: Sequence[Document], **kwargs: Any
+    ) -> Sequence[Document]:
+        """Transform sequence of documents by splitting them."""
+        return self.split_documents(list(documents))
+
+    async def atransform_documents(
+            self, documents: Sequence[Document], **kwargs: Any
+    ) -> Sequence[Document]:
+        """Asynchronously transform a sequence of documents by splitting them."""
+        raise NotImplementedError
+
+
+class CharacterTextSplitter(TextSplitter):
+    """Splitting text that looks at characters."""
+
+    def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
+        """Create a new TextSplitter."""
+        super().__init__(**kwargs)
+        self._separator = separator
+
+    def split_text(self, text: str) -> list[str]:
+        """Split incoming text and return chunks."""
+        # First we naively split the large input into a bunch of smaller ones.
+        splits = _split_text_with_regex(text, self._separator, self._keep_separator)
+        _separator = "" if self._keep_separator else self._separator
+        return self._merge_splits(splits, _separator)
+
+
+class LineType(TypedDict):
+    """Line type as typed dict."""
+
+    metadata: dict[str, str]
+    content: str
+
+
+class HeaderType(TypedDict):
+    """Header type as typed dict."""
+
+    level: int
+    name: str
+    data: str
+
+
+class MarkdownHeaderTextSplitter:
+    """Splitting markdown files based on specified headers."""
+
+    def __init__(
+            self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
+    ):
+        """Create a new MarkdownHeaderTextSplitter.
+
+        Args:
+            headers_to_split_on: Headers we want to track
+            return_each_line: Return each line w/ associated headers
+        """
+        # Output line-by-line or aggregated into chunks w/ common headers
+        self.return_each_line = return_each_line
+        # Given the headers we want to split on,
+        # (e.g., "#, ##, etc") order by length
+        self.headers_to_split_on = sorted(
+            headers_to_split_on, key=lambda split: len(split[0]), reverse=True
+        )
+
+    def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
+        """Combine lines with common metadata into chunks
+        Args:
+            lines: Line of text / associated header metadata
+        """
+        aggregated_chunks: list[LineType] = []
+
+        for line in lines:
+            if (
+                    aggregated_chunks
+                    and aggregated_chunks[-1]["metadata"] == line["metadata"]
+            ):
+                # If the last line in the aggregated list
+                # has the same metadata as the current line,
+                # append the current content to the last lines's content
+                aggregated_chunks[-1]["content"] += "  \n" + line["content"]
+            else:
+                # Otherwise, append the current line to the aggregated list
+                aggregated_chunks.append(line)
+
+        return [
+            Document(page_content=chunk["content"], metadata=chunk["metadata"])
+            for chunk in aggregated_chunks
+        ]
+
+    def split_text(self, text: str) -> list[Document]:
+        """Split markdown file
+        Args:
+            text: Markdown file"""
+
+        # Split the input text by newline character ("\n").
+        lines = text.split("\n")
+        # Final output
+        lines_with_metadata: list[LineType] = []
+        # Content and metadata of the chunk currently being processed
+        current_content: list[str] = []
+        current_metadata: dict[str, str] = {}
+        # Keep track of the nested header structure
+        # header_stack: List[Dict[str, Union[int, str]]] = []
+        header_stack: list[HeaderType] = []
+        initial_metadata: dict[str, str] = {}
+
+        for line in lines:
+            stripped_line = line.strip()
+            # Check each line against each of the header types (e.g., #, ##)
+            for sep, name in self.headers_to_split_on:
+                # Check if line starts with a header that we intend to split on
+                if stripped_line.startswith(sep) and (
+                        # Header with no text OR header is followed by space
+                        # Both are valid conditions that sep is being used a header
+                        len(stripped_line) == len(sep)
+                        or stripped_line[len(sep)] == " "
+                ):
+                    # Ensure we are tracking the header as metadata
+                    if name is not None:
+                        # Get the current header level
+                        current_header_level = sep.count("#")
+
+                        # Pop out headers of lower or same level from the stack
+                        while (
+                                header_stack
+                                and header_stack[-1]["level"] >= current_header_level
+                        ):
+                            # We have encountered a new header
+                            # at the same or higher level
+                            popped_header = header_stack.pop()
+                            # Clear the metadata for the
+                            # popped header in initial_metadata
+                            if popped_header["name"] in initial_metadata:
+                                initial_metadata.pop(popped_header["name"])
+
+                        # Push the current header to the stack
+                        header: HeaderType = {
+                            "level": current_header_level,
+                            "name": name,
+                            "data": stripped_line[len(sep):].strip(),
+                        }
+                        header_stack.append(header)
+                        # Update initial_metadata with the current header
+                        initial_metadata[name] = header["data"]
+
+                    # Add the previous line to the lines_with_metadata
+                    # only if current_content is not empty
+                    if current_content:
+                        lines_with_metadata.append(
+                            {
+                                "content": "\n".join(current_content),
+                                "metadata": current_metadata.copy(),
+                            }
+                        )
+                        current_content.clear()
+
+                    break
+            else:
+                if stripped_line:
+                    current_content.append(stripped_line)
+                elif current_content:
+                    lines_with_metadata.append(
+                        {
+                            "content": "\n".join(current_content),
+                            "metadata": current_metadata.copy(),
+                        }
+                    )
+                    current_content.clear()
+
+            current_metadata = initial_metadata.copy()
+
+        if current_content:
+            lines_with_metadata.append(
+                {"content": "\n".join(current_content), "metadata": current_metadata}
+            )
+
+        # lines_with_metadata has each line with associated header metadata
+        # aggregate these into chunks based on common metadata
+        if not self.return_each_line:
+            return self.aggregate_lines_to_chunks(lines_with_metadata)
+        else:
+            return [
+                Document(page_content=chunk["content"], metadata=chunk["metadata"])
+                for chunk in lines_with_metadata
+            ]
+
+
+# should be in newer Python versions (3.10+)
+# @dataclass(frozen=True, kw_only=True, slots=True)
+@dataclass(frozen=True)
+class Tokenizer:
+    chunk_overlap: int
+    tokens_per_chunk: int
+    decode: Callable[[list[int]], str]
+    encode: Callable[[str], list[int]]
+
+
+def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
+    """Split incoming text and return chunks using tokenizer."""
+    splits: list[str] = []
+    input_ids = tokenizer.encode(text)
+    start_idx = 0
+    cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+    chunk_ids = input_ids[start_idx:cur_idx]
+    while start_idx < len(input_ids):
+        splits.append(tokenizer.decode(chunk_ids))
+        start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
+        cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+        chunk_ids = input_ids[start_idx:cur_idx]
+    return splits
+
+
+class TokenTextSplitter(TextSplitter):
+    """Splitting text to tokens using model tokenizer."""
+
+    def __init__(
+            self,
+            encoding_name: str = "gpt2",
+            model_name: Optional[str] = None,
+            allowed_special: Union[Literal["all"], Set[str]] = set(),
+            disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+            **kwargs: Any,
+    ) -> None:
+        """Create a new TextSplitter."""
+        super().__init__(**kwargs)
+        try:
+            import tiktoken
+        except ImportError:
+            raise ImportError(
+                "Could not import tiktoken python package. "
+                "This is needed in order to for TokenTextSplitter. "
+                "Please install it with `pip install tiktoken`."
+            )
+
+        if model_name is not None:
+            enc = tiktoken.encoding_for_model(model_name)
+        else:
+            enc = tiktoken.get_encoding(encoding_name)
+        self._tokenizer = enc
+        self._allowed_special = allowed_special
+        self._disallowed_special = disallowed_special
+
+    def split_text(self, text: str) -> list[str]:
+        def _encode(_text: str) -> list[int]:
+            return self._tokenizer.encode(
+                _text,
+                allowed_special=self._allowed_special,
+                disallowed_special=self._disallowed_special,
+            )
+
+        tokenizer = Tokenizer(
+            chunk_overlap=self._chunk_overlap,
+            tokens_per_chunk=self._chunk_size,
+            decode=self._tokenizer.decode,
+            encode=_encode,
+        )
+
+        return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+
+class Language(str, Enum):
+    """Enum of the programming languages."""
+
+    CPP = "cpp"
+    GO = "go"
+    JAVA = "java"
+    JS = "js"
+    PHP = "php"
+    PROTO = "proto"
+    PYTHON = "python"
+    RST = "rst"
+    RUBY = "ruby"
+    RUST = "rust"
+    SCALA = "scala"
+    SWIFT = "swift"
+    MARKDOWN = "markdown"
+    LATEX = "latex"
+    HTML = "html"
+    SOL = "sol"
+
+
+class RecursiveCharacterTextSplitter(TextSplitter):
+    """Splitting text by recursively look at characters.
+
+    Recursively tries to split by different characters to find one
+    that works.
+    """
+
+    def __init__(
+            self,
+            separators: Optional[list[str]] = None,
+            keep_separator: bool = True,
+            **kwargs: Any,
+    ) -> None:
+        """Create a new TextSplitter."""
+        super().__init__(keep_separator=keep_separator, **kwargs)
+        self._separators = separators or ["\n\n", "\n", " ", ""]
+
+    def _split_text(self, text: str, separators: list[str]) -> list[str]:
+        """Split incoming text and return chunks."""
+        final_chunks = []
+        # Get appropriate separator to use
+        separator = separators[-1]
+        new_separators = []
+        for i, _s in enumerate(separators):
+            if _s == "":
+                separator = _s
+                break
+            if re.search(_s, text):
+                separator = _s
+                new_separators = separators[i + 1:]
+                break
+
+        splits = _split_text_with_regex(text, separator, self._keep_separator)
+        # Now go merging things, recursively splitting longer texts.
+        _good_splits = []
+        _separator = "" if self._keep_separator else separator
+        for s in splits:
+            if self._length_function(s) < self._chunk_size:
+                _good_splits.append(s)
+            else:
+                if _good_splits:
+                    merged_text = self._merge_splits(_good_splits, _separator)
+                    final_chunks.extend(merged_text)
+                    _good_splits = []
+                if not new_separators:
+                    final_chunks.append(s)
+                else:
+                    other_info = self._split_text(s, new_separators)
+                    final_chunks.extend(other_info)
+        if _good_splits:
+            merged_text = self._merge_splits(_good_splits, _separator)
+            final_chunks.extend(merged_text)
+        return final_chunks
+
+    def split_text(self, text: str) -> list[str]:
+        return self._split_text(text, self._separators)
+
+    @classmethod
+    def from_language(
+            cls, language: Language, **kwargs: Any
+    ) -> RecursiveCharacterTextSplitter:
+        separators = cls.get_separators_for_language(language)
+        return cls(separators=separators, **kwargs)
+
+    @staticmethod
+    def get_separators_for_language(language: Language) -> list[str]:
+        if language == Language.CPP:
+            return [
+                # Split along class definitions
+                "\nclass ",
+                # Split along function definitions
+                "\nvoid ",
+                "\nint ",
+                "\nfloat ",
+                "\ndouble ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nwhile ",
+                "\nswitch ",
+                "\ncase ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.GO:
+            return [
+                # Split along function definitions
+                "\nfunc ",
+                "\nvar ",
+                "\nconst ",
+                "\ntype ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nswitch ",
+                "\ncase ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.JAVA:
+            return [
+                # Split along class definitions
+                "\nclass ",
+                # Split along method definitions
+                "\npublic ",
+                "\nprotected ",
+                "\nprivate ",
+                "\nstatic ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nwhile ",
+                "\nswitch ",
+                "\ncase ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.JS:
+            return [
+                # Split along function definitions
+                "\nfunction ",
+                "\nconst ",
+                "\nlet ",
+                "\nvar ",
+                "\nclass ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nwhile ",
+                "\nswitch ",
+                "\ncase ",
+                "\ndefault ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.PHP:
+            return [
+                # Split along function definitions
+                "\nfunction ",
+                # Split along class definitions
+                "\nclass ",
+                # Split along control flow statements
+                "\nif ",
+                "\nforeach ",
+                "\nwhile ",
+                "\ndo ",
+                "\nswitch ",
+                "\ncase ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.PROTO:
+            return [
+                # Split along message definitions
+                "\nmessage ",
+                # Split along service definitions
+                "\nservice ",
+                # Split along enum definitions
+                "\nenum ",
+                # Split along option definitions
+                "\noption ",
+                # Split along import statements
+                "\nimport ",
+                # Split along syntax declarations
+                "\nsyntax ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.PYTHON:
+            return [
+                # First, try to split along class definitions
+                "\nclass ",
+                "\ndef ",
+                "\n\tdef ",
+                # Now split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.RST:
+            return [
+                # Split along section titles
+                "\n=+\n",
+                "\n-+\n",
+                "\n\*+\n",
+                # Split along directive markers
+                "\n\n.. *\n\n",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.RUBY:
+            return [
+                # Split along method definitions
+                "\ndef ",
+                "\nclass ",
+                # Split along control flow statements
+                "\nif ",
+                "\nunless ",
+                "\nwhile ",
+                "\nfor ",
+                "\ndo ",
+                "\nbegin ",
+                "\nrescue ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.RUST:
+            return [
+                # Split along function definitions
+                "\nfn ",
+                "\nconst ",
+                "\nlet ",
+                # Split along control flow statements
+                "\nif ",
+                "\nwhile ",
+                "\nfor ",
+                "\nloop ",
+                "\nmatch ",
+                "\nconst ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.SCALA:
+            return [
+                # Split along class definitions
+                "\nclass ",
+                "\nobject ",
+                # Split along method definitions
+                "\ndef ",
+                "\nval ",
+                "\nvar ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nwhile ",
+                "\nmatch ",
+                "\ncase ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.SWIFT:
+            return [
+                # Split along function definitions
+                "\nfunc ",
+                # Split along class definitions
+                "\nclass ",
+                "\nstruct ",
+                "\nenum ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nwhile ",
+                "\ndo ",
+                "\nswitch ",
+                "\ncase ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.MARKDOWN:
+            return [
+                # First, try to split along Markdown headings (starting with level 2)
+                "\n#{1,6} ",
+                # Note the alternative syntax for headings (below) is not handled here
+                # Heading level 2
+                # ---------------
+                # End of code block
+                "```\n",
+                # Horizontal lines
+                "\n\*\*\*+\n",
+                "\n---+\n",
+                "\n___+\n",
+                # Note that this splitter doesn't handle horizontal lines defined
+                # by *three or more* of ***, ---, or ___, but this is not handled
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        elif language == Language.LATEX:
+            return [
+                # First, try to split along Latex sections
+                "\n\\\chapter{",
+                "\n\\\section{",
+                "\n\\\subsection{",
+                "\n\\\subsubsection{",
+                # Now split by environments
+                "\n\\\begin{enumerate}",
+                "\n\\\begin{itemize}",
+                "\n\\\begin{description}",
+                "\n\\\begin{list}",
+                "\n\\\begin{quote}",
+                "\n\\\begin{quotation}",
+                "\n\\\begin{verse}",
+                "\n\\\begin{verbatim}",
+                # Now split by math environments
+                "\n\\\begin{align}",
+                "$$",
+                "$",
+                # Now split by the normal type of lines
+                " ",
+                "",
+            ]
+        elif language == Language.HTML:
+            return [
+                # First, try to split along HTML tags
+                "<body",
+                "<div",
+                "<p",
+                "<br",
+                "<li",
+                "<h1",
+                "<h2",
+                "<h3",
+                "<h4",
+                "<h5",
+                "<h6",
+                "<span",
+                "<table",
+                "<tr",
+                "<td",
+                "<th",
+                "<ul",
+                "<ol",
+                "<header",
+                "<footer",
+                "<nav",
+                # Head
+                "<head",
+                "<style",
+                "<script",
+                "<meta",
+                "<title",
+                "",
+            ]
+        elif language == Language.SOL:
+            return [
+                # Split along compiler information definitions
+                "\npragma ",
+                "\nusing ",
+                # Split along contract definitions
+                "\ncontract ",
+                "\ninterface ",
+                "\nlibrary ",
+                # Split along method definitions
+                "\nconstructor ",
+                "\ntype ",
+                "\nfunction ",
+                "\nevent ",
+                "\nmodifier ",
+                "\nerror ",
+                "\nstruct ",
+                "\nenum ",
+                # Split along control flow statements
+                "\nif ",
+                "\nfor ",
+                "\nwhile ",
+                "\ndo while ",
+                "\nassembly ",
+                # Split by the normal type of lines
+                "\n\n",
+                "\n",
+                " ",
+                "",
+            ]
+        else:
+            raise ValueError(
+                f"Language {language} is not supported! "
+                f"Please choose from {list(Language)}"
+            )

+ 1 - 1
api/core/tool/web_reader_tool.py

@@ -13,7 +13,6 @@ import requests
 from bs4 import BeautifulSoup, CData, Comment, NavigableString
 from langchain.chains import RefineDocumentsChain
 from langchain.chains.summarize import refine_prompts
-from langchain.schema import Document
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.tools.base import BaseTool
 from newspaper import Article
@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain
 from core.entities.application_entities import ModelConfigEntity
 from core.rag.extractor import extract_processor
 from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.models.document import Document
 
 FULL_TEMPLATE = """
 TITLE: {title}

+ 1 - 1
api/core/tools/utils/web_reader_tool.py

@@ -13,7 +13,6 @@ import requests
 from bs4 import BeautifulSoup, CData, Comment, NavigableString
 from langchain.chains import RefineDocumentsChain
 from langchain.chains.summarize import refine_prompts
-from langchain.schema import Document
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.tools.base import BaseTool
 from newspaper import Article
@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain
 from core.entities.application_entities import ModelConfigEntity
 from core.rag.extractor import extract_processor
 from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.models.document import Document
 
 FULL_TEMPLATE = """
 TITLE: {title}

+ 1 - 1
api/tasks/annotation/add_annotation_to_index_task.py

@@ -3,9 +3,9 @@ import time
 
 import click
 from celery import shared_task
-from langchain.schema import Document
 
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from models.dataset import Dataset
 from services.dataset_service import DatasetCollectionBindingService
 

+ 1 - 1
api/tasks/annotation/batch_import_annotations_task.py

@@ -3,10 +3,10 @@ import time
 
 import click
 from celery import shared_task
-from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/tasks/annotation/enable_annotation_reply_task.py

@@ -4,10 +4,10 @@ import time
 
 import click
 from celery import shared_task
-from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset

+ 1 - 1
api/tasks/annotation/update_annotation_to_index_task.py

@@ -3,9 +3,9 @@ import time
 
 import click
 from celery import shared_task
-from langchain.schema import Document
 
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from models.dataset import Dataset
 from services.dataset_service import DatasetCollectionBindingService