Переглянути джерело

feat: support opensearch approximate k-NN (#5322)

baojingyu 10 місяців тому
батько
коміт
d160d1ed02

+ 8 - 0
api/commands.py

@@ -327,6 +327,14 @@ def migrate_knowledge_vector_database():
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == VectorType.OPENSEARCH:
+                    dataset_id = dataset.id
+                    collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                    index_struct_dict = {
+                        "type": VectorType.OPENSEARCH,
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
                 else:
                     raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 7 - 0
api/config.py

@@ -282,6 +282,13 @@ class Config:
         self.MILVUS_SECURE = get_env('MILVUS_SECURE')
         self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
 
+        # OpenSearch settings
+        self.OPENSEARCH_HOST = get_env('OPENSEARCH_HOST')
+        self.OPENSEARCH_PORT = get_env('OPENSEARCH_PORT')
+        self.OPENSEARCH_USER = get_env('OPENSEARCH_USER')
+        self.OPENSEARCH_PASSWORD = get_env('OPENSEARCH_PASSWORD')
+        self.OPENSEARCH_SECURE = get_bool_env('OPENSEARCH_SECURE')
+
         # weaviate settings
         self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
         self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -503,7 +503,7 @@ class DatasetRetrievalSettingApi(Resource):
                         'semantic_search'
                     ]
                 }
-            case VectorType.QDRANT | VectorType.WEAVIATE:
+            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
                 return {
                     'retrieval_method': [
                         'semantic_search', 'full_text_search', 'hybrid_search'
@@ -525,7 +525,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                         'semantic_search'
                     ]
                 }
-            case VectorType.QDRANT | VectorType.WEAVIATE:
+            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
                 return {
                     'retrieval_method': [
                         'semantic_search', 'full_text_search', 'hybrid_search'

+ 0 - 0
api/core/rag/datasource/vdb/opensearch/__init__.py


+ 278 - 0
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -0,0 +1,278 @@
+import json
+import logging
+import ssl
+from typing import Any, Optional
+from uuid import uuid4
+
+from flask import current_app
+from opensearchpy import OpenSearch, helpers
+from opensearchpy.helpers import BulkIndexError
+from pydantic import BaseModel, model_validator
+
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+logger = logging.getLogger(__name__)
+
+
+class OpenSearchConfig(BaseModel):
+    host: str
+    port: int
+    user: Optional[str] = None
+    password: Optional[str] = None
+    secure: bool = False
+
+    @model_validator(mode='before')
+    def validate_config(cls, values: dict) -> dict:
+        if not values.get('host'):
+            raise ValueError("config OPENSEARCH_HOST is required")
+        if not values.get('port'):
+            raise ValueError("config OPENSEARCH_PORT is required")
+        return values
+
+    def create_ssl_context(self) -> ssl.SSLContext:
+        ssl_context = ssl.create_default_context()
+        ssl_context.check_hostname = False
+        ssl_context.verify_mode = ssl.CERT_NONE  # Disable Certificate Validation
+        return ssl_context
+
+    def to_opensearch_params(self) -> dict[str, Any]:
+        params = {
+            'hosts': [{'host': self.host, 'port': self.port}],
+            'use_ssl': self.secure,
+            'verify_certs': self.secure,
+        }
+        if self.user and self.password:
+            params['http_auth'] = (self.user, self.password)
+        if self.secure:
+            params['ssl_context'] = self.create_ssl_context()
+        return params
+
+
+class OpenSearchVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: OpenSearchConfig):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = OpenSearch(**config.to_opensearch_params())
+
+    def get_type(self) -> str:
+        return VectorType.OPENSEARCH
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        metadatas = [d.metadata for d in texts]
+        self.create_collection(embeddings, metadatas)
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        actions = []
+        for i in range(len(documents)):
+            action = {
+                "_op_type": "index",
+                "_index": self._collection_name.lower(),
+                "_id": uuid4().hex,
+                "_source": {
+                    Field.CONTENT_KEY.value: documents[i].page_content,
+                    Field.VECTOR.value: embeddings[i],  # Make sure you pass an array here
+                    Field.METADATA_KEY.value: documents[i].metadata,
+                }
+            }
+            actions.append(action)
+
+        helpers.bulk(self._client, actions)
+
+    def delete_by_document_id(self, document_id: str):
+        ids = self.get_ids_by_metadata_field('document_id', document_id)
+        if ids:
+            self.delete_by_ids(ids)
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
+        response = self._client.search(index=self._collection_name.lower(), body=query)
+        if response['hits']['hits']:
+            return [hit['_id'] for hit in response['hits']['hits']]
+        else:
+            return None
+
+    def delete_by_metadata_field(self, key: str, value: str):
+        ids = self.get_ids_by_metadata_field(key, value)
+        if ids:
+            self.delete_by_ids(ids)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        index_name = self._collection_name.lower()
+        if not self._client.indices.exists(index=index_name):
+            logger.warning(f"Index {index_name} does not exist")
+            return
+
+        # Obtaining All Actual Documents_ID
+        actual_ids = []
+
+        for doc_id in ids:
+            es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
+            if es_ids:
+                actual_ids.extend(es_ids)
+            else:
+                logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion")
+
+        if actual_ids:
+            actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids]
+            try:
+                helpers.bulk(self._client, actions)
+            except BulkIndexError as e:
+                for error in e.errors:
+                    delete_error = error.get('delete', {})
+                    status = delete_error.get('status')
+                    doc_id = delete_error.get('_id')
+
+                    if status == 404:
+                        logger.warning(f"Document not found for deletion: {doc_id}")
+                    else:
+                        logger.error(f"Error deleting document: {error}")
+
+    def delete(self) -> None:
+        self._client.indices.delete(index=self._collection_name.lower())
+
+    def text_exists(self, id: str) -> bool:
+        try:
+            self._client.get(index=self._collection_name.lower(), id=id)
+            return True
+        except:
+            return False
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        # Make sure query_vector is a list
+        if not isinstance(query_vector, list):
+            raise ValueError("query_vector should be a list of floats")
+
+        # Check whether query_vector is a floating-point number list
+        if not all(isinstance(x, float) for x in query_vector):
+            raise ValueError("All elements in query_vector should be floats")
+
+        query = {
+            "size": kwargs.get('top_k', 4),
+            "query": {
+                "knn": {
+                    Field.VECTOR.value: {
+                        Field.VECTOR.value: query_vector,
+                        "k": kwargs.get('top_k', 4)
+                    }
+                }
+            }
+        }
+
+        try:
+            response = self._client.search(index=self._collection_name.lower(), body=query)
+        except Exception as e:
+            logger.error(f"Error executing search: {e}")
+            raise
+
+        docs = []
+        for hit in response['hits']['hits']:
+            metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
+
+            # Make sure metadata is a dictionary
+            if metadata is None:
+                metadata = {}
+
+            metadata['score'] = hit['_score']
+            score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
+            if hit['_score'] > score_threshold:
+                doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
+                docs.append(doc)
+
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
+
+        response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
+
+        docs = []
+        for hit in response['hits']['hits']:
+            metadata = hit['_source'].get(Field.METADATA_KEY.value)
+            doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
+            docs.append(doc)
+
+        return docs
+
+    def create_collection(
+            self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
+    ):
+        lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
+            if redis_client.get(collection_exist_cache_key):
+                logger.info(f"Collection {self._collection_name.lower()} already exists.")
+                return
+
+            if not self._client.indices.exists(index=self._collection_name.lower()):
+                index_body = {
+                    "settings": {
+                        "index": {
+                            "knn": True
+                        }
+                    },
+                    "mappings": {
+                        "properties": {
+                            Field.CONTENT_KEY.value: {"type": "text"},
+                            Field.VECTOR.value: {
+                                "type": "knn_vector",
+                                "dimension": len(embeddings[0]),  # Make sure the dimension is correct here
+                                "method": {
+                                    "name": "hnsw",
+                                    "space_type": "l2",
+                                    "engine": "faiss",
+                                    "parameters": {
+                                        "ef_construction": 64,
+                                        "m": 8
+                                    }
+                                }
+                            },
+                            Field.METADATA_KEY.value: {
+                                "type": "object",
+                                "properties": {
+                                    "doc_id": {"type": "keyword"}  # Map doc_id to keyword type
+                                }
+                            }
+                        }
+                    }
+                }
+
+                self._client.indices.create(index=self._collection_name.lower(), body=index_body)
+
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+
+class OpenSearchVectorFactory(AbstractVectorFactory):
+
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
+
+        config = current_app.config
+
+        open_search_config = OpenSearchConfig(
+            host=config.get('OPENSEARCH_HOST'),
+            port=config.get('OPENSEARCH_PORT'),
+            user=config.get('OPENSEARCH_USER'),
+            password=config.get('OPENSEARCH_PASSWORD'),
+            secure=config.get('OPENSEARCH_SECURE'),
+        )
+
+        return OpenSearchVector(
+            collection_name=collection_name,
+            config=open_search_config
+        )

+ 3 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -78,6 +78,9 @@ class Vector:
             case VectorType.TENCENT:
                 from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
                 return TencentVectorFactory
+            case VectorType.OPENSEARCH:
+                from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
+                return OpenSearchVectorFactory
             case _:
                 raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -10,4 +10,5 @@ class VectorType(str, Enum):
     RELYT = 'relyt'
     TIDB_VECTOR = 'tidb_vector'
     WEAVIATE = 'weaviate'
+    OPENSEARCH = 'opensearch'
     TENCENT = 'tencent'

+ 0 - 0
api/events/__init__.py


+ 36 - 1
api/poetry.lock

@@ -4891,6 +4891,30 @@ files = [
 [package.dependencies]
 et-xmlfile = "*"
 
+[[package]]
+name = "opensearch-py"
+version = "2.4.0"
+description = "Python client for OpenSearch"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4"
+files = [
+    {file = "opensearch-py-2.4.0.tar.gz", hash = "sha256:7eba2b6ed2ddcf33225bfebfba2aee026877838cc39f760ec80f27827308cc4b"},
+    {file = "opensearch_py-2.4.0-py2.py3-none-any.whl", hash = "sha256:316077235437c8ceac970232261f3393c65fb92a80f33c5b106f50f1dab24fd9"},
+]
+
+[package.dependencies]
+certifi = ">=2022.12.07"
+python-dateutil = "*"
+requests = ">=2.4.0,<3.0.0"
+six = "*"
+urllib3 = ">=1.26.18"
+
+[package.extras]
+async = ["aiohttp (>=3,<4)"]
+develop = ["black", "botocore", "coverage (<8.0.0)", "jinja2", "mock", "myst-parser", "pytest (>=3.0.0)", "pytest-cov", "pytest-mock (<4.0.0)", "pytz", "pyyaml", "requests (>=2.0.0,<3.0.0)", "sphinx", "sphinx-copybutton", "sphinx-rtd-theme"]
+docs = ["aiohttp (>=3,<4)", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-rtd-theme"]
+kerberos = ["requests-kerberos"]
+
 [[package]]
 name = "opentelemetry-api"
 version = "1.25.0"
@@ -6414,6 +6438,7 @@ files = [
     {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
     {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
     {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+    {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
     {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
     {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
     {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
@@ -6421,8 +6446,16 @@ files = [
     {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
     {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
     {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+    {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
     {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
     {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+    {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+    {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+    {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
+    {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+    {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+    {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+    {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
     {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
     {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
     {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
@@ -6439,6 +6472,7 @@ files = [
     {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
     {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
     {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
+    {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
     {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
     {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
     {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
@@ -6446,6 +6480,7 @@ files = [
     {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
     {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
     {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+    {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
     {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
     {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
     {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
@@ -8944,4 +8979,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "21360e271c46e0368b8e3bd26287caca73145a73ee73287669f91e7eac6f05b9"
+content-hash = "367a4b0ad745a48263dd44711be28c4c076dee983e3f5d1ac56c22bbb2eed531"

+ 1 - 0
api/pyproject.toml

@@ -185,6 +185,7 @@ chromadb = "~0.5.1"
 tenacity = "~8.3.0"
 cos-python-sdk-v5 = "1.9.30"
 novita-client = "^0.5.6"
+opensearch-py = "2.4.0"
 
 [tool.poetry.group.dev]
 optional = true

+ 2 - 1
api/requirements.txt

@@ -90,4 +90,5 @@ tencentcloud-sdk-python-hunyuan~=3.0.1158
 chromadb~=0.5.1
 novita_client~=0.5.6
 tenacity~=8.3.0
-cos-python-sdk-v5==1.9.30
+opensearch-py==2.4.0
+cos-python-sdk-v5==1.9.30

+ 0 - 0
api/tests/integration_tests/vdb/opensearch/__init__.py


+ 186 - 0
api/tests/integration_tests/vdb/opensearch/test_opensearch.py

@@ -0,0 +1,186 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector
+from core.rag.models.document import Document
+from extensions import ext_redis
+
+
+def get_example_text() -> str:
+    return "This is a sample text for testing purposes."
+
+
+@pytest.fixture(scope="module")
+def setup_mock_redis():
+    ext_redis.redis_client.get = MagicMock(return_value=None)
+    ext_redis.redis_client.set = MagicMock(return_value=None)
+
+    mock_redis_lock = MagicMock()
+    mock_redis_lock.__enter__ = MagicMock()
+    mock_redis_lock.__exit__ = MagicMock()
+    ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
+
+
+class TestOpenSearchVector:
+    def setup_method(self):
+        self.collection_name = "test_collection"
+        self.example_doc_id = "example_doc_id"
+        self.vector = OpenSearchVector(
+            collection_name=self.collection_name,
+            config=OpenSearchConfig(
+                host='localhost',
+                port=9200,
+                user='admin',
+                password='password',
+                secure=False
+            )
+        )
+        self.vector._client = MagicMock()
+
+    @pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [
+        ({
+            'hits': {
+                'total': {'value': 1},
+                'hits': [
+                    {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}}
+                ]
+            }
+        }, 1, "example_doc_id"),
+        ({
+            'hits': {
+                'total': {'value': 0},
+                'hits': []
+            }
+        }, 0, None)
+    ])
+    def test_search_by_full_text(self, search_response, expected_length, expected_doc_id):
+        self.vector._client.search.return_value = search_response
+
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == expected_length
+        if expected_length > 0:
+            assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id
+
+    def test_search_by_vector(self):
+        vector = [0.1] * 128
+        mock_response = {
+            'hits': {
+                'total': {'value': 1},
+                'hits': [
+                    {
+                        '_source': {
+                            Field.CONTENT_KEY.value: get_example_text(),
+                            Field.METADATA_KEY.value: {"document_id": self.example_doc_id}
+                        },
+                        '_score': 1.0
+                    }
+                ]
+            }
+        }
+        self.vector._client.search.return_value = mock_response
+
+        hits_by_vector = self.vector.search_by_vector(query_vector=vector)
+
+        print("Hits by vector:", hits_by_vector)
+        print("Expected document ID:", self.example_doc_id)
+        print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits")
+
+        assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
+        assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \
+            f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
+
+    def test_delete_by_document_id(self):
+        self.vector._client.delete_by_query.return_value = {'deleted': 1}
+
+        doc = Document(page_content="Test content to delete", metadata={"document_id": self.example_doc_id})
+        embedding = [0.1] * 128
+
+        with patch('opensearchpy.helpers.bulk') as mock_bulk:
+            mock_bulk.return_value = ([], [])
+            self.vector.add_texts([doc], [embedding])
+
+        self.vector.delete_by_document_id(document_id=self.example_doc_id)
+
+        self.vector._client.search.return_value = {'hits': {'total': {'value': 0}, 'hits': []}}
+
+        ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+        assert ids is None or len(ids) == 0
+
+    def test_get_ids_by_metadata_field(self):
+        mock_response = {
+            'hits': {
+                'total': {'value': 1},
+                'hits': [{'_id': 'mock_id'}]
+            }
+        }
+        self.vector._client.search.return_value = mock_response
+
+        doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
+        embedding = [0.1] * 128
+
+        with patch('opensearchpy.helpers.bulk') as mock_bulk:
+            mock_bulk.return_value = ([], [])
+            self.vector.add_texts([doc], [embedding])
+
+        ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+        assert len(ids) == 1
+        assert ids[0] == 'mock_id'
+
+    def test_add_texts(self):
+        self.vector._client.index.return_value = {'result': 'created'}
+
+        doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
+        embedding = [0.1] * 128
+
+        with patch('opensearchpy.helpers.bulk') as mock_bulk:
+            mock_bulk.return_value = ([], [])
+            self.vector.add_texts([doc], [embedding])
+
+        mock_response = {
+            'hits': {
+                'total': {'value': 1},
+                'hits': [{'_id': 'mock_id'}]
+            }
+        }
+        self.vector._client.search.return_value = mock_response
+
+        ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+        assert len(ids) == 1
+        assert ids[0] == 'mock_id'
+
+@pytest.mark.usefixtures("setup_mock_redis")
+class TestOpenSearchVectorWithRedis:
+    def setup_method(self):
+        self.tester = TestOpenSearchVector()
+
+    def test_search_by_full_text(self):
+        self.tester.setup_method()
+        search_response = {
+            'hits': {
+                'total': {'value': 1},
+                'hits': [
+                    {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}}
+                ]
+            }
+        }
+        expected_length = 1
+        expected_doc_id = "example_doc_id"
+        self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id)
+
+    def test_delete_by_document_id(self):
+        self.tester.setup_method()
+        self.tester.test_delete_by_document_id()
+
+    def test_get_ids_by_metadata_field(self):
+        self.tester.setup_method()
+        self.tester.test_get_ids_by_metadata_field()
+
+    def test_add_texts(self):
+        self.tester.setup_method()
+        self.tester.test_add_texts()
+
+    def test_search_by_vector(self):
+        self.tester.setup_method()
+        self.tester.test_search_by_vector()

+ 41 - 0
docker/docker-compose.opensearch.yml

@@ -0,0 +1,41 @@
+version: '3'
+services:
+  opensearch: # This is also the hostname of the container within the Docker network (i.e. https://opensearch/)
+    image: opensearchproject/opensearch:latest # Specifying the latest available image - modify if you want a specific version
+    container_name: opensearch
+    environment:
+      - discovery.type=single-node
+      - bootstrap.memory_lock=true # Disable JVM heap memory swapping
+      - "OPENSEARCH_JAVA_OPTS=-Xms512m -Xmx1024m" # Set min and max JVM heap sizes to at least 50% of system RAM
+      - OPENSEARCH_INITIAL_ADMIN_PASSWORD=Qazwsxedc!@#123    # Sets the demo admin user password when using demo configuration, required for OpenSearch 2.12 and later
+    ulimits:
+      memlock:
+        soft: -1 # Set memlock to unlimited (no soft or hard limit)
+        hard: -1
+      nofile:
+        soft: 65536 # Maximum number of open files for the opensearch user - set to at least 65536
+        hard: 65536
+    volumes:
+      - ./volumes/opensearch/data:/usr/share/opensearch/data # Creates volume called opensearch-data1 and mounts it to the container
+    ports:
+      - 9200:9200 # REST API
+      - 9600:9600 # Performance Analyzer
+    networks:
+      - opensearch-net # All of the containers will join the same Docker bridge network
+  opensearch-dashboards:
+    image: opensearchproject/opensearch-dashboards:latest # Make sure the version of opensearch-dashboards matches the version of opensearch installed on other nodes
+    container_name: opensearch-dashboards
+    ports:
+      - 5601:5601 # Map host port 5601 to container port 5601
+    expose:
+      - "5601" # Expose port 5601 for web access to OpenSearch Dashboards
+    environment:
+      OPENSEARCH_HOSTS: '["https://opensearch:9200"]' # Define the OpenSearch nodes that OpenSearch Dashboards will query
+    volumes:
+      - ./volumes/opensearch/opensearch_dashboards.yml:/usr/share/opensearch-dashboards/config/opensearch_dashboards.yml
+    networks:
+      - opensearch-net
+
+networks:
+  opensearch-net:
+    driver: bridge

+ 6 - 0
docker/docker-compose.yaml

@@ -332,6 +332,12 @@ services:
       TENCENT_VECTOR_DB_DATABASE: dify
       TENCENT_VECTOR_DB_SHARD: 1
       TENCENT_VECTOR_DB_REPLICAS: 2
+      # OpenSearch configuration
+      OPENSEARCH_HOST: 127.0.0.1
+      OPENSEARCH_PORT: 9200
+      OPENSEARCH_USER: admin
+      OPENSEARCH_PASSWORD: admin
+      OPENSEARCH_SECURE: 'true'
       # pgvector configurations
       PGVECTOR_HOST: pgvector
       PGVECTOR_PORT: 5432

+ 222 - 0
docker/volumes/opensearch/opensearch_dashboards.yml

@@ -0,0 +1,222 @@
+---
+# Copyright OpenSearch Contributors
+# SPDX-License-Identifier: Apache-2.0
+
+# Description:
+# Default configuration for OpenSearch Dashboards
+
+# OpenSearch Dashboards is served by a back end server. This setting specifies the port to use.
+# server.port: 5601
+
+# Specifies the address to which the OpenSearch Dashboards server will bind. IP addresses and host names are both valid values.
+# The default is 'localhost', which usually means remote machines will not be able to connect.
+# To allow connections from remote users, set this parameter to a non-loopback address.
+# server.host: "localhost"
+
+# Enables you to specify a path to mount OpenSearch Dashboards at if you are running behind a proxy.
+# Use the `server.rewriteBasePath` setting to tell OpenSearch Dashboards if it should remove the basePath
+# from requests it receives, and to prevent a deprecation warning at startup.
+# This setting cannot end in a slash.
+# server.basePath: ""
+
+# Specifies whether OpenSearch Dashboards should rewrite requests that are prefixed with
+# `server.basePath` or require that they are rewritten by your reverse proxy.
+# server.rewriteBasePath: false
+
+# The maximum payload size in bytes for incoming server requests.
+# server.maxPayloadBytes: 1048576
+
+# The OpenSearch Dashboards server's name.  This is used for display purposes.
+# server.name: "your-hostname"
+
+# The URLs of the OpenSearch instances to use for all your queries.
+# opensearch.hosts: ["http://localhost:9200"]
+
+# OpenSearch Dashboards uses an index in OpenSearch to store saved searches, visualizations and
+# dashboards. OpenSearch Dashboards creates a new index if the index doesn't already exist.
+# opensearchDashboards.index: ".opensearch_dashboards"
+
+# The default application to load.
+# opensearchDashboards.defaultAppId: "home"
+
+# Setting for an optimized healthcheck that only uses the local OpenSearch node to do Dashboards healthcheck.
+# This settings should be used for large clusters or for clusters with ingest heavy nodes.
+# It allows Dashboards to only healthcheck using the local OpenSearch node rather than fan out requests across all nodes.
+#
+# It requires the user to create an OpenSearch node attribute with the same name as the value used in the setting
+# This node attribute should assign all nodes of the same cluster an integer value that increments with each new cluster that is spun up
+# e.g. in opensearch.yml file you would set the value to a setting using node.attr.cluster_id:
+# Should only be enabled if there is a corresponding node attribute created in your OpenSearch config that matches the value here
+# opensearch.optimizedHealthcheckId: "cluster_id"
+
+# If your OpenSearch is protected with basic authentication, these settings provide
+# the username and password that the OpenSearch Dashboards server uses to perform maintenance on the OpenSearch Dashboards
+# index at startup. Your OpenSearch Dashboards users still need to authenticate with OpenSearch, which
+# is proxied through the OpenSearch Dashboards server.
+# opensearch.username: "opensearch_dashboards_system"
+# opensearch.password: "pass"
+
+# Enables SSL and paths to the PEM-format SSL certificate and SSL key files, respectively.
+# These settings enable SSL for outgoing requests from the OpenSearch Dashboards server to the browser.
+# server.ssl.enabled: false
+# server.ssl.certificate: /path/to/your/server.crt
+# server.ssl.key: /path/to/your/server.key
+
+# Optional settings that provide the paths to the PEM-format SSL certificate and key files.
+# These files are used to verify the identity of OpenSearch Dashboards to OpenSearch and are required when
+# xpack.security.http.ssl.client_authentication in OpenSearch is set to required.
+# opensearch.ssl.certificate: /path/to/your/client.crt
+# opensearch.ssl.key: /path/to/your/client.key
+
+# Optional setting that enables you to specify a path to the PEM file for the certificate
+# authority for your OpenSearch instance.
+# opensearch.ssl.certificateAuthorities: [ "/path/to/your/CA.pem" ]
+
+# To disregard the validity of SSL certificates, change this setting's value to 'none'.
+# opensearch.ssl.verificationMode: full
+
+# Time in milliseconds to wait for OpenSearch to respond to pings. Defaults to the value of
+# the opensearch.requestTimeout setting.
+# opensearch.pingTimeout: 1500
+
+# Time in milliseconds to wait for responses from the back end or OpenSearch. This value
+# must be a positive integer.
+# opensearch.requestTimeout: 30000
+
+# List of OpenSearch Dashboards client-side headers to send to OpenSearch. To send *no* client-side
+# headers, set this value to [] (an empty list).
+# opensearch.requestHeadersWhitelist: [ authorization ]
+
+# Header names and values that are sent to OpenSearch. Any custom headers cannot be overwritten
+# by client-side headers, regardless of the opensearch.requestHeadersWhitelist configuration.
+# opensearch.customHeaders: {}
+
+# Time in milliseconds for OpenSearch to wait for responses from shards. Set to 0 to disable.
+# opensearch.shardTimeout: 30000
+
+# Logs queries sent to OpenSearch. Requires logging.verbose set to true.
+# opensearch.logQueries: false
+
+# Specifies the path where OpenSearch Dashboards creates the process ID file.
+# pid.file: /var/run/opensearchDashboards.pid
+
+# Enables you to specify a file where OpenSearch Dashboards stores log output.
+# logging.dest: stdout
+
+# Set the value of this setting to true to suppress all logging output.
+# logging.silent: false
+
+# Set the value of this setting to true to suppress all logging output other than error messages.
+# logging.quiet: false
+
+# Set the value of this setting to true to log all events, including system usage information
+# and all requests.
+# logging.verbose: false
+
+# Set the interval in milliseconds to sample system and process performance
+# metrics. Minimum is 100ms. Defaults to 5000.
+# ops.interval: 5000
+
+# Specifies locale to be used for all localizable strings, dates and number formats.
+# Supported languages are the following: English - en , by default , Chinese - zh-CN .
+# i18n.locale: "en"
+
+# Set the allowlist to check input graphite Url. Allowlist is the default check list.
+# vis_type_timeline.graphiteAllowedUrls: ['https://www.hostedgraphite.com/UID/ACCESS_KEY/graphite']
+
+# Set the blocklist to check input graphite Url. Blocklist is an IP list.
+# Below is an example for reference
+# vis_type_timeline.graphiteBlockedIPs: [
+#  //Loopback
+#  '127.0.0.0/8',
+#  '::1/128',
+#  //Link-local Address for IPv6
+#  'fe80::/10',
+#  //Private IP address for IPv4
+#  '10.0.0.0/8',
+#  '172.16.0.0/12',
+#  '192.168.0.0/16',
+#  //Unique local address (ULA)
+#  'fc00::/7',
+#  //Reserved IP address
+#  '0.0.0.0/8',
+#  '100.64.0.0/10',
+#  '192.0.0.0/24',
+#  '192.0.2.0/24',
+#  '198.18.0.0/15',
+#  '192.88.99.0/24',
+#  '198.51.100.0/24',
+#  '203.0.113.0/24',
+#  '224.0.0.0/4',
+#  '240.0.0.0/4',
+#  '255.255.255.255/32',
+#  '::/128',
+#  '2001:db8::/32',
+#  'ff00::/8',
+# ]
+# vis_type_timeline.graphiteBlockedIPs: []
+
+# opensearchDashboards.branding:
+#   logo:
+#     defaultUrl: ""
+#     darkModeUrl: ""
+#   mark:
+#     defaultUrl: ""
+#     darkModeUrl: ""
+#   loadingLogo:
+#     defaultUrl: ""
+#     darkModeUrl: ""
+#   faviconUrl: ""
+#   applicationTitle: ""
+
+# Set the value of this setting to true to capture region blocked warnings and errors
+# for your map rendering services.
+# map.showRegionBlockedWarning: false%
+
+# Set the value of this setting to false to suppress search usage telemetry
+# for reducing the load of OpenSearch cluster.
+# data.search.usageTelemetry.enabled: false
+
+# 2.4 renames 'wizard.enabled: false' to 'vis_builder.enabled: false'
+# Set the value of this setting to false to disable VisBuilder
+# functionality in Visualization.
+# vis_builder.enabled: false
+
+# 2.4 New Experimental Feature
+# Set the value of this setting to true to enable the experimental multiple data source
+# support feature. Use with caution.
+# data_source.enabled: false
+# Set the value of these settings to customize crypto materials to encryption saved credentials
+# in data sources.
+# data_source.encryption.wrappingKeyName: 'changeme'
+# data_source.encryption.wrappingKeyNamespace: 'changeme'
+# data_source.encryption.wrappingKey: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
+
+# 2.6 New ML Commons Dashboards Feature
+# Set the value of this setting to true to enable the ml commons dashboards
+# ml_commons_dashboards.enabled: false
+
+# 2.12 New experimental Assistant Dashboards Feature
+# Set the value of this setting to true to enable the assistant dashboards
+# assistant.chat.enabled: false
+
+# 2.13 New Query Assistant Feature
+# Set the value of this setting to false to disable the query assistant
+# observability.query_assist.enabled: false
+
+# 2.14 Enable Ui Metric Collectors in Usage Collector
+# Set the value of this setting to true to enable UI Metric collections
+# usageCollection.uiMetric.enabled: false
+
+opensearch.hosts: [https://localhost:9200]
+opensearch.ssl.verificationMode: none
+opensearch.username: admin
+opensearch.password: 'Qazwsxedc!@#123'
+opensearch.requestHeadersWhitelist: [authorization, securitytenant]
+
+opensearch_security.multitenancy.enabled: true
+opensearch_security.multitenancy.tenants.preferred: [Private, Global]
+opensearch_security.readonly_mode.roles: [kibana_read_only]
+# Use this setting if you are running opensearch-dashboards without https
+opensearch_security.cookie.secure: false
+server.host: '0.0.0.0'