Przeglądaj źródła

feat(vdb): add HNSW vector index for TiDB vector store with TiFlash (#12043)

Bowen Liang 2 miesięcy temu
rodzic
commit
0751ad1eeb

+ 1 - 1
.github/workflows/expose_service_ports.sh

@@ -9,6 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
 yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
 yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
 yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
-yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
+yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
 
 echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"

+ 12 - 2
.github/workflows/vdb-tests.yml

@@ -54,7 +54,15 @@ jobs:
       - name: Expose Service Ports
         run: sh .github/workflows/expose_service_ports.sh
 
-      - name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
+      - name: Set up Vector Store (TiDB)
+        uses: hoverkraft-tech/compose-action@v2.0.2
+        with:
+          compose-file: docker/tidb/docker-compose.yaml
+          services: |
+            tidb
+            tiflash
+
+      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
         uses: hoverkraft-tech/compose-action@v2.0.2
         with:
           compose-file: |
@@ -70,7 +78,9 @@ jobs:
             pgvector
             chroma
             elasticsearch
-            tidb
+
+      - name: Check TiDB Ready
+        run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
 
       - name: Test Vector Stores
         run: poetry run -P api bash dev/pytest/pytest_vdb.sh

+ 1 - 0
.gitignore

@@ -163,6 +163,7 @@ docker/volumes/db/data/*
 docker/volumes/redis/data/*
 docker/volumes/weaviate/*
 docker/volumes/qdrant/*
+docker/tidb/volumes/*
 docker/volumes/etcd/*
 docker/volumes/minio/*
 docker/volumes/milvus/*

+ 40 - 22
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -9,6 +9,7 @@ from sqlalchemy import text as sql_text
 from sqlalchemy.orm import Session, declarative_base
 
 from configs import dify_config
+from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
@@ -54,14 +55,13 @@ class TiDBVector(BaseVector):
         return Table(
             self._collection_name,
             self._orm_base.metadata,
-            Column("id", String(36), primary_key=True, nullable=False),
+            Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
             Column(
-                "vector",
+                Field.VECTOR.value,
                 VectorType(dim),
                 nullable=False,
-                comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
             ),
-            Column("text", TEXT, nullable=False),
+            Column(Field.TEXT_KEY.value, TEXT, nullable=False),
             Column("meta", JSON, nullable=False),
             Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
             Column(
@@ -96,6 +96,7 @@ class TiDBVector(BaseVector):
             collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
             if redis_client.get(collection_exist_cache_key):
                 return
+            tidb_dist_func = self._get_distance_func()
             with Session(self._engine) as session:
                 session.begin()
                 create_statement = sql_text(f"""
@@ -104,14 +105,14 @@ class TiDBVector(BaseVector):
                         text TEXT NOT NULL,
                         meta JSON NOT NULL,
                         doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
-                        KEY (doc_id),
-                        vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
+                        vector VECTOR<FLOAT>({dimension}) NOT NULL,
                         create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
-                        update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
+                        update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+                        KEY (doc_id),
+                        VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
                     );
                 """)
                 session.execute(create_statement)
-                # tidb vector not support 'CREATE/ADD INDEX' now
                 session.commit()
             redis_client.set(collection_exist_cache_key, 1, ex=3600)
 
@@ -194,23 +195,30 @@ class TiDBVector(BaseVector):
         )
 
         docs = []
-        if self._distance_func == "l2":
-            tidb_func = "Vec_l2_distance"
-        elif self._distance_func == "cosine":
-            tidb_func = "Vec_Cosine_distance"
-        else:
-            tidb_func = "Vec_Cosine_distance"
+        tidb_dist_func = self._get_distance_func()
 
         with Session(self._engine) as session:
-            select_statement = sql_text(
-                f"""SELECT meta, text, distance FROM (
-                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance
-                        FROM {self._collection_name}
-                        ORDER BY distance
-                        LIMIT {top_k}
-                    ) t WHERE distance < {distance};"""
+            select_statement = sql_text(f"""
+                SELECT meta, text, distance 
+                FROM (
+                  SELECT 
+                    meta,
+                    text,
+                    {tidb_dist_func}(vector, :query_vector_str) AS distance
+                  FROM {self._collection_name}
+                  ORDER BY distance ASC
+                  LIMIT :top_k
+                ) t
+                WHERE distance <= :distance
+                """)
+            res = session.execute(
+                select_statement,
+                params={
+                    "query_vector_str": query_vector_str,
+                    "distance": distance,
+                    "top_k": top_k,
+                },
             )
-            res = session.execute(select_statement)
             results = [(row[0], row[1], row[2]) for row in res]
             for meta, text, distance in results:
                 metadata = json.loads(meta)
@@ -227,6 +235,16 @@ class TiDBVector(BaseVector):
             session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
             session.commit()
 
+    def _get_distance_func(self) -> str:
+        match self._distance_func:
+            case "l2":
+                tidb_dist_func = "VEC_L2_DISTANCE"
+            case "cosine":
+                tidb_dist_func = "VEC_COSINE_DISTANCE"
+            case _:
+                tidb_dist_func = "VEC_COSINE_DISTANCE"
+        return tidb_dist_func
+
 
 class TiDBVectorFactory(AbstractVectorFactory):
     def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:

+ 59 - 0
api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py

@@ -0,0 +1,59 @@
+import time
+
+import pymysql
+
+
+def check_tiflash_ready() -> bool:
+    try:
+        connection = pymysql.connect(
+            host="localhost",
+            port=4000,
+            user="root",
+            password="",
+        )
+
+        with connection.cursor() as cursor:
+            # Doc reference:
+            # https://docs.pingcap.com/zh/tidb/stable/information-schema-cluster-hardware
+            select_tiflash_query = """
+            SELECT * FROM information_schema.cluster_hardware
+            WHERE TYPE='tiflash'
+            LIMIT 1;
+            """
+            cursor.execute(select_tiflash_query)
+            result = cursor.fetchall()
+            return result is not None and len(result) > 0
+    except Exception as e:
+        print(f"TiFlash is not ready. Exception: {e}")
+        return False
+    finally:
+        if connection:
+            connection.close()
+
+
+def main():
+    max_attempts = 30
+    retry_interval_seconds = 2
+    is_tiflash_ready = False
+    for attempt in range(max_attempts):
+        try:
+            is_tiflash_ready = check_tiflash_ready()
+        except Exception as e:
+            print(f"TiFlash is not ready. Exception: {e}")
+            is_tiflash_ready = False
+
+        if is_tiflash_ready:
+            break
+        else:
+            print(f"Attempt {attempt + 1} failed,retry in {retry_interval_seconds} seconds...")
+            time.sleep(retry_interval_seconds)
+
+    if is_tiflash_ready:
+        print("TiFlash is ready in TiDB.")
+    else:
+        print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.")
+        exit(1)
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 10
docker/docker-compose-template.yaml

@@ -199,16 +199,6 @@ services:
       - '${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}'
       - '${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}'
 
-  # The TiDB vector store.
-  # For production use, please refer to https://github.com/pingcap/tidb-docker-compose
-  tidb:
-    image: pingcap/tidb:v8.4.0
-    profiles:
-      - tidb
-    command:
-      - --store=unistore
-    restart: always
-
   # The Weaviate vector store.
   weaviate:
     image: semitechnologies/weaviate:1.19.0

+ 0 - 10
docker/docker-compose.yaml

@@ -594,16 +594,6 @@ services:
       - '${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}'
       - '${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}'
 
-  # The TiDB vector store.
-  # For production use, please refer to https://github.com/pingcap/tidb-docker-compose
-  tidb:
-    image: pingcap/tidb:v8.4.0
-    profiles:
-      - tidb
-    command:
-      - --store=unistore
-    restart: always
-
   # The Weaviate vector store.
   weaviate:
     image: semitechnologies/weaviate:1.19.0

+ 4 - 0
docker/tidb/config/pd.toml

@@ -0,0 +1,4 @@
+# PD Configuration File reference:
+# https://docs.pingcap.com/tidb/stable/pd-configuration-file#pd-configuration-file
+[replication]
+max-replicas = 1

+ 13 - 0
docker/tidb/config/tiflash-learner.toml

@@ -0,0 +1,13 @@
+# TiFlash tiflash-learner.toml Configuration File reference:
+# https://docs.pingcap.com/tidb/stable/tiflash-configuration#configure-the-tiflash-learnertoml-file
+
+log-file = "/logs/tiflash_tikv.log"
+
+[server]
+engine-addr = "tiflash:4030"
+addr = "0.0.0.0:20280"
+advertise-addr = "tiflash:20280"
+status-addr = "tiflash:20292"
+
+[storage]
+data-dir = "/data/flash"

+ 19 - 0
docker/tidb/config/tiflash.toml

@@ -0,0 +1,19 @@
+# TiFlash tiflash.toml Configuration File reference:
+# https://docs.pingcap.com/tidb/stable/tiflash-configuration#configure-the-tiflashtoml-file
+
+listen_host = "0.0.0.0"
+path = "/data"
+
+[flash]
+tidb_status_addr = "tidb:10080"
+service_addr = "tiflash:4030"
+
+[flash.proxy]
+config = "/tiflash-learner.toml"
+
+[logger]
+errorlog = "/logs/tiflash_error.log"
+log = "/logs/tiflash.log"
+
+[raft]
+pd_addr = "pd0:2379"

+ 62 - 0
docker/tidb/docker-compose.yaml

@@ -0,0 +1,62 @@
+services:
+  pd0:
+    image: pingcap/pd:v8.5.1
+    # ports:
+    #  - "2379"
+    volumes:
+      - ./config/pd.toml:/pd.toml:ro
+      - ./volumes/data:/data
+      - ./volumes/logs:/logs
+    command:
+      - --name=pd0
+      - --client-urls=http://0.0.0.0:2379
+      - --peer-urls=http://0.0.0.0:2380
+      - --advertise-client-urls=http://pd0:2379
+      - --advertise-peer-urls=http://pd0:2380
+      - --initial-cluster=pd0=http://pd0:2380
+      - --data-dir=/data/pd
+      - --config=/pd.toml
+      - --log-file=/logs/pd.log
+    restart: on-failure
+  tikv:
+    image: pingcap/tikv:v8.5.1
+    volumes:
+      - ./volumes/data:/data
+      - ./volumes/logs:/logs
+    command:
+      - --addr=0.0.0.0:20160
+      - --advertise-addr=tikv:20160
+      - --status-addr=tikv:20180
+      - --data-dir=/data/tikv
+      - --pd=pd0:2379
+      - --log-file=/logs/tikv.log
+    depends_on:
+      - "pd0"
+    restart: on-failure
+  tidb:
+    image: pingcap/tidb:v8.5.1
+    # ports:
+    #   - "4000:4000"
+    volumes:
+      - ./volumes/logs:/logs
+    command:
+      - --advertise-address=tidb
+      - --store=tikv
+      - --path=pd0:2379
+      - --log-file=/logs/tidb.log
+    depends_on:
+      - "tikv"
+    restart: on-failure
+  tiflash:
+    image: pingcap/tiflash:v8.5.1
+    volumes:
+      - ./config/tiflash.toml:/tiflash.toml:ro
+      - ./config/tiflash-learner.toml:/tiflash-learner.toml:ro
+      - ./volumes/data:/data
+      - ./volumes/logs:/logs
+    command:
+      - --config=/tiflash.toml
+    depends_on:
+      - "tikv"
+      - "tidb"
+    restart: on-failure