瀏覽代碼

Feat/vector db pgvector (#3879)

LiuVaayne 11 月之前
父節點
當前提交
875249eb00

+ 3 - 1
.github/workflows/api-tests.yml

@@ -50,7 +50,7 @@ jobs:
       - name: Run Workflow
         run: dev/pytest/pytest_workflow.sh
 
-      - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
+      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
         uses: hoverkraft-tech/compose-action@v2.0.0
         with:
           compose-file: |
@@ -58,6 +58,7 @@ jobs:
             docker/docker-compose.qdrant.yaml
             docker/docker-compose.milvus.yaml
             docker/docker-compose.pgvecto-rs.yaml
+            docker/docker-compose.pgvector.yaml
           services: |
             weaviate
             qdrant
@@ -65,6 +66,7 @@ jobs:
             minio
             milvus-standalone
             pgvecto-rs
+            pgvector
 
       - name: Test Vector Stores
         run: dev/pytest/pytest_vdb.sh

+ 8 - 1
api/.env.example

@@ -65,7 +65,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri
 WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 
-# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
+# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector
 VECTOR_STORE=weaviate
 
 # Weaviate configuration
@@ -102,6 +102,13 @@ PGVECTO_RS_USER=postgres
 PGVECTO_RS_PASSWORD=difyai123456
 PGVECTO_RS_DATABASE=postgres
 
+# PGVector configuration
+PGVECTOR_HOST=127.0.0.1
+PGVECTOR_PORT=5433
+PGVECTOR_USER=postgres
+PGVECTOR_PASSWORD=postgres
+PGVECTOR_DATABASE=postgres
+
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5

+ 8 - 0
api/commands.py

@@ -305,6 +305,14 @@ def migrate_knowledge_vector_database():
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == "pgvector":
+                    dataset_id = dataset.id
+                    collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                    index_struct_dict = {
+                        "type": 'pgvector',
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
                 else:
                     raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 

+ 8 - 1
api/config.py

@@ -222,7 +222,7 @@ class Config:
 
         # ------------------------
         # Vector Store Configurations.
-        # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt
+        # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector
         # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
         self.KEYWORD_STORE = get_env('KEYWORD_STORE')
@@ -261,6 +261,13 @@ class Config:
         self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
         self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
 
+        # pgvector settings
+        self.PGVECTOR_HOST = get_env('PGVECTOR_HOST')
+        self.PGVECTOR_PORT = get_env('PGVECTOR_PORT')
+        self.PGVECTOR_USER = get_env('PGVECTOR_USER')
+        self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
+        self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
+
         # ------------------------
         # Mail Configurations.
         # ------------------------

+ 4 - 5
api/controllers/console/datasets/datasets.py

@@ -476,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource):
     @account_initialization_required
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
-        if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
+        if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
             return {
                 'retrieval_method': [
                     'semantic_search'
                 ]
             }
-        elif vector_type == 'qdrant' or vector_type == 'weaviate':
+        elif vector_type in {"qdrant", "weaviate"}:
             return {
                 'retrieval_method': [
                     'semantic_search', 'full_text_search', 'hybrid_search'
@@ -497,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource):
     @login_required
     @account_initialization_required
     def get(self, vector_type):
-
-        if vector_type == 'milvus' or vector_type == 'relyt':
+        if vector_type in {'milvus', 'relyt', 'pgvector'}:
             return {
                 'retrieval_method': [
                     'semantic_search'
                 ]
             }
-        elif vector_type == 'qdrant' or vector_type == 'weaviate':
+        elif vector_type in {'qdrant', 'weaviate'}:
             return {
                 'retrieval_method': [
                     'semantic_search', 'full_text_search', 'hybrid_search'

+ 0 - 0
api/core/rag/datasource/vdb/pgvector/__init__.py


+ 169 - 0
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -0,0 +1,169 @@
+import json
+import uuid
+from contextlib import contextmanager
+from typing import Any
+
+import psycopg2.extras
+import psycopg2.pool
+from pydantic import BaseModel, root_validator
+
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+
+class PGVectorConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    database: str
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values["host"]:
+            raise ValueError("config PGVECTOR_HOST is required")
+        if not values["port"]:
+            raise ValueError("config PGVECTOR_PORT is required")
+        if not values["user"]:
+            raise ValueError("config PGVECTOR_USER is required")
+        if not values["password"]:
+            raise ValueError("config PGVECTOR_PASSWORD is required")
+        if not values["database"]:
+            raise ValueError("config PGVECTOR_DATABASE is required")
+        return values
+
+
+SQL_CREATE_TABLE = """
+CREATE TABLE IF NOT EXISTS {table_name} (
+    id UUID PRIMARY KEY,
+    text TEXT NOT NULL,
+    meta JSONB NOT NULL,
+    embedding vector({dimension}) NOT NULL
+) using heap; 
+"""
+
+
+class PGVector(BaseVector):
+    def __init__(self, collection_name: str, config: PGVectorConfig):
+        super().__init__(collection_name)
+        self.pool = self._create_connection_pool(config)
+        self.table_name = f"embedding_{collection_name}"
+
+    def get_type(self) -> str:
+        return "pgvector"
+
+    def _create_connection_pool(self, config: PGVectorConfig):
+        return psycopg2.pool.SimpleConnectionPool(
+            1,
+            5,
+            host=config.host,
+            port=config.port,
+            user=config.user,
+            password=config.password,
+            database=config.database,
+        )
+
+    @contextmanager
+    def _get_cursor(self):
+        conn = self.pool.getconn()
+        cur = conn.cursor()
+        try:
+            yield cur
+        finally:
+            cur.close()
+            conn.commit()
+            self.pool.putconn(conn)
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        dimension = len(embeddings[0])
+        self._create_collection(dimension)
+        return self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        values = []
+        pks = []
+        for i, doc in enumerate(documents):
+            doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+            pks.append(doc_id)
+            values.append(
+                (
+                    doc_id,
+                    doc.page_content,
+                    json.dumps(doc.metadata),
+                    embeddings[i],
+                )
+            )
+        with self._get_cursor() as cur:
+            psycopg2.extras.execute_values(
+                cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
+            )
+        return pks
+
+    def text_exists(self, id: str) -> bool:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
+            return cur.fetchone() is not None
+
+    def get_by_ids(self, ids: list[str]) -> list[Document]:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+            docs = []
+            for record in cur:
+                docs.append(Document(page_content=record[1], metadata=record[0]))
+        return docs
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        """
+        Search the nearest neighbors to a vector.
+
+        :param query_vector: The input vector to search for similar items.
+        :param top_k: The number of nearest neighbors to return, default is 5.
+        :return: List of Documents that are nearest to the query vector.
+        """
+        top_k = kwargs.get("top_k", 5)
+
+        with self._get_cursor() as cur:
+            cur.execute(
+                f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}",
+                (json.dumps(query_vector),),
+            )
+            docs = []
+            score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
+            for record in cur:
+                metadata, text, distance = record
+                score = 1 - distance
+                metadata["score"] = score
+                if score > score_threshold:
+                    docs.append(Document(page_content=text, metadata=metadata))
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # do not support bm25 search
+        return []
+
+    def delete(self) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
+
+    def _create_collection(self, dimension: int):
+        cache_key = f"vector_indexing_{self._collection_name}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+            if redis_client.get(collection_exist_cache_key):
+                return
+
+            with self._get_cursor() as cur:
+                cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
+                cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
+                # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)

+ 23 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -164,6 +164,29 @@ class Vector:
                 ),
                 dim=dim
             )
+        elif vector_type == "pgvector":
+            from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
+
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
+                collection_name = class_prefix
+            else:
+                dataset_id = self._dataset.id
+                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                index_struct_dict = {
+                    "type": "pgvector",
+                    "vector_store": {"class_prefix": collection_name}}
+                self._dataset.index_struct = json.dumps(index_struct_dict)
+            return PGVector(
+                collection_name=collection_name,
+                config=PGVectorConfig(
+                    host=config.get("PGVECTOR_HOST"),
+                    port=config.get("PGVECTOR_PORT"),
+                    user=config.get("PGVECTOR_USER"),
+                    password=config.get("PGVECTOR_PASSWORD"),
+                    database=config.get("PGVECTOR_DATABASE"),
+                ),
+            )
         else:
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 

+ 1 - 0
api/requirements.txt

@@ -83,3 +83,4 @@ pydantic~=1.10.0
 pgvecto-rs==0.1.4
 firecrawl-py==0.0.5
 oss2==2.15.0
+pgvector==0.2.5

+ 0 - 0
api/tests/integration_tests/vdb/pgvector/__init__.py


+ 30 - 0
api/tests/integration_tests/vdb/pgvector/test_pgvector.py

@@ -0,0 +1,30 @@
+from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
+from core.rag.models.document import Document
+from tests.integration_tests.vdb.test_vector_store import (
+    AbstractVectorTest,
+    get_example_text,
+    setup_mock_redis,
+)
+
+
+class TestPGVector(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = PGVector(
+            collection_name=self.collection_name,
+            config=PGVectorConfig(
+                host="localhost",
+                port=5433,
+                user="postgres",
+                password="difyai123456",
+                database="dify",
+            ),
+        )
+
+    def search_by_full_text(self):
+        hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+
+def test_pgvector(setup_mock_redis):
+    TestPGVector().run_all_tests()

+ 24 - 0
docker/docker-compose.pgvector.yaml

@@ -0,0 +1,24 @@
+version: '3'
+services:
+  # Qdrant vector store.
+  pgvector:
+    image: pgvector/pgvector:pg16
+    restart: always
+    environment:
+      PGUSER: postgres
+      # The password for the default postgres user.
+      POSTGRES_PASSWORD: difyai123456
+      # The name of the default postgres database.
+      POSTGRES_DB: dify
+      # postgres data directory
+      PGDATA: /var/lib/postgresql/data/pgdata
+    volumes:
+      - ./volumes/pgvector/data:/var/lib/postgresql/data
+    # uncomment to expose db(postgresql) port to host
+    ports:
+      - "5433:5432"
+    healthcheck:
+      test: [ "CMD", "pg_isready" ]
+      interval: 1s
+      timeout: 3s
+      retries: 30

+ 38 - 1
docker/docker-compose.yaml

@@ -122,6 +122,12 @@ services:
       RELYT_USER: postgres
       RELYT_PASSWORD: difyai123456
       RELYT_DATABASE: postgres
+      # pgvector configurations
+      PGVECTOR_HOST: pgvector
+      PGVECTOR_PORT: 5432
+      PGVECTOR_USER: postgres
+      PGVECTOR_PASSWORD: difyai123456
+      PGVECTOR_DATABASE: dify
       # Mail configuration, support: resend, smtp
       MAIL_TYPE: ''
       # default send from email address, if not specified
@@ -211,7 +217,7 @@ services:
       AZURE_BLOB_ACCOUNT_KEY: 'difyai'
       AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
       AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
-      # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
+      # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`, `pgvector`.
       VECTOR_STORE: weaviate
       # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
       WEAVIATE_ENDPOINT: http://weaviate:8080
@@ -251,6 +257,12 @@ services:
       RELYT_USER: postgres
       RELYT_PASSWORD: difyai123456
       RELYT_DATABASE: postgres
+      # pgvector configurations
+      PGVECTOR_HOST: pgvector
+      PGVECTOR_PORT: 5432
+      PGVECTOR_USER: postgres
+      PGVECTOR_PASSWORD: difyai123456
+      PGVECTOR_DATABASE: dify
       # Notion import configuration, support public and internal
       NOTION_INTEGRATION_TYPE: public
       NOTION_CLIENT_SECRET: you-client-secret
@@ -374,6 +386,31 @@ services:
   #   #  - "6333:6333"
   #   #  - "6334:6334"
 
+  # The pgvector vector database.
+  # Uncomment to use qdrant as vector store.
+  # pgvector:
+  #   image: pgvector/pgvector:pg16
+  #   restart: always
+  #   environment:
+  #     PGUSER: postgres
+  #     # The password for the default postgres user.
+  #     POSTGRES_PASSWORD: difyai123456
+  #     # The name of the default postgres database.
+  #     POSTGRES_DB: dify
+  #     # postgres data directory
+  #     PGDATA: /var/lib/postgresql/data/pgdata
+  #   volumes:
+  #     - ./volumes/pgvector/data:/var/lib/postgresql/data
+  #   # uncomment to expose db(postgresql) port to host
+  #   # ports:
+  #   #   - "5433:5432"
+  #   healthcheck:
+  #     test: [ "CMD", "pg_isready" ]
+  #     interval: 1s
+  #     timeout: 3s
+  #     retries: 30
+
+
   # The nginx reverse proxy.
   # used for reverse proxying the API service and Web service.
   nginx: