Ver código fonte

refactor: improve handling of leading punctuation removal (#10761)

Zane 5 meses atrás
pai
commit
14f3d44c37

+ 2 - 5
api/core/indexing_runner.py

@@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import (
     FixedRecursiveCharacterTextSplitter,
 )
 from core.rag.splitter.text_splitter import TextSplitter
+from core.tools.utils.text_processing_utils import remove_leading_symbols
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
@@ -500,11 +501,7 @@ class IndexingRunner:
                     document_node.metadata["doc_hash"] = hash
                     # delete Splitter character
                     page_content = document_node.page_content
-                    if page_content.startswith(".") or page_content.startswith("。"):
-                        page_content = page_content[1:]
-                    else:
-                        page_content = page_content
-                    document_node.page_content = page_content
+                    document_node.page_content = remove_leading_symbols(page_content)
 
                     if document_node.page_content:
                         split_documents.append(document_node)

+ 2 - 5
api/core/rag/index_processor/processor/paragraph_index_processor.py

@@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.extractor.extract_processor import ExtractProcessor
 from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.models.document import Document
+from core.tools.utils.text_processing_utils import remove_leading_symbols
 from libs import helper
 from models.dataset import Dataset
 
@@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
                     document_node.metadata["doc_id"] = doc_id
                     document_node.metadata["doc_hash"] = hash
                     # delete Splitter character
-                    page_content = document_node.page_content
-                    if page_content.startswith(".") or page_content.startswith("。"):
-                        page_content = page_content[1:].strip()
-                    else:
-                        page_content = page_content
+                    page_content = remove_leading_symbols(document_node.page_content).strip()
                     if len(page_content) > 0:
                         document_node.page_content = page_content
                         split_documents.append(document_node)

+ 2 - 5
api/core/rag/index_processor/processor/qa_index_processor.py

@@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.extractor.extract_processor import ExtractProcessor
 from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.models.document import Document
+from core.tools.utils.text_processing_utils import remove_leading_symbols
 from libs import helper
 from models.dataset import Dataset
 
@@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
                     document_node.metadata["doc_hash"] = hash
                     # delete Splitter character
                     page_content = document_node.page_content
-                    if page_content.startswith(".") or page_content.startswith("。"):
-                        page_content = page_content[1:]
-                    else:
-                        page_content = page_content
-                    document_node.page_content = page_content
+                    document_node.page_content = remove_leading_symbols(page_content)
                     split_documents.append(document_node)
             all_documents.extend(split_documents)
         for i in range(0, len(all_documents), 10):

+ 16 - 0
api/core/tools/utils/text_processing_utils.py

@@ -0,0 +1,16 @@
+import re
+
+
+def remove_leading_symbols(text: str) -> str:
+    """
+    Remove leading punctuation or symbols from the given text.
+
+    Args:
+        text (str): The input text to process.
+
+    Returns:
+        str: The text with leading punctuation or symbols removed.
+    """
+    # Match Unicode ranges for punctuation and symbols
+    pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+"
+    return re.sub(pattern, "", text)

+ 20 - 0
api/tests/unit_tests/utils/test_text_processing.py

@@ -0,0 +1,20 @@
+from textwrap import dedent
+
+import pytest
+
+from core.tools.utils.text_processing_utils import remove_leading_symbols
+
+
+@pytest.mark.parametrize(
+    ("input_text", "expected_output"),
+    [
+        ("...Hello, World!", "Hello, World!"),
+        ("。测试中文标点", "测试中文标点"),
+        ("!@#Test symbols", "Test symbols"),
+        ("Hello, World!", "Hello, World!"),
+        ("", ""),
+        ("   ", "   "),
+    ],
+)
+def test_remove_leading_symbols(input_text, expected_output):
+    assert remove_leading_symbols(input_text) == expected_output