فهرست منبع

feat: support tidb vector (#4588)

Weaxs 10 ماه پیش
والد
کامیت
0797f9bc05

+ 7 - 0
api/.env.example

@@ -112,6 +112,13 @@ PGVECTOR_USER=postgres
 PGVECTOR_PASSWORD=postgres
 PGVECTOR_DATABASE=postgres
 
+# Tidb Vector configuration
+TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
+TIDB_VECTOR_PORT=4000
+TIDB_VECTOR_USER=xxx.root
+TIDB_VECTOR_PASSWORD=xxxxxx
+TIDB_VECTOR_DATABASE=dify
+
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5

+ 7 - 0
api/config.py

@@ -299,6 +299,13 @@ class Config:
         self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
         self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
 
+        # tidb-vector settings
+        self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
+        self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
+        self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
+        self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
+        self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
+
         # ------------------------
         # Mail Configurations.
         # ------------------------

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
     @account_initialization_required
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
-        if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
+        if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
             return {
                 'retrieval_method': [
                     'semantic_search'
@@ -497,7 +497,7 @@ class DatasetRetrievalSettingMockApi(Resource):
     @login_required
     @account_initialization_required
     def get(self, vector_type):
-        if vector_type in {'milvus', 'relyt', 'pgvector'}:
+        if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}:
             return {
                 'retrieval_method': [
                     'semantic_search'

+ 0 - 0
api/core/rag/datasource/vdb/tidb_vector/__init__.py


+ 214 - 0
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -0,0 +1,214 @@
+import json
+import logging
+from typing import Any
+
+import sqlalchemy
+from pydantic import BaseModel, root_validator
+from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
+from sqlalchemy import text as sql_text
+from sqlalchemy.orm import Session, declarative_base
+
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+
+class TiDBVectorConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    database: str
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['host']:
+            raise ValueError("config TIDB_VECTOR_HOST is required")
+        if not values['port']:
+            raise ValueError("config TIDB_VECTOR_PORT is required")
+        if not values['user']:
+            raise ValueError("config TIDB_VECTOR_USER is required")
+        if not values['password']:
+            raise ValueError("config TIDB_VECTOR_PASSWORD is required")
+        if not values['database']:
+            raise ValueError("config TIDB_VECTOR_DATABASE is required")
+        return values
+
+
+class TiDBVector(BaseVector):
+
+    def _table(self, dim: int) -> Table:
+        from tidb_vector.sqlalchemy import VectorType
+        return Table(
+            self._collection_name,
+            self._orm_base.metadata,
+            Column('id', String(36), primary_key=True, nullable=False),
+            Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
+            Column("text", TEXT, nullable=False),
+            Column("meta", JSON, nullable=False),
+            Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
+            Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
+            extend_existing=True
+        )
+
+    def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
+                     f"ssl_verify_cert=true&ssl_verify_identity=true")
+        self._distance_func = distance_func.lower()
+        self._engine = create_engine(self._url)
+        self._orm_base = declarative_base()
+        self._dimension = 1536
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        logger.info("create collection and add texts, collection_name: " + self._collection_name)
+        self._create_collection(len(embeddings[0]))
+        self.add_texts(texts, embeddings)
+        self._dimension = len(embeddings[0])
+        pass
+
+    def _create_collection(self, dimension: int):
+        logger.info("_create_collection, collection_name " + self._collection_name)
+        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+            with Session(self._engine) as session:
+                session.begin()
+                drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
+                session.execute(drop_statement)
+                create_statement = sql_text(f"""
+                    CREATE TABLE IF NOT EXISTS {self._collection_name} (
+                        id CHAR(36) PRIMARY KEY,
+                        text TEXT NOT NULL,
+                        meta JSON NOT NULL,
+                        vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
+                        create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
+                        update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
+                    );
+                """)
+                session.execute(create_statement)
+                # tidb vector not support 'CREATE/ADD INDEX' now
+                session.commit()
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        table = self._table(len(embeddings[0]))
+        ids = self._get_uuids(documents)
+        metas = [d.metadata for d in documents]
+        texts = [d.page_content for d in documents]
+
+        chunks_table_data = []
+        with self._engine.connect() as conn:
+            with conn.begin():
+                for id, text, meta, embedding in zip(
+                        ids, texts, metas, embeddings
+                ):
+                    chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
+
+                    # Execute the batch insert when the batch size is reached
+                    if len(chunks_table_data) == 500:
+                        conn.execute(insert(table).values(chunks_table_data))
+                        # Clear the chunks_table_data list for the next batch
+                        chunks_table_data.clear()
+
+                # Insert any remaining records that didn't make up a full batch
+                if chunks_table_data:
+                    conn.execute(insert(table).values(chunks_table_data))
+        return ids
+
+    def text_exists(self, id: str) -> bool:
+        result = self.get_ids_by_metadata_field('doc_id', id)
+        return len(result) > 0
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        with Session(self._engine) as session:
+            ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
+            select_statement = sql_text(
+                f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
+            )
+            result = session.execute(select_statement).fetchall()
+        if result:
+            ids = [item[0] for item in result]
+            self._delete_by_ids(ids)
+
+    def _delete_by_ids(self, ids: list[str]) -> bool:
+        if ids is None:
+            raise ValueError("No ids provided to delete.")
+        table = self._table(self._dimension)
+        try:
+            with self._engine.connect() as conn:
+                with conn.begin():
+                    delete_condition = table.c.id.in_(ids)
+                    conn.execute(table.delete().where(delete_condition))
+                    return True
+        except Exception as e:
+            print("Delete operation failed:", str(e))
+            return False
+
+    def delete_by_document_id(self, document_id: str):
+        ids = self.get_ids_by_metadata_field('document_id', document_id)
+        if ids:
+            self._delete_by_ids(ids)
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        with Session(self._engine) as session:
+            select_statement = sql_text(
+                f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
+            )
+            result = session.execute(select_statement).fetchall()
+        if result:
+            return [item[0] for item in result]
+        else:
+            return None
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        ids = self.get_ids_by_metadata_field(key, value)
+        if ids:
+            self._delete_by_ids(ids)
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        top_k = kwargs.get("top_k", 5)
+        score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
+        filter = kwargs.get('filter')
+        distance = 1 - score_threshold
+
+        query_vector_str = ", ".join(format(x) for x in query_vector)
+        query_vector_str = "[" + query_vector_str + "]"
+        logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
+
+        docs = []
+        if self._distance_func == 'l2':
+            tidb_func = 'Vec_l2_distance'
+        elif self._distance_func == 'l2':
+            tidb_func = 'Vec_Cosine_distance'
+        else:
+            tidb_func = 'Vec_Cosine_distance'
+
+        with Session(self._engine) as session:
+            select_statement = sql_text(
+                f"""SELECT meta, text FROM (
+                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance 
+                        FROM {self._collection_name} 
+                        ORDER BY distance
+                        LIMIT {top_k}
+                    ) t WHERE distance < {distance};"""
+            )
+            res = session.execute(select_statement)
+            results = [(row[0], row[1]) for row in res]
+            for meta, text in results:
+                docs.append(Document(page_content=text, metadata=json.loads(meta)))
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # tidb doesn't support bm25 search
+        return []
+
+    def delete(self) -> None:
+        with Session(self._engine) as session:
+            session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
+            session.commit()

+ 25 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -187,6 +187,31 @@ class Vector:
                     database=config.get("PGVECTOR_DATABASE"),
                 ),
             )
+        elif vector_type == "tidb_vector":
+            from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
+
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix.lower()
+            else:
+                dataset_id = self._dataset.id
+                collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+                index_struct_dict = {
+                    "type": 'tidb_vector',
+                    "vector_store": {"class_prefix": collection_name}
+                }
+                self._dataset.index_struct = json.dumps(index_struct_dict)
+
+            return TiDBVector(
+                collection_name=collection_name,
+                config=TiDBVectorConfig(
+                    host=config.get('TIDB_VECTOR_HOST'),
+                    port=config.get('TIDB_VECTOR_PORT'),
+                    user=config.get('TIDB_VECTOR_USER'),
+                    password=config.get('TIDB_VECTOR_PASSWORD'),
+                    database=config.get('TIDB_VECTOR_DATABASE'),
+                ),
+            )
         else:
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 

+ 2 - 0
api/requirements.txt

@@ -81,5 +81,7 @@ pgvecto-rs==0.1.4
 firecrawl-py==0.0.5
 oss2==2.18.5
 pgvector==0.2.5
+pymysql==1.1.1
+tidb-vector==0.0.9
 google-cloud-aiplatform==1.49.0
 vanna[postgres,mysql,clickhouse,duckdb]==0.5.5

+ 0 - 0
api/tests/integration_tests/vdb/tidb_vector/__init__.py


+ 63 - 0
api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py

@@ -0,0 +1,63 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
+from models.dataset import Document
+from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
+
+
+@pytest.fixture
+def tidb_vector():
+    return TiDBVector(
+        collection_name='test_collection',
+        config=TiDBVectorConfig(
+            host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
+            port="4000",
+            user="xxx.root",
+            password="xxxxxx",
+            database="dify"
+        )
+    )
+
+
+class TiDBVectorTest(AbstractVectorTest):
+    def __init__(self, vector):
+        super().__init__()
+        self.vector = vector
+
+    def text_exists(self):
+        exist = self.vector.text_exists(self.example_doc_id)
+        assert exist == False
+
+    def search_by_vector(self):
+        hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
+        assert len(hits_by_vector) == 0
+
+    def search_by_full_text(self):
+        hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+    def get_ids_by_metadata_field(self):
+        ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+        assert len(ids) == 0
+
+    def delete_by_document_id(self):
+        self.vector.delete_by_document_id(document_id=self.example_doc_id)
+
+
+def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session):
+    TiDBVectorTest(vector=tidb_vector).run_all_tests()
+
+
+@pytest.fixture
+def mock_session():
+    with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session:
+        yield mock_session
+
+
+@pytest.fixture
+def setup_tidbvector_mock(tidb_vector, mock_session):
+    with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'):
+        with patch.object(tidb_vector._engine, 'connect'):
+            yield tidb_vector

+ 12 - 0
docker/docker-compose.yaml

@@ -134,6 +134,12 @@ services:
       PGVECTOR_USER: postgres
       PGVECTOR_PASSWORD: difyai123456
       PGVECTOR_DATABASE: dify
+      # tidb vector configurations
+      TIDB_VECTOR_HOST: tidb
+      TIDB_VECTOR_PORT: 4000
+      TIDB_VECTOR_USER: xxx.root
+      TIDB_VECTOR_PASSWORD: xxxxxx
+      TIDB_VECTOR_DATABASE: dify
       # Mail configuration, support: resend, smtp
       MAIL_TYPE: ''
       # default send from email address, if not specified
@@ -289,6 +295,12 @@ services:
       PGVECTOR_USER: postgres
       PGVECTOR_PASSWORD: difyai123456
       PGVECTOR_DATABASE: dify
+      # tidb vector configurations
+      TIDB_VECTOR_HOST: tidb
+      TIDB_VECTOR_PORT: 4000
+      TIDB_VECTOR_USER: xxx.root
+      TIDB_VECTOR_PASSWORD: xxxxxx
+      TIDB_VECTOR_DATABASE: dify
       # Notion import configuration, support public and internal
       NOTION_INTEGRATION_TYPE: public
       NOTION_CLIENT_SECRET: you-client-secret