浏览代码

Add Volcengine VikingDB as new vector provider (#9287)

ice yao 6 月之前
父节点
当前提交
d15ba3939d

+ 10 - 1
api/.env.example

@@ -111,7 +111,7 @@ SUPABASE_URL=your-server-url
 WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 
-# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
+# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
 VECTOR_STORE=weaviate
 
 # Weaviate configuration
@@ -220,6 +220,15 @@ BAIDU_VECTOR_DB_DATABASE=dify
 BAIDU_VECTOR_DB_SHARD=1
 BAIDU_VECTOR_DB_REPLICAS=3
 
+# ViKingDB configuration
+VIKINGDB_ACCESS_KEY=your-ak
+VIKINGDB_SECRET_KEY=your-sk
+VIKINGDB_REGION=cn-shanghai
+VIKINGDB_HOST=api-vikingdb.xxx.volces.com
+VIKINGDB_SCHEMA=http
+VIKINGDB_CONNECTION_TIMEOUT=30
+VIKINGDB_SOCKET_TIMEOUT=30
+
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5

+ 2 - 0
api/configs/middleware/__init__.py

@@ -28,6 +28,7 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
 from configs.middleware.vdb.relyt_config import RelytConfig
 from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
 from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
+from configs.middleware.vdb.vikingdb_config import VikingDBConfig
 from configs.middleware.vdb.weaviate_config import WeaviateConfig
 
 
@@ -243,5 +244,6 @@ class MiddlewareConfig(
     WeaviateConfig,
     ElasticsearchConfig,
     InternalTestConfig,
+    VikingDBConfig,
 ):
     pass

+ 37 - 0
api/configs/middleware/vdb/vikingdb_config.py

@@ -0,0 +1,37 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+
+class VikingDBConfig(BaseModel):
+    """
+    Configuration for connecting to Volcengine VikingDB.
+    Refer to the following documentation for details on obtaining credentials:
+    https://www.volcengine.com/docs/6291/65568
+    """
+
+    VIKINGDB_ACCESS_KEY: Optional[str] = Field(
+        default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
+    )
+    VIKINGDB_SECRET_KEY: Optional[str] = Field(
+        default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
+    )
+    VIKINGDB_REGION: Optional[str] = Field(
+        default="cn-shanghai",
+        description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
+    )
+    VIKINGDB_HOST: Optional[str] = Field(
+        default="api-vikingdb.mlp.cn-shanghai.volces.com",
+        description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
+            'api-vikingdb.mlp.cn-shanghai.volces.com')",
+    )
+    VIKINGDB_SCHEME: Optional[str] = Field(
+        default="http",
+        description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
+    )
+    VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
+        default=30, description="The connection timeout of the Volcengine VikingDB service."
+    )
+    VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
+        default=30, description="The socket timeout of the Volcengine VikingDB service."
+    )

+ 2 - 0
api/controllers/console/datasets/datasets.py

@@ -618,6 +618,7 @@ class DatasetRetrievalSettingApi(Resource):
                 | VectorType.TENCENT
                 | VectorType.PGVECTO_RS
                 | VectorType.BAIDU
+                | VectorType.VIKINGDB
             ):
                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
             case (
@@ -655,6 +656,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                 | VectorType.TENCENT
                 | VectorType.PGVECTO_RS
                 | VectorType.BAIDU
+                | VectorType.VIKINGDB
             ):
                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
             case (

+ 4 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -107,6 +107,10 @@ class Vector:
                 from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
 
                 return BaiduVectorFactory
+            case VectorType.VIKINGDB:
+                from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
+
+                return VikingDBVectorFactory
             case _:
                 raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -17,3 +17,4 @@ class VectorType(str, Enum):
     ORACLE = "oracle"
     ELASTICSEARCH = "elasticsearch"
     BAIDU = "baidu"
+    VIKINGDB = "vikingdb"

+ 0 - 0
api/core/rag/datasource/vdb/vikingdb/__init__.py


+ 239 - 0
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py

@@ -0,0 +1,239 @@
+import json
+from typing import Any
+
+from pydantic import BaseModel
+from volcengine.viking_db import (
+    Data,
+    DistanceType,
+    Field,
+    FieldType,
+    IndexType,
+    QuantType,
+    VectorIndexParams,
+    VikingDBService,
+)
+
+from configs import dify_config
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.field import Field as vdb_Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+
+class VikingDBConfig(BaseModel):
+    access_key: str
+    secret_key: str
+    host: str
+    region: str
+    scheme: str
+    connection_timeout: int
+    socket_timeout: int
+    index_type: str = IndexType.HNSW
+    distance: str = DistanceType.L2
+    quant: str = QuantType.Float
+
+
+class VikingDBVector(BaseVector):
+    def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig):
+        super().__init__(collection_name)
+        self._group_id = group_id
+        self._client_config = config
+        self._index_name = f"{self._collection_name}_idx"
+        self._client = VikingDBService(
+            host=config.host,
+            region=config.region,
+            scheme=config.scheme,
+            connection_timeout=config.connection_timeout,
+            socket_timeout=config.socket_timeout,
+            ak=config.access_key,
+            sk=config.secret_key,
+        )
+
+    def _has_collection(self) -> bool:
+        try:
+            self._client.get_collection(self._collection_name)
+        except Exception:
+            return False
+        return True
+
+    def _has_index(self) -> bool:
+        try:
+            self._client.get_index(self._collection_name, self._index_name)
+        except Exception:
+            return False
+        return True
+
+    def _create_collection(self, dimension: int):
+        lock_name = f"vector_indexing_lock_{self._collection_name}"
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+            if redis_client.get(collection_exist_cache_key):
+                return
+
+            if not self._has_collection():
+                fields = [
+                    Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
+                    Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
+                    Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
+                    Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
+                    Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
+                ]
+
+                self._client.create_collection(
+                    collection_name=self._collection_name,
+                    fields=fields,
+                    description="Collection For Dify",
+                )
+
+            if not self._has_index():
+                vector_index = VectorIndexParams(
+                    distance=self._client_config.distance,
+                    index_type=self._client_config.index_type,
+                    quant=self._client_config.quant,
+                )
+
+                self._client.create_index(
+                    collection_name=self._collection_name,
+                    index_name=self._index_name,
+                    vector_index=vector_index,
+                    partition_by=vdb_Field.GROUP_KEY.value,
+                    description="Index For Dify",
+                )
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def get_type(self) -> str:
+        return VectorType.VIKINGDB
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        dimension = len(embeddings[0])
+        self._create_collection(dimension)
+        self.add_texts(texts, embeddings, **kwargs)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        page_contents = [doc.page_content for doc in documents]
+        metadatas = [doc.metadata for doc in documents]
+        docs = []
+
+        for i, page_content in enumerate(page_contents):
+            metadata = {}
+            if metadatas is not None:
+                for key, val in metadatas[i].items():
+                    metadata[key] = val
+            doc = Data(
+                {
+                    vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
+                    vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
+                    vdb_Field.CONTENT_KEY.value: page_content,
+                    vdb_Field.METADATA_KEY.value: json.dumps(metadata),
+                    vdb_Field.GROUP_KEY.value: self._group_id,
+                }
+            )
+            docs.append(doc)
+
+        self._client.get_collection(self._collection_name).upsert_data(docs)
+
+    def text_exists(self, id: str) -> bool:
+        docs = self._client.get_collection(self._collection_name).fetch_data(id)
+        not_exists_str = "data does not exist"
+        if docs is not None and not_exists_str not in docs.fields.get("message", ""):
+            return True
+        return False
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._client.get_collection(self._collection_name).delete_data(ids)
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        # Note: Metadata field value is an dict, but vikingdb field
+        # not support json type
+        results = self._client.get_index(self._collection_name, self._index_name).search(
+            filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
+            # max value is 5000
+            limit=5000,
+        )
+
+        if not results:
+            return []
+
+        ids = []
+        for result in results:
+            metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
+            if metadata is not None:
+                metadata = json.loads(metadata)
+                if metadata.get(key) == value:
+                    ids.append(result.id)
+        return ids
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        ids = self.get_ids_by_metadata_field(key, value)
+        self.delete_by_ids(ids)
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
+            query_vector, limit=kwargs.get("top_k", 50)
+        )
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
+        return self._get_search_res(results, score_threshold)
+
+    def _get_search_res(self, results, score_threshold):
+        if len(results) == 0:
+            return []
+
+        docs = []
+        for result in results:
+            metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
+            if metadata is not None:
+                metadata = json.loads(metadata)
+            if result.score > score_threshold:
+                metadata["score"] = result.score
+                doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
+                docs.append(doc)
+        docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        return []
+
+    def delete(self) -> None:
+        if self._has_index():
+            self._client.drop_index(self._collection_name, self._index_name)
+        if self._has_collection():
+            self._client.drop_collection(self._collection_name)
+
+
+class VikingDBVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name))
+
+        if dify_config.VIKINGDB_ACCESS_KEY is None:
+            raise ValueError("VIKINGDB_ACCESS_KEY should not be None")
+        if dify_config.VIKINGDB_SECRET_KEY is None:
+            raise ValueError("VIKINGDB_SECRET_KEY should not be None")
+        if dify_config.VIKINGDB_HOST is None:
+            raise ValueError("VIKINGDB_HOST should not be None")
+        if dify_config.VIKINGDB_REGION is None:
+            raise ValueError("VIKINGDB_REGION should not be None")
+        if dify_config.VIKINGDB_SCHEME is None:
+            raise ValueError("VIKINGDB_SCHEME should not be None")
+        return VikingDBVector(
+            collection_name=collection_name,
+            group_id=dataset.id,
+            config=VikingDBConfig(
+                access_key=dify_config.VIKINGDB_ACCESS_KEY,
+                secret_key=dify_config.VIKINGDB_SECRET_KEY,
+                host=dify_config.VIKINGDB_HOST,
+                region=dify_config.VIKINGDB_REGION,
+                scheme=dify_config.VIKINGDB_SCHEME,
+                connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT,
+                socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT,
+            ),
+        )

+ 72 - 1
api/poetry.lock

@@ -2038,6 +2038,17 @@ packaging = ">=17.0"
 pandas = ">=0.24.2"
 pyarrow = ">=3.0.0"
 
+[[package]]
+name = "decorator"
+version = "5.1.1"
+description = "Decorators for Humans"
+optional = false
+python-versions = ">=3.5"
+files = [
+    {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
+    {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
+]
+
 [[package]]
 name = "defusedxml"
 version = "0.7.1"
@@ -3027,6 +3038,20 @@ files = [
 docs = ["sphinx (>=4)", "sphinx-rtd-theme (>=1)"]
 tests = ["cython", "hypothesis", "mpmath", "pytest", "setuptools"]
 
+[[package]]
+name = "google"
+version = "3.0.0"
+description = "Python bindings to the Google search engine."
+optional = false
+python-versions = "*"
+files = [
+    {file = "google-3.0.0-py2.py3-none-any.whl", hash = "sha256:889cf695f84e4ae2c55fbc0cfdaf4c1e729417fa52ab1db0485202ba173e4935"},
+    {file = "google-3.0.0.tar.gz", hash = "sha256:143530122ee5130509ad5e989f0512f7cb218b2d4eddbafbad40fd10e8d8ccbe"},
+]
+
+[package.dependencies]
+beautifulsoup4 = "*"
+
 [[package]]
 name = "google-ai-generativelanguage"
 version = "0.6.9"
@@ -6670,6 +6695,17 @@ files = [
     {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"},
 ]
 
+[[package]]
+name = "py"
+version = "1.11.0"
+description = "library with cross-python path, ini-parsing, io, code, log facilities"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+    {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
+    {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+]
+
 [[package]]
 name = "py-cpuinfo"
 version = "9.0.0"
@@ -8012,6 +8048,21 @@ files = [
 [package.dependencies]
 requests = "2.31.0"
 
+[[package]]
+name = "retry"
+version = "0.9.2"
+description = "Easy to use retry decorator."
+optional = false
+python-versions = "*"
+files = [
+    {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
+    {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
+]
+
+[package.dependencies]
+decorator = ">=3.4.2"
+py = ">=1.4.26,<2.0.0"
+
 [[package]]
 name = "rich"
 version = "13.9.2"
@@ -9829,6 +9880,26 @@ files = [
     {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
 ]
 
+[[package]]
+name = "volcengine-compat"
+version = "1.0.156"
+description = "Be Compatible with the Volcengine SDK for Python, The version of package dependencies has been modified. like pycryptodome, pytz."
+optional = false
+python-versions = "*"
+files = [
+    {file = "volcengine_compat-1.0.156-py3-none-any.whl", hash = "sha256:4abc149a7601ebad8fa2d28fab50c7945145cf74daecb71bca797b0bdc82c5a5"},
+    {file = "volcengine_compat-1.0.156.tar.gz", hash = "sha256:e357d096828e31a202dc6047bbc5bf6fff3f54a98cd35a99ab5f965ea741a267"},
+]
+
+[package.dependencies]
+google = ">=3.0.0"
+protobuf = ">=3.18.3"
+pycryptodome = ">=3.9.9"
+pytz = ">=2020.5"
+requests = ">=2.25.1"
+retry = ">=0.9.2"
+six = ">=1.0"
+
 [[package]]
 name = "volcengine-python-sdk"
 version = "1.0.103"
@@ -10636,4 +10707,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
-content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774"
+content-hash = "edb5e3b0d50e84a239224cc77f3f615fdbdd6b504bce5b1075b29363f3054957"

+ 1 - 0
api/pyproject.toml

@@ -246,6 +246,7 @@ pymochow = "1.3.1"
 qdrant-client = "1.7.3"
 tcvectordb = "1.3.2"
 tidb-vector = "0.0.9"
+volcengine-compat = "~1.0.156"
 weaviate-client = "~3.21.0"
 
 ############################################################

+ 215 - 0
api/tests/integration_tests/vdb/__mock/vikingdb.py

@@ -0,0 +1,215 @@
+import os
+from typing import Union
+from unittest.mock import MagicMock
+
+import pytest
+from _pytest.monkeypatch import MonkeyPatch
+from volcengine.viking_db import (
+    Collection,
+    Data,
+    DistanceType,
+    Field,
+    FieldType,
+    Index,
+    IndexType,
+    QuantType,
+    VectorIndexParams,
+    VikingDBService,
+)
+
+from core.rag.datasource.vdb.field import Field as vdb_Field
+
+
+class MockVikingDBClass:
+    def __init__(
+        self,
+        host="api-vikingdb.volces.com",
+        region="cn-north-1",
+        ak="",
+        sk="",
+        scheme="http",
+        connection_timeout=30,
+        socket_timeout=30,
+        proxy=None,
+    ):
+        self._viking_db_service = MagicMock()
+        self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
+
+    def get_collection(self, collection_name) -> Collection:
+        return Collection(
+            collection_name=collection_name,
+            description="Collection For Dify",
+            viking_db_service=self._viking_db_service,
+            primary_key=vdb_Field.PRIMARY_KEY.value,
+            fields=[
+                Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
+                Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
+                Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
+                Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
+                Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
+            ],
+            indexes=[
+                Index(
+                    collection_name=collection_name,
+                    index_name=f"{collection_name}_idx",
+                    vector_index=VectorIndexParams(
+                        distance=DistanceType.L2,
+                        index_type=IndexType.HNSW,
+                        quant=QuantType.Float,
+                    ),
+                    scalar_index=None,
+                    stat=None,
+                    viking_db_service=self._viking_db_service,
+                )
+            ],
+        )
+
+    def drop_collection(self, collection_name):
+        assert collection_name != ""
+
+    def create_collection(self, collection_name, fields, description="") -> Collection:
+        return Collection(
+            collection_name=collection_name,
+            description=description,
+            primary_key=vdb_Field.PRIMARY_KEY.value,
+            viking_db_service=self._viking_db_service,
+            fields=fields,
+        )
+
+    def get_index(self, collection_name, index_name) -> Index:
+        return Index(
+            collection_name=collection_name,
+            index_name=index_name,
+            viking_db_service=self._viking_db_service,
+            stat=None,
+            scalar_index=None,
+            vector_index=VectorIndexParams(
+                distance=DistanceType.L2,
+                index_type=IndexType.HNSW,
+                quant=QuantType.Float,
+            ),
+        )
+
+    def create_index(
+        self,
+        collection_name,
+        index_name,
+        vector_index=None,
+        cpu_quota=2,
+        description="",
+        partition_by="",
+        scalar_index=None,
+        shard_count=None,
+        shard_policy=None,
+    ):
+        return Index(
+            collection_name=collection_name,
+            index_name=index_name,
+            vector_index=vector_index,
+            cpu_quota=cpu_quota,
+            description=description,
+            partition_by=partition_by,
+            scalar_index=scalar_index,
+            shard_count=shard_count,
+            shard_policy=shard_policy,
+            viking_db_service=self._viking_db_service,
+            stat=None,
+        )
+
+    def drop_index(self, collection_name, index_name):
+        assert collection_name != ""
+        assert index_name != ""
+
+    def upsert_data(self, data: Union[Data, list[Data]]):
+        assert data is not None
+
+    def fetch_data(self, id: Union[str, list[str], int, list[int]]):
+        return Data(
+            fields={
+                vdb_Field.GROUP_KEY.value: "test_group",
+                vdb_Field.METADATA_KEY.value: "{}",
+                vdb_Field.CONTENT_KEY.value: "content",
+                vdb_Field.PRIMARY_KEY.value: id,
+                vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
+            },
+            id=id,
+        )
+
+    def delete_data(self, id: Union[str, list[str], int, list[int]]):
+        assert id is not None
+
+    def search_by_vector(
+        self,
+        vector,
+        sparse_vectors=None,
+        filter=None,
+        limit=10,
+        output_fields=None,
+        partition="default",
+        dense_weight=None,
+    ) -> list[Data]:
+        return [
+            Data(
+                fields={
+                    vdb_Field.GROUP_KEY.value: "test_group",
+                    vdb_Field.METADATA_KEY.value: '\
+                    {"source": "/var/folders/ml/xxx/xxx.txt", \
+                    "document_id": "test_document_id", \
+                    "dataset_id": "test_dataset_id", \
+                    "doc_id": "test_id", \
+                    "doc_hash": "test_hash"}',
+                    vdb_Field.CONTENT_KEY.value: "content",
+                    vdb_Field.PRIMARY_KEY.value: "test_id",
+                    vdb_Field.VECTOR.value: vector,
+                },
+                id="test_id",
+                score=0.10,
+            )
+        ]
+
+    def search(
+        self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
+    ) -> list[Data]:
+        return [
+            Data(
+                fields={
+                    vdb_Field.GROUP_KEY.value: "test_group",
+                    vdb_Field.METADATA_KEY.value: '\
+                    {"source": "/var/folders/ml/xxx/xxx.txt", \
+                    "document_id": "test_document_id", \
+                    "dataset_id": "test_dataset_id", \
+                    "doc_id": "test_id", \
+                    "doc_hash": "test_hash"}',
+                    vdb_Field.CONTENT_KEY.value: "content",
+                    vdb_Field.PRIMARY_KEY.value: "test_id",
+                    vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
+                },
+                id="test_id",
+                score=0.10,
+            )
+        ]
+
+
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
+
+@pytest.fixture
+def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
+    if MOCK:
+        monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
+        monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
+        monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
+        monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
+        monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
+        monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
+        monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
+        monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
+        monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
+        monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
+        monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
+        monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
+
+    yield
+
+    if MOCK:
+        monkeypatch.undo()

+ 0 - 0
api/tests/integration_tests/vdb/vikingdb/__init__.py


+ 37 - 0
api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py

@@ -0,0 +1,37 @@
+from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
+from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
+from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
+
+
+class VikingDBVectorTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = VikingDBVector(
+            "test_collection",
+            "test_group",
+            config=VikingDBConfig(
+                access_key="test_access_key",
+                host="test_host",
+                region="test_region",
+                scheme="test_scheme",
+                secret_key="test_secret_key",
+                connection_timeout=30,
+                socket_timeout=30,
+            ),
+        )
+
+    def search_by_vector(self):
+        hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
+        assert len(hits_by_vector) == 1
+
+    def search_by_full_text(self):
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+    def get_ids_by_metadata_field(self):
+        ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
+        assert len(ids) > 0
+
+
+def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
+    VikingDBVectorTest().run_all_tests()

+ 2 - 1
dev/pytest/pytest_vdb.sh

@@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \
   api/tests/integration_tests/vdb/pgvector \
   api/tests/integration_tests/vdb/qdrant \
   api/tests/integration_tests/vdb/weaviate \
-  api/tests/integration_tests/vdb/elasticsearch
+  api/tests/integration_tests/vdb/elasticsearch \
+  api/tests/integration_tests/vdb/vikingdb

+ 5 - 0
docker/docker-compose.yaml

@@ -173,6 +173,11 @@ x-shared-env: &shared-api-worker-env
   BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
   BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
   BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
+  VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-dify}
+  VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-dify}
+  VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai}
+  VIKINGDB_HOST: ${VIKINGDB_HOST:-api-vikingdb.xxx.volces.com}
+  VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http}
   UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
   UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
   ETL_TYPE: ${ETL_TYPE:-dify}