浏览代码

feat: support openGauss vector database (#15865)

LittleFish-15 1 月之前
父节点
当前提交
223ab5a38f

+ 2 - 1
.github/workflows/expose_service_ports.sh

@@ -10,5 +10,6 @@ yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-com
 yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
 yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
 yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
+yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
 
-echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
+echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

+ 1 - 0
.github/workflows/vdb-tests.yml

@@ -76,6 +76,7 @@ jobs:
             milvus-standalone
             pgvecto-rs
             pgvector
+            opengauss
             chroma
             elasticsearch
 

+ 10 - 2
api/.env.example

@@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 
 # Vector database configuration
-# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
+# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss
 VECTOR_STORE=weaviate
 
 # Weaviate configuration
@@ -298,6 +298,14 @@ OCEANBASE_VECTOR_PASSWORD=difyai123456
 OCEANBASE_VECTOR_DATABASE=test
 OCEANBASE_MEMORY_LIMIT=6G
 
+# openGauss configuration
+OPENGAUSS_HOST=127.0.0.1
+OPENGAUSS_PORT=6600
+OPENGAUSS_USER=postgres
+OPENGAUSS_PASSWORD=Dify@123
+OPENGAUSS_DATABASE=dify
+OPENGAUSS_MIN_CONNECTION=1
+OPENGAUSS_MAX_CONNECTION=5
 
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
@@ -445,4 +453,4 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
 # Maximum number of submitted thread count in a ThreadPool for parallel node execution
 MAX_SUBMIT_COUNT=100
 # Lockout duration in seconds
-LOGIN_LOCKOUT_DURATION=86400
+LOGIN_LOCKOUT_DURATION=86400

+ 1 - 0
api/commands.py

@@ -267,6 +267,7 @@ def migrate_knowledge_vector_database():
         VectorType.WEAVIATE,
         VectorType.ORACLE,
         VectorType.ELASTICSEARCH,
+        VectorType.OPENGAUSS,
     }
     lower_collection_vector_types = {
         VectorType.ANALYTICDB,

+ 2 - 0
api/configs/middleware/__init__.py

@@ -26,6 +26,7 @@ from .vdb.lindorm_config import LindormConfig
 from .vdb.milvus_config import MilvusConfig
 from .vdb.myscale_config import MyScaleConfig
 from .vdb.oceanbase_config import OceanBaseVectorConfig
+from .vdb.opengauss_config import OpenGaussConfig
 from .vdb.opensearch_config import OpenSearchConfig
 from .vdb.oracle_config import OracleConfig
 from .vdb.pgvector_config import PGVectorConfig
@@ -281,5 +282,6 @@ class MiddlewareConfig(
     LindormConfig,
     OceanBaseVectorConfig,
     BaiduVectorDBConfig,
+    OpenGaussConfig,
 ):
     pass

+ 45 - 0
api/configs/middleware/vdb/opengauss_config.py

@@ -0,0 +1,45 @@
+from typing import Optional
+
+from pydantic import Field, PositiveInt
+from pydantic_settings import BaseSettings
+
+
+class OpenGaussConfig(BaseSettings):
+    """
+    Configuration settings for OpenGauss
+    """
+
+    OPENGAUSS_HOST: Optional[str] = Field(
+        description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')",
+        default=None,
+    )
+
+    OPENGAUSS_PORT: PositiveInt = Field(
+        description="Port number on which the OpenGauss server is listening (default is 6600)",
+        default=6600,
+    )
+
+    OPENGAUSS_USER: Optional[str] = Field(
+        description="Username for authenticating with the OpenGauss database",
+        default=None,
+    )
+
+    OPENGAUSS_PASSWORD: Optional[str] = Field(
+        description="Password for authenticating with the OpenGauss database",
+        default=None,
+    )
+
+    OPENGAUSS_DATABASE: Optional[str] = Field(
+        description="Name of the OpenGauss database to connect to",
+        default=None,
+    )
+
+    OPENGAUSS_MIN_CONNECTION: PositiveInt = Field(
+        description="Min connection of the OpenGauss database",
+        default=1,
+    )
+
+    OPENGAUSS_MAX_CONNECTION: PositiveInt = Field(
+        description="Max connection of the OpenGauss database",
+        default=5,
+    )

+ 2 - 0
api/controllers/console/datasets/datasets.py

@@ -659,6 +659,7 @@ class DatasetRetrievalSettingApi(Resource):
                 | VectorType.LINDORM
                 | VectorType.COUCHBASE
                 | VectorType.MILVUS
+                | VectorType.OPENGAUSS
             ):
                 return {
                     "retrieval_method": [
@@ -702,6 +703,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                 | VectorType.COUCHBASE
                 | VectorType.PGVECTOR
                 | VectorType.LINDORM
+                | VectorType.OPENGAUSS
             ):
                 return {
                     "retrieval_method": [

+ 0 - 0
api/core/rag/datasource/vdb/opengauss/__init__.py


+ 238 - 0
api/core/rag/datasource/vdb/opengauss/opengauss.py

@@ -0,0 +1,238 @@
+import json
+import uuid
+from contextlib import contextmanager
+from typing import Any
+
+import psycopg2.extras  # type: ignore
+import psycopg2.pool  # type: ignore
+from pydantic import BaseModel, model_validator
+
+from configs import dify_config
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+
+class OpenGaussConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    database: str
+    min_connection: int
+    max_connection: int
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_config(cls, values: dict) -> dict:
+        if not values["host"]:
+            raise ValueError("config OPENGAUSS_HOST is required")
+        if not values["port"]:
+            raise ValueError("config OPENGAUSS_PORT is required")
+        if not values["user"]:
+            raise ValueError("config OPENGAUSS_USER is required")
+        if not values["password"]:
+            raise ValueError("config OPENGAUSS_PASSWORD is required")
+        if not values["database"]:
+            raise ValueError("config OPENGAUSS_DATABASE is required")
+        if not values["min_connection"]:
+            raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
+        if not values["max_connection"]:
+            raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
+        if values["min_connection"] > values["max_connection"]:
+            raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
+        return values
+
+
+SQL_CREATE_TABLE = """
+CREATE TABLE IF NOT EXISTS {table_name} (
+    id UUID PRIMARY KEY,
+    text TEXT NOT NULL,
+    meta JSONB NOT NULL,
+    embedding vector({dimension}) NOT NULL
+);
+"""
+
+SQL_CREATE_INDEX = """
+CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name} 
+USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
+"""
+
+
+class OpenGauss(BaseVector):
+    def __init__(self, collection_name: str, config: OpenGaussConfig):
+        super().__init__(collection_name)
+        self.pool = self._create_connection_pool(config)
+        self.table_name = f"embedding_{collection_name}"
+
+    def get_type(self) -> str:
+        return VectorType.OPENGAUSS
+
+    def _create_connection_pool(self, config: OpenGaussConfig):
+        return psycopg2.pool.SimpleConnectionPool(
+            config.min_connection,
+            config.max_connection,
+            host=config.host,
+            port=config.port,
+            user=config.user,
+            password=config.password,
+            database=config.database,
+        )
+
+    @contextmanager
+    def _get_cursor(self):
+        conn = self.pool.getconn()
+        cur = conn.cursor()
+        try:
+            yield cur
+        finally:
+            cur.close()
+            conn.commit()
+            self.pool.putconn(conn)
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        dimension = len(embeddings[0])
+        self._create_collection(dimension)
+        return self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        values = []
+        pks = []
+        for i, doc in enumerate(documents):
+            if doc.metadata is not None:
+                doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+                pks.append(doc_id)
+                values.append(
+                    (
+                        doc_id,
+                        doc.page_content,
+                        json.dumps(doc.metadata),
+                        embeddings[i],
+                    )
+                )
+        with self._get_cursor() as cur:
+            psycopg2.extras.execute_values(
+                cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
+            )
+        return pks
+
+    def text_exists(self, id: str) -> bool:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
+            return cur.fetchone() is not None
+
+    def get_by_ids(self, ids: list[str]) -> list[Document]:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+            docs = []
+            for record in cur:
+                docs.append(Document(page_content=record[1], metadata=record[0]))
+        return docs
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
+        # Scenario 1: extract a document fails, resulting in a table not being created.
+        # Then clicking the retry button triggers a delete operation on an empty list.
+        if not ids:
+            return
+        with self._get_cursor() as cur:
+            cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        """
+        Search the nearest neighbors to a vector.
+
+        :param query_vector: The input vector to search for similar items.
+        :param top_k: The number of nearest neighbors to return, default is 5.
+        :return: List of Documents that are nearest to the query vector.
+        """
+        top_k = kwargs.get("top_k", 4)
+
+        with self._get_cursor() as cur:
+            cur.execute(
+                f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
+                f" ORDER BY distance LIMIT {top_k}",
+                (json.dumps(query_vector),),
+            )
+            docs = []
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
+            for record in cur:
+                metadata, text, distance = record
+                score = 1 - distance
+                metadata["score"] = score
+                if score > score_threshold:
+                    docs.append(Document(page_content=text, metadata=metadata))
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        top_k = kwargs.get("top_k", 5)
+
+        with self._get_cursor() as cur:
+            cur.execute(
+                f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
+                FROM {self.table_name}
+                WHERE to_tsvector(text) @@ plainto_tsquery(%s)
+                ORDER BY score DESC
+                LIMIT {top_k}""",
+                # f"'{query}'" is required in order to account for whitespace in query
+                (f"'{query}'", f"'{query}'"),
+            )
+
+            docs = []
+
+            for record in cur:
+                metadata, text, score = record
+                metadata["score"] = score
+                docs.append(Document(page_content=text, metadata=metadata))
+
+        return docs
+
+    def delete(self) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
+
+    def _create_collection(self, dimension: int):
+        cache_key = f"vector_indexing_{self._collection_name}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+            if redis_client.get(collection_exist_cache_key):
+                return
+
+            with self._get_cursor() as cur:
+                cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
+                if dimension <= 2000:
+                    cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+
+class OpenGaussFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
+
+        return OpenGauss(
+            collection_name=collection_name,
+            config=OpenGaussConfig(
+                host=dify_config.OPENGAUSS_HOST or "localhost",
+                port=dify_config.OPENGAUSS_PORT,
+                user=dify_config.OPENGAUSS_USER or "postgres",
+                password=dify_config.OPENGAUSS_PASSWORD or "",
+                database=dify_config.OPENGAUSS_DATABASE or "dify",
+                min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
+                max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
+            ),
+        )

+ 4 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -148,6 +148,10 @@ class Vector:
                 from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
 
                 return OceanBaseVectorFactory
+            case VectorType.OPENGAUSS:
+                from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
+
+                return OpenGaussFactory
             case _:
                 raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -24,3 +24,4 @@ class VectorType(StrEnum):
     UPSTASH = "upstash"
     TIDB_ON_QDRANT = "tidb_on_qdrant"
     OCEANBASE = "oceanbase"
+    OPENGAUSS = "opengauss"

+ 0 - 0
api/tests/integration_tests/vdb/opengauss/__init__.py


+ 40 - 0
api/tests/integration_tests/vdb/opengauss/test_opengauss.py

@@ -0,0 +1,40 @@
+import time
+
+from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
+from tests.integration_tests.vdb.test_vector_store import (
+    AbstractVectorTest,
+    get_example_text,
+    setup_mock_redis,
+)
+
+
+class OpenGaussTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        max_retries = 5
+        retry_delay = 20
+        retry_count = 0
+        while retry_count < max_retries:
+            try:
+                config = OpenGaussConfig(
+                    host="localhost",
+                    port=6600,
+                    user="postgres",
+                    password="Dify@123",
+                    database="dify",
+                    min_connection=1,
+                    max_connection=5,
+                )
+                break
+            except psycopg2.OperationalError as e:
+                retry_count += 1
+                if retry_count < max_retries:
+                    time.sleep(retry_delay)
+        self.vector = OpenGauss(
+            collection_name=self.collection_name,
+            config=config,
+        )
+
+
+def test_opengauss(setup_mock_redis):
+    OpenGaussTest().run_all_tests()

+ 10 - 1
docker/.env.example

@@ -383,7 +383,7 @@ SUPABASE_URL=your-server-url
 # ------------------------------
 
 # The type of vector store to use.
-# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
+# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`.
 VECTOR_STORE=weaviate
 
 # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
@@ -555,6 +555,15 @@ OCEANBASE_VECTOR_DATABASE=test
 OCEANBASE_CLUSTER_NAME=difyai
 OCEANBASE_MEMORY_LIMIT=6G
 
+# opengauss configurations, only available when VECTOR_STORE is `opengauss`
+OPENGAUSS_HOST=opengauss
+OPENGAUSS_PORT=6600
+OPENGAUSS_USER=postgres
+OPENGAUSS_PASSWORD=Dify@123
+OPENGAUSS_DATABASE=dify
+OPENGAUSS_MIN_CONNECTION=1
+OPENGAUSS_MAX_CONNECTION=5
+
 # Upstash Vector configuration, only available when VECTOR_STORE is `upstash`
 UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io
 UPSTASH_VECTOR_TOKEN=dify

+ 22 - 0
docker/docker-compose-template.yaml

@@ -507,6 +507,28 @@ services:
     depends_on:
       - opensearch
 
+  # opengauss vector database.
+  opengauss:
+    image: opengauss/opengauss:7.0.0-RC1
+    profiles:
+      - opengauss
+    privileged: true
+    restart: always
+    environment:
+      GS_USERNAME: ${OPENGAUSS_USER:-postgres}
+      GS_PASSWORD: ${OPENGAUSS_PASSWORD:-Dify@123}
+      GS_PORT: ${OPENGAUSS_PORT:-6600}
+      GS_DB: ${OPENGAUSS_DATABASE:-dify}
+    volumes:
+      - ./volumes/opengauss/data:/var/lib/opengauss/data
+    healthcheck:
+      test: ["CMD-SHELL", "netstat -lntp | grep tcp6 > /dev/null 2>&1"]
+      interval: 10s
+      timeout: 10s
+      retries: 10
+    ports:
+      - ${OPENGAUSS_PORT:-6600}:${OPENGAUSS_PORT:-6600}
+
   # MyScale vector database
   myscale:
     container_name: myscale

+ 29 - 0
docker/docker-compose.yaml

@@ -252,6 +252,13 @@ x-shared-env: &shared-api-worker-env
   OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test}
   OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai}
   OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G}
+  OPENGAUSS_HOST: ${OPENGAUSS_HOST:-opengauss}
+  OPENGAUSS_PORT: ${OPENGAUSS_PORT:-6600}
+  OPENGAUSS_USER: ${OPENGAUSS_USER:-postgres}
+  OPENGAUSS_PASSWORD: ${OPENGAUSS_PASSWORD:-Dify@123}
+  OPENGAUSS_DATABASE: ${OPENGAUSS_DATABASE:-dify}
+  OPENGAUSS_MIN_CONNECTION: ${OPENGAUSS_MIN_CONNECTION:-1}
+  OPENGAUSS_MAX_CONNECTION: ${OPENGAUSS_MAX_CONNECTION:-5}
   UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io}
   UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify}
   UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
@@ -929,6 +936,28 @@ services:
     depends_on:
       - opensearch
 
+  # opengauss vector database.
+  opengauss:
+    image: opengauss/opengauss:7.0.0-RC1
+    profiles:
+      - opengauss
+    privileged: true
+    restart: always
+    environment:
+      GS_USERNAME: ${OPENGAUSS_USER:-postgres}
+      GS_PASSWORD: ${OPENGAUSS_PASSWORD:-Dify@123}
+      GS_PORT: ${OPENGAUSS_PORT:-6600}
+      GS_DB: ${OPENGAUSS_DATABASE:-dify}
+    volumes:
+      - ./volumes/opengauss/data:/var/lib/opengauss/data
+    healthcheck:
+      test: ["CMD-SHELL", "netstat -lntp | grep tcp6 > /dev/null 2>&1"]
+      interval: 10s
+      timeout: 10s
+      retries: 10
+    ports:
+      - ${OPENGAUSS_PORT:-6600}:${OPENGAUSS_PORT:-6600}
+
   # MyScale vector database
   myscale:
     container_name: myscale