|
@@ -25,6 +25,7 @@ class OpenGaussConfig(BaseModel):
|
|
|
database: str
|
|
|
min_connection: int
|
|
|
max_connection: int
|
|
|
+ enable_pq: bool = False # Enable PQ acceleration
|
|
|
|
|
|
@model_validator(mode="before")
|
|
|
@classmethod
|
|
@@ -57,6 +58,11 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
|
|
);
|
|
|
"""
|
|
|
|
|
|
+SQL_CREATE_INDEX_PQ = """
|
|
|
+CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
|
|
|
+USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m});
|
|
|
+"""
|
|
|
+
|
|
|
SQL_CREATE_INDEX = """
|
|
|
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
|
|
|
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
|
@@ -68,6 +74,7 @@ class OpenGauss(BaseVector):
|
|
|
super().__init__(collection_name)
|
|
|
self.pool = self._create_connection_pool(config)
|
|
|
self.table_name = f"embedding_{collection_name}"
|
|
|
+ self.pq_enabled = config.enable_pq
|
|
|
|
|
|
def get_type(self) -> str:
|
|
|
return VectorType.OPENGAUSS
|
|
@@ -97,7 +104,26 @@ class OpenGauss(BaseVector):
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
dimension = len(embeddings[0])
|
|
|
self._create_collection(dimension)
|
|
|
- return self.add_texts(texts, embeddings)
|
|
|
+ self.add_texts(texts, embeddings)
|
|
|
+ self._create_index(dimension)
|
|
|
+
|
|
|
+ def _create_index(self, dimension: int):
|
|
|
+ index_cache_key = f"vector_index_{self._collection_name}"
|
|
|
+ lock_name = f"{index_cache_key}_lock"
|
|
|
+ with redis_client.lock(lock_name, timeout=60):
|
|
|
+ index_exist_cache_key = f"vector_index_{self._collection_name}"
|
|
|
+ if redis_client.get(index_exist_cache_key):
|
|
|
+ return
|
|
|
+
|
|
|
+ with self._get_cursor() as cur:
|
|
|
+ if dimension <= 2000:
|
|
|
+ if self.pq_enabled:
|
|
|
+ cur.execute(SQL_CREATE_INDEX_PQ.format(table_name=self.table_name, pq_m=int(dimension / 4)))
|
|
|
+ cur.execute("SET hnsw_earlystop_threshold = 320")
|
|
|
+
|
|
|
+ if not self.pq_enabled:
|
|
|
+ cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
|
|
+ redis_client.set(index_exist_cache_key, 1, ex=3600)
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
values = []
|
|
@@ -211,8 +237,6 @@ class OpenGauss(BaseVector):
|
|
|
|
|
|
with self._get_cursor() as cur:
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
|
|
- if dimension <= 2000:
|
|
|
- cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
|
|
|
|
@@ -236,5 +260,6 @@ class OpenGaussFactory(AbstractVectorFactory):
|
|
|
database=dify_config.OPENGAUSS_DATABASE or "dify",
|
|
|
min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
|
|
|
max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
|
|
|
+ enable_pq=dify_config.OPENGAUSS_ENABLE_PQ or False,
|
|
|
),
|
|
|
)
|