|
@@ -9,6 +9,7 @@ from sqlalchemy import text as sql_text
|
|
|
from sqlalchemy.orm import Session, declarative_base
|
|
|
|
|
|
from configs import dify_config
|
|
|
+from core.rag.datasource.vdb.field import Field
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
|
|
from core.rag.datasource.vdb.vector_type import VectorType
|
|
@@ -54,14 +55,13 @@ class TiDBVector(BaseVector):
|
|
|
return Table(
|
|
|
self._collection_name,
|
|
|
self._orm_base.metadata,
|
|
|
- Column("id", String(36), primary_key=True, nullable=False),
|
|
|
+ Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
|
|
|
Column(
|
|
|
- "vector",
|
|
|
+ Field.VECTOR.value,
|
|
|
VectorType(dim),
|
|
|
nullable=False,
|
|
|
- comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
|
|
|
),
|
|
|
- Column("text", TEXT, nullable=False),
|
|
|
+ Column(Field.TEXT_KEY.value, TEXT, nullable=False),
|
|
|
Column("meta", JSON, nullable=False),
|
|
|
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
|
|
Column(
|
|
@@ -96,6 +96,7 @@ class TiDBVector(BaseVector):
|
|
|
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
|
|
if redis_client.get(collection_exist_cache_key):
|
|
|
return
|
|
|
+ tidb_dist_func = self._get_distance_func()
|
|
|
with Session(self._engine) as session:
|
|
|
session.begin()
|
|
|
create_statement = sql_text(f"""
|
|
@@ -104,14 +105,14 @@ class TiDBVector(BaseVector):
|
|
|
text TEXT NOT NULL,
|
|
|
meta JSON NOT NULL,
|
|
|
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
|
|
|
- KEY (doc_id),
|
|
|
- vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
|
|
+ vector VECTOR<FLOAT>({dimension}) NOT NULL,
|
|
|
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
|
- update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
|
|
+ update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
|
|
+ KEY (doc_id),
|
|
|
+ VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
|
|
|
);
|
|
|
""")
|
|
|
session.execute(create_statement)
|
|
|
- # tidb vector not support 'CREATE/ADD INDEX' now
|
|
|
session.commit()
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
|
@@ -194,23 +195,30 @@ class TiDBVector(BaseVector):
|
|
|
)
|
|
|
|
|
|
docs = []
|
|
|
- if self._distance_func == "l2":
|
|
|
- tidb_func = "Vec_l2_distance"
|
|
|
- elif self._distance_func == "cosine":
|
|
|
- tidb_func = "Vec_Cosine_distance"
|
|
|
- else:
|
|
|
- tidb_func = "Vec_Cosine_distance"
|
|
|
+ tidb_dist_func = self._get_distance_func()
|
|
|
|
|
|
with Session(self._engine) as session:
|
|
|
- select_statement = sql_text(
|
|
|
- f"""SELECT meta, text, distance FROM (
|
|
|
- SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
|
|
- FROM {self._collection_name}
|
|
|
- ORDER BY distance
|
|
|
- LIMIT {top_k}
|
|
|
- ) t WHERE distance < {distance};"""
|
|
|
+ select_statement = sql_text(f"""
|
|
|
+ SELECT meta, text, distance
|
|
|
+ FROM (
|
|
|
+ SELECT
|
|
|
+ meta,
|
|
|
+ text,
|
|
|
+ {tidb_dist_func}(vector, :query_vector_str) AS distance
|
|
|
+ FROM {self._collection_name}
|
|
|
+ ORDER BY distance ASC
|
|
|
+ LIMIT :top_k
|
|
|
+ ) t
|
|
|
+ WHERE distance <= :distance
|
|
|
+ """)
|
|
|
+ res = session.execute(
|
|
|
+ select_statement,
|
|
|
+ params={
|
|
|
+ "query_vector_str": query_vector_str,
|
|
|
+ "distance": distance,
|
|
|
+ "top_k": top_k,
|
|
|
+ },
|
|
|
)
|
|
|
- res = session.execute(select_statement)
|
|
|
results = [(row[0], row[1], row[2]) for row in res]
|
|
|
for meta, text, distance in results:
|
|
|
metadata = json.loads(meta)
|
|
@@ -227,6 +235,16 @@ class TiDBVector(BaseVector):
|
|
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
|
|
session.commit()
|
|
|
|
|
|
+ def _get_distance_func(self) -> str:
|
|
|
+ match self._distance_func:
|
|
|
+ case "l2":
|
|
|
+ tidb_dist_func = "VEC_L2_DISTANCE"
|
|
|
+ case "cosine":
|
|
|
+ tidb_dist_func = "VEC_COSINE_DISTANCE"
|
|
|
+ case _:
|
|
|
+ tidb_dist_func = "VEC_COSINE_DISTANCE"
|
|
|
+ return tidb_dist_func
|
|
|
+
|
|
|
|
|
|
class TiDBVectorFactory(AbstractVectorFactory):
|
|
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|