|
@@ -0,0 +1,214 @@
|
|
|
+import json
|
|
|
+import logging
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+import sqlalchemy
|
|
|
+from pydantic import BaseModel, root_validator
|
|
|
+from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
|
|
+from sqlalchemy import text as sql_text
|
|
|
+from sqlalchemy.orm import Session, declarative_base
|
|
|
+
|
|
|
+from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
+from core.rag.models.document import Document
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class TiDBVectorConfig(BaseModel):
|
|
|
+ host: str
|
|
|
+ port: int
|
|
|
+ user: str
|
|
|
+ password: str
|
|
|
+ database: str
|
|
|
+
|
|
|
+ @root_validator()
|
|
|
+ def validate_config(cls, values: dict) -> dict:
|
|
|
+ if not values['host']:
|
|
|
+ raise ValueError("config TIDB_VECTOR_HOST is required")
|
|
|
+ if not values['port']:
|
|
|
+ raise ValueError("config TIDB_VECTOR_PORT is required")
|
|
|
+ if not values['user']:
|
|
|
+ raise ValueError("config TIDB_VECTOR_USER is required")
|
|
|
+ if not values['password']:
|
|
|
+ raise ValueError("config TIDB_VECTOR_PASSWORD is required")
|
|
|
+ if not values['database']:
|
|
|
+ raise ValueError("config TIDB_VECTOR_DATABASE is required")
|
|
|
+ return values
|
|
|
+
|
|
|
+
|
|
|
+class TiDBVector(BaseVector):
|
|
|
+
|
|
|
+ def _table(self, dim: int) -> Table:
|
|
|
+ from tidb_vector.sqlalchemy import VectorType
|
|
|
+ return Table(
|
|
|
+ self._collection_name,
|
|
|
+ self._orm_base.metadata,
|
|
|
+ Column('id', String(36), primary_key=True, nullable=False),
|
|
|
+ Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
|
|
|
+ Column("text", TEXT, nullable=False),
|
|
|
+ Column("meta", JSON, nullable=False),
|
|
|
+ Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
|
|
+ Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
|
|
+ extend_existing=True
|
|
|
+ )
|
|
|
+
|
|
|
+ def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
|
|
|
+ super().__init__(collection_name)
|
|
|
+ self._client_config = config
|
|
|
+ self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
|
|
+ f"ssl_verify_cert=true&ssl_verify_identity=true")
|
|
|
+ self._distance_func = distance_func.lower()
|
|
|
+ self._engine = create_engine(self._url)
|
|
|
+ self._orm_base = declarative_base()
|
|
|
+ self._dimension = 1536
|
|
|
+
|
|
|
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
+ logger.info("create collection and add texts, collection_name: " + self._collection_name)
|
|
|
+ self._create_collection(len(embeddings[0]))
|
|
|
+ self.add_texts(texts, embeddings)
|
|
|
+ self._dimension = len(embeddings[0])
|
|
|
+ pass
|
|
|
+
|
|
|
+ def _create_collection(self, dimension: int):
|
|
|
+ logger.info("_create_collection, collection_name " + self._collection_name)
|
|
|
+ lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
|
|
+ with redis_client.lock(lock_name, timeout=20):
|
|
|
+ collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
|
|
+ if redis_client.get(collection_exist_cache_key):
|
|
|
+ return
|
|
|
+ with Session(self._engine) as session:
|
|
|
+ session.begin()
|
|
|
+ drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
|
|
|
+ session.execute(drop_statement)
|
|
|
+ create_statement = sql_text(f"""
|
|
|
+ CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
|
|
+ id CHAR(36) PRIMARY KEY,
|
|
|
+ text TEXT NOT NULL,
|
|
|
+ meta JSON NOT NULL,
|
|
|
+ vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
|
|
+ create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
|
+ update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
|
|
+ );
|
|
|
+ """)
|
|
|
+ session.execute(create_statement)
|
|
|
+ # tidb vector not support 'CREATE/ADD INDEX' now
|
|
|
+ session.commit()
|
|
|
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
+
|
|
|
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
+ table = self._table(len(embeddings[0]))
|
|
|
+ ids = self._get_uuids(documents)
|
|
|
+ metas = [d.metadata for d in documents]
|
|
|
+ texts = [d.page_content for d in documents]
|
|
|
+
|
|
|
+ chunks_table_data = []
|
|
|
+ with self._engine.connect() as conn:
|
|
|
+ with conn.begin():
|
|
|
+ for id, text, meta, embedding in zip(
|
|
|
+ ids, texts, metas, embeddings
|
|
|
+ ):
|
|
|
+ chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
|
|
+
|
|
|
+ # Execute the batch insert when the batch size is reached
|
|
|
+ if len(chunks_table_data) == 500:
|
|
|
+ conn.execute(insert(table).values(chunks_table_data))
|
|
|
+ # Clear the chunks_table_data list for the next batch
|
|
|
+ chunks_table_data.clear()
|
|
|
+
|
|
|
+ # Insert any remaining records that didn't make up a full batch
|
|
|
+ if chunks_table_data:
|
|
|
+ conn.execute(insert(table).values(chunks_table_data))
|
|
|
+ return ids
|
|
|
+
|
|
|
+ def text_exists(self, id: str) -> bool:
|
|
|
+ result = self.get_ids_by_metadata_field('doc_id', id)
|
|
|
+ return len(result) > 0
|
|
|
+
|
|
|
+ def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
+ with Session(self._engine) as session:
|
|
|
+ ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
|
|
+ select_statement = sql_text(
|
|
|
+ f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
|
|
|
+ )
|
|
|
+ result = session.execute(select_statement).fetchall()
|
|
|
+ if result:
|
|
|
+ ids = [item[0] for item in result]
|
|
|
+ self._delete_by_ids(ids)
|
|
|
+
|
|
|
+ def _delete_by_ids(self, ids: list[str]) -> bool:
|
|
|
+ if ids is None:
|
|
|
+ raise ValueError("No ids provided to delete.")
|
|
|
+ table = self._table(self._dimension)
|
|
|
+ try:
|
|
|
+ with self._engine.connect() as conn:
|
|
|
+ with conn.begin():
|
|
|
+ delete_condition = table.c.id.in_(ids)
|
|
|
+ conn.execute(table.delete().where(delete_condition))
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ print("Delete operation failed:", str(e))
|
|
|
+ return False
|
|
|
+
|
|
|
+ def delete_by_document_id(self, document_id: str):
|
|
|
+ ids = self.get_ids_by_metadata_field('document_id', document_id)
|
|
|
+ if ids:
|
|
|
+ self._delete_by_ids(ids)
|
|
|
+
|
|
|
+ def get_ids_by_metadata_field(self, key: str, value: str):
|
|
|
+ with Session(self._engine) as session:
|
|
|
+ select_statement = sql_text(
|
|
|
+ f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
|
|
|
+ )
|
|
|
+ result = session.execute(select_statement).fetchall()
|
|
|
+ if result:
|
|
|
+ return [item[0] for item in result]
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
|
|
|
+ ids = self.get_ids_by_metadata_field(key, value)
|
|
|
+ if ids:
|
|
|
+ self._delete_by_ids(ids)
|
|
|
+
|
|
|
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
+ top_k = kwargs.get("top_k", 5)
|
|
|
+ score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
|
|
+ filter = kwargs.get('filter')
|
|
|
+ distance = 1 - score_threshold
|
|
|
+
|
|
|
+ query_vector_str = ", ".join(format(x) for x in query_vector)
|
|
|
+ query_vector_str = "[" + query_vector_str + "]"
|
|
|
+ logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
|
|
|
+
|
|
|
+ docs = []
|
|
|
+ if self._distance_func == 'l2':
|
|
|
+ tidb_func = 'Vec_l2_distance'
|
|
|
+ elif self._distance_func == 'l2':
|
|
|
+ tidb_func = 'Vec_Cosine_distance'
|
|
|
+ else:
|
|
|
+ tidb_func = 'Vec_Cosine_distance'
|
|
|
+
|
|
|
+ with Session(self._engine) as session:
|
|
|
+ select_statement = sql_text(
|
|
|
+ f"""SELECT meta, text FROM (
|
|
|
+ SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
|
|
+ FROM {self._collection_name}
|
|
|
+ ORDER BY distance
|
|
|
+ LIMIT {top_k}
|
|
|
+ ) t WHERE distance < {distance};"""
|
|
|
+ )
|
|
|
+ res = session.execute(select_statement)
|
|
|
+ results = [(row[0], row[1]) for row in res]
|
|
|
+ for meta, text in results:
|
|
|
+ docs.append(Document(page_content=text, metadata=json.loads(meta)))
|
|
|
+ return docs
|
|
|
+
|
|
|
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
+ # tidb doesn't support bm25 search
|
|
|
+ return []
|
|
|
+
|
|
|
+ def delete(self) -> None:
|
|
|
+ with Session(self._engine) as session:
|
|
|
+ session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
|
|
+ session.commit()
|