Bläddra i källkod

feat:support baidu vector db (#9185)

Shili Cao 6 månader sedan
förälder
incheckning
2ec6ffe478

+ 9 - 0
api/.env.example

@@ -208,6 +208,15 @@ OPENSEARCH_USER=admin
 OPENSEARCH_PASSWORD=admin
 OPENSEARCH_SECURE=true
 
+# Baidu configuration
+BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
+BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
+BAIDU_VECTOR_DB_ACCOUNT=root
+BAIDU_VECTOR_DB_API_KEY=dify
+BAIDU_VECTOR_DB_DATABASE=dify
+BAIDU_VECTOR_DB_SHARD=1
+BAIDU_VECTOR_DB_REPLICAS=3
+
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5

+ 8 - 0
api/commands.py

@@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
                     index_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
                     dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == VectorType.BAIDU:
+                    dataset_id = dataset.id
+                    collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                    index_struct_dict = {
+                        "type": VectorType.BAIDU,
+                        "vector_store": {"class_prefix": collection_name},
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
                 else:
                     raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 45 - 0
api/configs/middleware/vdb/baidu_vector_config.py

@@ -0,0 +1,45 @@
+from typing import Optional
+
+from pydantic import Field, NonNegativeInt, PositiveInt
+from pydantic_settings import BaseSettings
+
+
+class BaiduVectorDBConfig(BaseSettings):
+    """
+    Configuration settings for Baidu Vector Database
+    """
+
+    BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
+        description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
+        default=None,
+    )
+
+    BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
+        description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
+        default=30000,
+    )
+
+    BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
+        description="Account for authenticating with the Baidu Vector Database",
+        default=None,
+    )
+
+    BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
+        description="API key for authenticating with the Baidu Vector Database service",
+        default=None,
+    )
+
+    BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
+        description="Name of the specific Baidu Vector Database to connect to",
+        default=None,
+    )
+
+    BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
+        description="Number of shards for the Baidu Vector Database (default is 1)",
+        default=1,
+    )
+
+    BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
+        description="Number of replicas for the Baidu Vector Database (default is 3)",
+        default=3,
+    )

+ 2 - 0
api/controllers/console/datasets/datasets.py

@@ -617,6 +617,7 @@ class DatasetRetrievalSettingApi(Resource):
                 | VectorType.CHROMA
                 | VectorType.TENCENT
                 | VectorType.PGVECTO_RS
+                | VectorType.BAIDU
             ):
                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
             case (
@@ -653,6 +654,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                 | VectorType.CHROMA
                 | VectorType.TENCENT
                 | VectorType.PGVECTO_RS
+                | VectorType.BAIDU
             ):
                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
             case (

+ 0 - 0
api/core/rag/datasource/vdb/baidu/__init__.py


+ 272 - 0
api/core/rag/datasource/vdb/baidu/baidu_vector.py

@@ -0,0 +1,272 @@
+import json
+import time
+import uuid
+from typing import Any
+
+from pydantic import BaseModel, model_validator
+from pymochow import MochowClient
+from pymochow.auth.bce_credentials import BceCredentials
+from pymochow.configuration import Configuration
+from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
+from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
+from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
+
+from configs import dify_config
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+
+class BaiduConfig(BaseModel):
+    endpoint: str
+    connection_timeout_in_mills: int = 30 * 1000
+    account: str
+    api_key: str
+    database: str
+    index_type: str = "HNSW"
+    metric_type: str = "L2"
+    shard: int = 1
+    replicas: int = 3
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_config(cls, values: dict) -> dict:
+        if not values["endpoint"]:
+            raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
+        if not values["account"]:
+            raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
+        if not values["api_key"]:
+            raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
+        if not values["database"]:
+            raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
+        return values
+
+
+class BaiduVector(BaseVector):
+    field_id: str = "id"
+    field_vector: str = "vector"
+    field_text: str = "text"
+    field_metadata: str = "metadata"
+    field_app_id: str = "app_id"
+    field_annotation_id: str = "annotation_id"
+    index_vector: str = "vector_idx"
+
+    def __init__(self, collection_name: str, config: BaiduConfig):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = self._init_client(config)
+        self._db = self._init_database()
+
+    def get_type(self) -> str:
+        return VectorType.BAIDU
+
+    def to_index_struct(self) -> dict:
+        return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        self._create_table(len(embeddings[0]))
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        texts = [doc.page_content for doc in documents]
+        metadatas = [doc.metadata for doc in documents]
+        total_count = len(documents)
+        batch_size = 1000
+
+        # upsert texts and embeddings batch by batch
+        table = self._db.table(self._collection_name)
+        for start in range(0, total_count, batch_size):
+            end = min(start + batch_size, total_count)
+            rows = []
+            for i in range(start, end, 1):
+                row = Row(
+                    id=metadatas[i].get("doc_id", str(uuid.uuid4())),
+                    vector=embeddings[i],
+                    text=texts[i],
+                    metadata=json.dumps(metadatas[i]),
+                    app_id=metadatas[i].get("app_id", ""),
+                    annotation_id=metadatas[i].get("annotation_id", ""),
+                )
+                rows.append(row)
+            table.upsert(rows=rows)
+
+        # rebuild vector index after upsert finished
+        table.rebuild_index(self.index_vector)
+        while True:
+            time.sleep(1)
+            index = table.describe_index(self.index_vector)
+            if index.state == IndexState.NORMAL:
+                break
+
+    def text_exists(self, id: str) -> bool:
+        res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
+        if res and res.code == 0:
+            return True
+        return False
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        quoted_ids = [f"'{id}'" for id in ids]
+        self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        anns = AnnSearch(
+            vector_field=self.field_vector,
+            vector_floats=query_vector,
+            params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
+        )
+        res = self._db.table(self._collection_name).search(
+            anns=anns,
+            projections=[self.field_id, self.field_text, self.field_metadata],
+            retrieve_vector=True,
+        )
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
+        return self._get_search_res(res, score_threshold)
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # baidu vector database doesn't support bm25 search on current version
+        return []
+
+    def _get_search_res(self, res, score_threshold):
+        docs = []
+        for row in res.rows:
+            row_data = row.get("row", {})
+            meta = row_data.get(self.field_metadata)
+            if meta is not None:
+                meta = json.loads(meta)
+            score = row.get("score", 0.0)
+            if score > score_threshold:
+                meta["score"] = score
+                doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
+                docs.append(doc)
+
+        return docs
+
+    def delete(self) -> None:
+        self._db.drop_table(table_name=self._collection_name)
+
+    def _init_client(self, config) -> MochowClient:
+        config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
+        client = MochowClient(config)
+        return client
+
+    def _init_database(self):
+        exists = False
+        for db in self._client.list_databases():
+            if db.database_name == self._client_config.database:
+                exists = True
+                break
+        # Create database if not existed
+        if exists:
+            return self._client.database(self._client_config.database)
+        else:
+            return self._client.create_database(database_name=self._client_config.database)
+
+    def _table_existed(self) -> bool:
+        tables = self._db.list_table()
+        return any(table.table_name == self._collection_name for table in tables)
+
+    def _create_table(self, dimension: int) -> None:
+        # Try to grab distributed lock and create table
+        lock_name = "vector_indexing_lock_{}".format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
+            if redis_client.get(table_exist_cache_key):
+                return
+
+            if self._table_existed():
+                return
+
+            self.delete()
+
+            # check IndexType and MetricType
+            index_type = None
+            for k, v in IndexType.__members__.items():
+                if k == self._client_config.index_type:
+                    index_type = v
+            if index_type is None:
+                raise ValueError("unsupported index_type")
+            metric_type = None
+            for k, v in MetricType.__members__.items():
+                if k == self._client_config.metric_type:
+                    metric_type = v
+            if metric_type is None:
+                raise ValueError("unsupported metric_type")
+
+            # Construct field schema
+            fields = []
+            fields.append(
+                Field(
+                    self.field_id,
+                    FieldType.STRING,
+                    primary_key=True,
+                    partition_key=True,
+                    auto_increment=False,
+                    not_null=True,
+                )
+            )
+            fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
+            fields.append(Field(self.field_app_id, FieldType.STRING))
+            fields.append(Field(self.field_annotation_id, FieldType.STRING))
+            fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
+            fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
+
+            # Construct vector index params
+            indexes = []
+            indexes.append(
+                VectorIndex(
+                    index_name="vector_idx",
+                    index_type=index_type,
+                    field="vector",
+                    metric_type=metric_type,
+                    params=HNSWParams(m=16, efconstruction=200),
+                )
+            )
+
+            # Create table
+            self._db.create_table(
+                table_name=self._collection_name,
+                replication=self._client_config.replicas,
+                partition=Partition(partition_num=self._client_config.shard),
+                schema=Schema(fields=fields, indexes=indexes),
+                description="Table for Dify",
+            )
+
+            redis_client.set(table_exist_cache_key, 1, ex=3600)
+
+        # Wait for table created
+        while True:
+            time.sleep(1)
+            table = self._db.describe_table(self._collection_name)
+            if table.state == TableState.NORMAL:
+                break
+
+
+class BaiduVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))
+
+        return BaiduVector(
+            collection_name=collection_name,
+            config=BaiduConfig(
+                endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
+                connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
+                account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
+                api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
+                database=dify_config.BAIDU_VECTOR_DB_DATABASE,
+                shard=dify_config.BAIDU_VECTOR_DB_SHARD,
+                replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
+            ),
+        )

+ 4 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -103,6 +103,10 @@ class Vector:
                 from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
 
                 return AnalyticdbVectorFactory
+            case VectorType.BAIDU:
+                from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
+
+                return BaiduVectorFactory
             case _:
                 raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -16,3 +16,4 @@ class VectorType(str, Enum):
     TENCENT = "tencent"
     ORACLE = "oracle"
     ELASTICSEARCH = "elasticsearch"
+    BAIDU = "baidu"

+ 34 - 13
api/poetry.lock

@@ -732,7 +732,7 @@ name = "bce-python-sdk"
 version = "0.9.23"
 description = "BCE SDK for python"
 optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4"
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7"
 files = [
     {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"},
     {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"},
@@ -847,7 +847,7 @@ name = "botocore"
 version = "1.35.38"
 description = "Low-level, data-driven core of boto 3."
 optional = false
-python-versions = ">= 3.8"
+python-versions = ">=3.8"
 files = [
     {file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"},
     {file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"},
@@ -1068,7 +1068,7 @@ name = "build"
 version = "1.2.2.post1"
 description = "A simple, correct Python build frontend"
 optional = false
-python-versions = ">= 3.8"
+python-versions = ">=3.8"
 files = [
     {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"},
     {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"},
@@ -3385,7 +3385,7 @@ name = "gotrue"
 version = "2.9.2"
 description = "Python Client Library for Supabase Auth"
 optional = false
-python-versions = ">=3.8,<4.0"
+python-versions = "<4.0,>=3.8"
 files = [
     {file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"},
     {file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"},
@@ -4415,7 +4415,7 @@ name = "langfuse"
 version = "2.51.5"
 description = "A client library for accessing langfuse"
 optional = false
-python-versions = ">=3.8.1,<4.0"
+python-versions = "<4.0,>=3.8.1"
 files = [
     {file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"},
     {file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"},
@@ -4440,7 +4440,7 @@ name = "langsmith"
 version = "0.1.134"
 description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
 optional = false
-python-versions = ">=3.8.1,<4.0"
+python-versions = "<4.0,>=3.8.1"
 files = [
     {file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"},
     {file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"},
@@ -6429,7 +6429,7 @@ name = "postgrest"
 version = "0.17.1"
 description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
 optional = false
-python-versions = ">=3.8,<4.0"
+python-versions = "<4.0,>=3.8"
 files = [
     {file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"},
     {file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"},
@@ -7047,6 +7047,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r
 dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"]
 model = ["milvus-model (>=0.1.0)"]
 
+[[package]]
+name = "pymochow"
+version = "1.3.1"
+description = "Python SDK for mochow"
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"},
+    {file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"},
+]
+
+[package.dependencies]
+future = "*"
+orjson = "*"
+requests = "*"
+
 [[package]]
 name = "pymysql"
 version = "1.1.1"
@@ -7746,7 +7762,7 @@ name = "realtime"
 version = "2.0.2"
 description = ""
 optional = false
-python-versions = ">=3.9,<4.0"
+python-versions = "<4.0,>=3.9"
 files = [
     {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"},
     {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"},
@@ -8173,7 +8189,7 @@ name = "s3transfer"
 version = "0.10.3"
 description = "An Amazon S3 Transfer Manager"
 optional = false
-python-versions = ">= 3.8"
+python-versions = ">=3.8"
 files = [
     {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
     {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
@@ -8417,6 +8433,11 @@ files = [
     {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
     {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
     {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
+    {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
     {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
     {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
     {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
@@ -8836,7 +8857,7 @@ name = "storage3"
 version = "0.8.1"
 description = "Supabase Storage client for Python."
 optional = false
-python-versions = ">=3.8,<4.0"
+python-versions = "<4.0,>=3.8"
 files = [
     {file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"},
     {file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"},
@@ -8882,7 +8903,7 @@ name = "supabase"
 version = "2.8.1"
 description = "Supabase client for Python."
 optional = false
-python-versions = ">=3.9,<4.0"
+python-versions = "<4.0,>=3.9"
 files = [
     {file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"},
     {file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"},
@@ -8902,7 +8923,7 @@ name = "supafunc"
 version = "0.6.1"
 description = "Library for Supabase Functions"
 optional = false
-python-versions = ">=3.8,<4.0"
+python-versions = "<4.0,>=3.8"
 files = [
     {file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"},
     {file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"},
@@ -10615,4 +10636,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
-content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21"
+content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774"

+ 1 - 0
api/pyproject.toml

@@ -242,6 +242,7 @@ oracledb = "~2.2.1"
 pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
 pgvector = "0.2.5"
 pymilvus = "~2.4.4"
+pymochow = "1.3.1"
 qdrant-client = "1.7.3"
 tcvectordb = "1.3.2"
 tidb-vector = "0.0.9"

+ 154 - 0
api/tests/integration_tests/vdb/__mock/baiduvectordb.py

@@ -0,0 +1,154 @@
+import os
+
+import pytest
+from _pytest.monkeypatch import MonkeyPatch
+from pymochow import MochowClient
+from pymochow.model.database import Database
+from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
+from pymochow.model.schema import HNSWParams, VectorIndex
+from pymochow.model.table import Table
+from requests.adapters import HTTPAdapter
+
+
+class MockBaiduVectorDBClass:
+    def mock_vector_db_client(
+        self,
+        config=None,
+        adapter: HTTPAdapter = None,
+    ):
+        self._conn = None
+        self._config = None
+
+    def list_databases(self, config=None) -> list[Database]:
+        return [
+            Database(
+                conn=self._conn,
+                database_name="dify",
+                config=self._config,
+            )
+        ]
+
+    def create_database(self, database_name: str, config=None) -> Database:
+        return Database(conn=self._conn, database_name=database_name, config=config)
+
+    def list_table(self, config=None) -> list[Table]:
+        return []
+
+    def drop_table(self, table_name: str, config=None):
+        return {"code": 0, "msg": "Success"}
+
+    def create_table(
+        self,
+        table_name: str,
+        replication: int,
+        partition: int,
+        schema,
+        enable_dynamic_field=False,
+        description: str = "",
+        config=None,
+    ) -> Table:
+        return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
+
+    def describe_table(self, table_name: str, config=None) -> Table:
+        return Table(
+            self,
+            table_name,
+            3,
+            1,
+            None,
+            enable_dynamic_field=False,
+            description="table for dify",
+            config=config,
+            state=TableState.NORMAL,
+        )
+
+    def upsert(self, rows, config=None):
+        return {"code": 0, "msg": "operation success", "affectedCount": 1}
+
+    def rebuild_index(self, index_name: str, config=None):
+        return {"code": 0, "msg": "Success"}
+
+    def describe_index(self, index_name: str, config=None):
+        return VectorIndex(
+            index_name=index_name,
+            index_type=IndexType.HNSW,
+            field="vector",
+            metric_type=MetricType.L2,
+            params=HNSWParams(m=16, efconstruction=200),
+            auto_build=False,
+            state=IndexState.NORMAL,
+        )
+
+    def query(
+        self,
+        primary_key,
+        partition_key=None,
+        projections=None,
+        retrieve_vector=False,
+        read_consistency=ReadConsistency.EVENTUAL,
+        config=None,
+    ):
+        return {
+            "row": {
+                "id": "doc_id_001",
+                "vector": [0.23432432, 0.8923744, 0.89238432],
+                "text": "text",
+                "metadata": {"doc_id": "doc_id_001"},
+            },
+            "code": 0,
+            "msg": "Success",
+        }
+
+    def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
+        return {"code": 0, "msg": "Success"}
+
+    def search(
+        self,
+        anns,
+        partition_key=None,
+        projections=None,
+        retrieve_vector=False,
+        read_consistency=ReadConsistency.EVENTUAL,
+        config=None,
+    ):
+        return {
+            "rows": [
+                {
+                    "row": {
+                        "id": "doc_id_001",
+                        "vector": [0.23432432, 0.8923744, 0.89238432],
+                        "text": "text",
+                        "metadata": {"doc_id": "doc_id_001"},
+                    },
+                    "distance": 0.1,
+                    "score": 0.5,
+                }
+            ],
+            "code": 0,
+            "msg": "Success",
+        }
+
+
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
+
+@pytest.fixture
+def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
+    if MOCK:
+        monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
+        monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
+        monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
+        monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
+        monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
+        monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
+        monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
+        monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
+        monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
+        monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
+        monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
+        monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
+
+    yield
+
+    if MOCK:
+        monkeypatch.undo()

+ 0 - 0
api/tests/integration_tests/vdb/baidu/__init__.py


+ 36 - 0
api/tests/integration_tests/vdb/baidu/test_baidu.py

@@ -0,0 +1,36 @@
+from unittest.mock import MagicMock
+
+from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
+from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
+from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
+
+mock_client = MagicMock()
+mock_client.list_databases.return_value = [{"name": "test"}]
+
+
+class BaiduVectorTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = BaiduVector(
+            "dify",
+            BaiduConfig(
+                endpoint="http://127.0.0.1:5287",
+                account="root",
+                api_key="dify",
+                database="dify",
+                shard=1,
+                replicas=3,
+            ),
+        )
+
+    def search_by_vector(self):
+        hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
+        assert len(hits_by_vector) == 1
+
+    def search_by_full_text(self):
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+
+def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
+    BaiduVectorTest().run_all_tests()

+ 9 - 0
docker/.env.example

@@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200
 ELASTICSEARCH_USERNAME=elastic
 ELASTICSEARCH_PASSWORD=elastic
 
+# baidu vector configurations, only available when VECTOR_STORE is `baidu`
+BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
+BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
+BAIDU_VECTOR_DB_ACCOUNT=root
+BAIDU_VECTOR_DB_API_KEY=dify
+BAIDU_VECTOR_DB_DATABASE=dify
+BAIDU_VECTOR_DB_SHARD=1
+BAIDU_VECTOR_DB_REPLICAS=3
+
 # ------------------------------
 # Knowledge Configuration
 # ------------------------------

+ 7 - 0
docker/docker-compose.yaml

@@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env
   TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify}
   TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1}
   TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2}
+  BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287}
+  BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000}
+  BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root}
+  BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify}
+  BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
+  BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
+  BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
   UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
   UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
   ETL_TYPE: ${ETL_TYPE:-dify}