Pārlūkot izejas kodu

feat: AnalyticDB vector store supports invocation via SQL. (#10802)

Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>
8bitpd 5 mēneši atpakaļ
vecāks
revīzija
873e9720e9

+ 4 - 0
api/.env.example

@@ -234,6 +234,10 @@ ANALYTICDB_ACCOUNT=testaccount
 ANALYTICDB_PASSWORD=testpassword
 ANALYTICDB_NAMESPACE=dify
 ANALYTICDB_NAMESPACE_PASSWORD=difypassword
+ANALYTICDB_HOST=gp-test.aliyuncs.com
+ANALYTICDB_PORT=5432
+ANALYTICDB_MIN_CONNECTION=1
+ANALYTICDB_MAX_CONNECTION=5
 
 # OpenSearch configuration
 OPENSEARCH_HOST=127.0.0.1

+ 9 - 1
api/configs/middleware/vdb/analyticdb_config.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, PositiveInt
 
 
 class AnalyticdbConfig(BaseModel):
@@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
         description="The password for accessing the specified namespace within the AnalyticDB instance"
         " (if namespace feature is enabled).",
     )
+    ANALYTICDB_HOST: Optional[str] = Field(
+        default=None, description="The host of the AnalyticDB instance you want to connect to."
+    )
+    ANALYTICDB_PORT: PositiveInt = Field(
+        default=5432, description="The port of the AnalyticDB instance you want to connect to."
+    )
+    ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
+    ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

+ 45 - 293
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -1,310 +1,62 @@
 import json
 from typing import Any
 
-from pydantic import BaseModel
-
-_import_err_msg = (
-    "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
-    "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
-)
-
 from configs import dify_config
+from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
+    AnalyticdbVectorOpenAPI,
+    AnalyticdbVectorOpenAPIConfig,
+)
+from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.embedding.embedding_base import Embeddings
 from core.rag.models.document import Document
-from extensions.ext_redis import redis_client
 from models.dataset import Dataset
 
 
-class AnalyticdbConfig(BaseModel):
-    access_key_id: str
-    access_key_secret: str
-    region_id: str
-    instance_id: str
-    account: str
-    account_password: str
-    namespace: str = ("dify",)
-    namespace_password: str = (None,)
-    metrics: str = ("cosine",)
-    read_timeout: int = 60000
-
-    def to_analyticdb_client_params(self):
-        return {
-            "access_key_id": self.access_key_id,
-            "access_key_secret": self.access_key_secret,
-            "region_id": self.region_id,
-            "read_timeout": self.read_timeout,
-        }
-
-
 class AnalyticdbVector(BaseVector):
-    def __init__(self, collection_name: str, config: AnalyticdbConfig):
-        self._collection_name = collection_name.lower()
-        try:
-            from alibabacloud_gpdb20160503.client import Client
-            from alibabacloud_tea_openapi import models as open_api_models
-        except:
-            raise ImportError(_import_err_msg)
-        self.config = config
-        self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
-        self._client = Client(self._client_config)
-        self._initialize()
-
-    def _initialize(self) -> None:
-        cache_key = f"vector_indexing_{self.config.instance_id}"
-        lock_name = f"{cache_key}_lock"
-        with redis_client.lock(lock_name, timeout=20):
-            collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
-            if redis_client.get(collection_exist_cache_key):
-                return
-            self._initialize_vector_database()
-            self._create_namespace_if_not_exists()
-            redis_client.set(collection_exist_cache_key, 1, ex=3600)
-
-    def _initialize_vector_database(self) -> None:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        request = gpdb_20160503_models.InitVectorDatabaseRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            manager_account=self.config.account,
-            manager_account_password=self.config.account_password,
-        )
-        self._client.init_vector_database(request)
-
-    def _create_namespace_if_not_exists(self) -> None:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-        from Tea.exceptions import TeaException
-
-        try:
-            request = gpdb_20160503_models.DescribeNamespaceRequest(
-                dbinstance_id=self.config.instance_id,
-                region_id=self.config.region_id,
-                namespace=self.config.namespace,
-                manager_account=self.config.account,
-                manager_account_password=self.config.account_password,
-            )
-            self._client.describe_namespace(request)
-        except TeaException as e:
-            if e.statusCode == 404:
-                request = gpdb_20160503_models.CreateNamespaceRequest(
-                    dbinstance_id=self.config.instance_id,
-                    region_id=self.config.region_id,
-                    manager_account=self.config.account,
-                    manager_account_password=self.config.account_password,
-                    namespace=self.config.namespace,
-                    namespace_password=self.config.namespace_password,
-                )
-                self._client.create_namespace(request)
-            else:
-                raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
-
-    def _create_collection_if_not_exists(self, embedding_dimension: int):
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-        from Tea.exceptions import TeaException
-
-        cache_key = f"vector_indexing_{self._collection_name}"
-        lock_name = f"{cache_key}_lock"
-        with redis_client.lock(lock_name, timeout=20):
-            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
-            if redis_client.get(collection_exist_cache_key):
-                return
-            try:
-                request = gpdb_20160503_models.DescribeCollectionRequest(
-                    dbinstance_id=self.config.instance_id,
-                    region_id=self.config.region_id,
-                    namespace=self.config.namespace,
-                    namespace_password=self.config.namespace_password,
-                    collection=self._collection_name,
-                )
-                self._client.describe_collection(request)
-            except TeaException as e:
-                if e.statusCode == 404:
-                    metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
-                    full_text_retrieval_fields = "page_content"
-                    request = gpdb_20160503_models.CreateCollectionRequest(
-                        dbinstance_id=self.config.instance_id,
-                        region_id=self.config.region_id,
-                        manager_account=self.config.account,
-                        manager_account_password=self.config.account_password,
-                        namespace=self.config.namespace,
-                        collection=self._collection_name,
-                        dimension=embedding_dimension,
-                        metrics=self.config.metrics,
-                        metadata=metadata,
-                        full_text_retrieval_fields=full_text_retrieval_fields,
-                    )
-                    self._client.create_collection(request)
-                else:
-                    raise ValueError(f"failed to create collection {self._collection_name}: {e}")
-            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+    def __init__(
+        self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
+    ):
+        super().__init__(collection_name)
+        if api_config is not None:
+            self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
+        else:
+            self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
 
     def get_type(self) -> str:
         return VectorType.ANALYTICDB
 
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         dimension = len(embeddings[0])
-        self._create_collection_if_not_exists(dimension)
-        self.add_texts(texts, embeddings)
+        self.analyticdb_vector._create_collection_if_not_exists(dimension)
+        self.analyticdb_vector.add_texts(texts, embeddings)
 
-    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
-        for doc, embedding in zip(documents, embeddings, strict=True):
-            metadata = {
-                "ref_doc_id": doc.metadata["doc_id"],
-                "page_content": doc.page_content,
-                "metadata_": json.dumps(doc.metadata),
-            }
-            rows.append(
-                gpdb_20160503_models.UpsertCollectionDataRequestRows(
-                    vector=embedding,
-                    metadata=metadata,
-                )
-            )
-        request = gpdb_20160503_models.UpsertCollectionDataRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            namespace=self.config.namespace,
-            namespace_password=self.config.namespace_password,
-            collection=self._collection_name,
-            rows=rows,
-        )
-        self._client.upsert_collection_data(request)
+    def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        self.analyticdb_vector.add_texts(texts, embeddings)
 
     def text_exists(self, id: str) -> bool:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        request = gpdb_20160503_models.QueryCollectionDataRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            namespace=self.config.namespace,
-            namespace_password=self.config.namespace_password,
-            collection=self._collection_name,
-            metrics=self.config.metrics,
-            include_values=True,
-            vector=None,
-            content=None,
-            top_k=1,
-            filter=f"ref_doc_id='{id}'",
-        )
-        response = self._client.query_collection_data(request)
-        return len(response.body.matches.match) > 0
+        return self.analyticdb_vector.text_exists(id)
 
     def delete_by_ids(self, ids: list[str]) -> None:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        ids_str = ",".join(f"'{id}'" for id in ids)
-        ids_str = f"({ids_str})"
-        request = gpdb_20160503_models.DeleteCollectionDataRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            namespace=self.config.namespace,
-            namespace_password=self.config.namespace_password,
-            collection=self._collection_name,
-            collection_data=None,
-            collection_data_filter=f"ref_doc_id IN {ids_str}",
-        )
-        self._client.delete_collection_data(request)
+        self.analyticdb_vector.delete_by_ids(ids)
 
     def delete_by_metadata_field(self, key: str, value: str) -> None:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        request = gpdb_20160503_models.DeleteCollectionDataRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            namespace=self.config.namespace,
-            namespace_password=self.config.namespace_password,
-            collection=self._collection_name,
-            collection_data=None,
-            collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
-        )
-        self._client.delete_collection_data(request)
+        self.analyticdb_vector.delete_by_metadata_field(key, value)
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        score_threshold = kwargs.get("score_threshold") or 0.0
-        request = gpdb_20160503_models.QueryCollectionDataRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            namespace=self.config.namespace,
-            namespace_password=self.config.namespace_password,
-            collection=self._collection_name,
-            include_values=kwargs.pop("include_values", True),
-            metrics=self.config.metrics,
-            vector=query_vector,
-            content=None,
-            top_k=kwargs.get("top_k", 4),
-            filter=None,
-        )
-        response = self._client.query_collection_data(request)
-        documents = []
-        for match in response.body.matches.match:
-            if match.score > score_threshold:
-                metadata = json.loads(match.metadata.get("metadata_"))
-                metadata["score"] = match.score
-                doc = Document(
-                    page_content=match.metadata.get("page_content"),
-                    metadata=metadata,
-                )
-                documents.append(doc)
-        documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
-        return documents
+        return self.analyticdb_vector.search_by_vector(query_vector)
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
-        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-        score_threshold = float(kwargs.get("score_threshold") or 0.0)
-        request = gpdb_20160503_models.QueryCollectionDataRequest(
-            dbinstance_id=self.config.instance_id,
-            region_id=self.config.region_id,
-            namespace=self.config.namespace,
-            namespace_password=self.config.namespace_password,
-            collection=self._collection_name,
-            include_values=kwargs.pop("include_values", True),
-            metrics=self.config.metrics,
-            vector=None,
-            content=query,
-            top_k=kwargs.get("top_k", 4),
-            filter=None,
-        )
-        response = self._client.query_collection_data(request)
-        documents = []
-        for match in response.body.matches.match:
-            if match.score > score_threshold:
-                metadata = json.loads(match.metadata.get("metadata_"))
-                metadata["score"] = match.score
-                doc = Document(
-                    page_content=match.metadata.get("page_content"),
-                    vector=match.metadata.get("vector"),
-                    metadata=metadata,
-                )
-                documents.append(doc)
-        documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
-        return documents
+        return self.analyticdb_vector.search_by_full_text(query, **kwargs)
 
     def delete(self) -> None:
-        try:
-            from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
-
-            request = gpdb_20160503_models.DeleteCollectionRequest(
-                collection=self._collection_name,
-                dbinstance_id=self.config.instance_id,
-                namespace=self.config.namespace,
-                namespace_password=self.config.namespace_password,
-                region_id=self.config.region_id,
-            )
-            self._client.delete_collection(request)
-        except Exception as e:
-            raise e
+        self.analyticdb_vector.delete()
 
 
 class AnalyticdbVectorFactory(AbstractVectorFactory):
-    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
         if dataset.index_struct_dict:
             class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
             collection_name = class_prefix.lower()
@@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
             collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
             dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
 
-        # handle optional params
-        if dify_config.ANALYTICDB_KEY_ID is None:
-            raise ValueError("ANALYTICDB_KEY_ID should not be None")
-        if dify_config.ANALYTICDB_KEY_SECRET is None:
-            raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
-        if dify_config.ANALYTICDB_REGION_ID is None:
-            raise ValueError("ANALYTICDB_REGION_ID should not be None")
-        if dify_config.ANALYTICDB_INSTANCE_ID is None:
-            raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
-        if dify_config.ANALYTICDB_ACCOUNT is None:
-            raise ValueError("ANALYTICDB_ACCOUNT should not be None")
-        if dify_config.ANALYTICDB_PASSWORD is None:
-            raise ValueError("ANALYTICDB_PASSWORD should not be None")
-        if dify_config.ANALYTICDB_NAMESPACE is None:
-            raise ValueError("ANALYTICDB_NAMESPACE should not be None")
-        if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
-            raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
-        return AnalyticdbVector(
-            collection_name,
-            AnalyticdbConfig(
+        if dify_config.ANALYTICDB_HOST is None:
+            # implemented through OpenAPI
+            apiConfig = AnalyticdbVectorOpenAPIConfig(
                 access_key_id=dify_config.ANALYTICDB_KEY_ID,
                 access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
                 region_id=dify_config.ANALYTICDB_REGION_ID,
@@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
                 account_password=dify_config.ANALYTICDB_PASSWORD,
                 namespace=dify_config.ANALYTICDB_NAMESPACE,
                 namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
-            ),
+            )
+            sqlConfig = None
+        else:
+            # implemented through sql
+            sqlConfig = AnalyticdbVectorBySqlConfig(
+                host=dify_config.ANALYTICDB_HOST,
+                port=dify_config.ANALYTICDB_PORT,
+                account=dify_config.ANALYTICDB_ACCOUNT,
+                account_password=dify_config.ANALYTICDB_PASSWORD,
+                min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
+                max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
+                namespace=dify_config.ANALYTICDB_NAMESPACE,
+            )
+            apiConfig = None
+        return AnalyticdbVector(
+            collection_name,
+            apiConfig,
+            sqlConfig,
         )

+ 309 - 0
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py

@@ -0,0 +1,309 @@
+import json
+from typing import Any
+
+from pydantic import BaseModel, model_validator
+
+_import_err_msg = (
+    "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
+    "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
+)
+
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+
+class AnalyticdbVectorOpenAPIConfig(BaseModel):
+    access_key_id: str
+    access_key_secret: str
+    region_id: str
+    instance_id: str
+    account: str
+    account_password: str
+    namespace: str = "dify"
+    namespace_password: str = (None,)
+    metrics: str = "cosine"
+    read_timeout: int = 60000
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_config(cls, values: dict) -> dict:
+        if not values["access_key_id"]:
+            raise ValueError("config ANALYTICDB_KEY_ID is required")
+        if not values["access_key_secret"]:
+            raise ValueError("config ANALYTICDB_KEY_SECRET is required")
+        if not values["region_id"]:
+            raise ValueError("config ANALYTICDB_REGION_ID is required")
+        if not values["instance_id"]:
+            raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
+        if not values["account"]:
+            raise ValueError("config ANALYTICDB_ACCOUNT is required")
+        if not values["account_password"]:
+            raise ValueError("config ANALYTICDB_PASSWORD is required")
+        if not values["namespace_password"]:
+            raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
+        return values
+
+    def to_analyticdb_client_params(self):
+        return {
+            "access_key_id": self.access_key_id,
+            "access_key_secret": self.access_key_secret,
+            "region_id": self.region_id,
+            "read_timeout": self.read_timeout,
+        }
+
+
+class AnalyticdbVectorOpenAPI:
+    def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
+        try:
+            from alibabacloud_gpdb20160503.client import Client
+            from alibabacloud_tea_openapi import models as open_api_models
+        except:
+            raise ImportError(_import_err_msg)
+        self._collection_name = collection_name.lower()
+        self.config = config
+        self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
+        self._client = Client(self._client_config)
+        self._initialize()
+
+    def _initialize(self) -> None:
+        cache_key = f"vector_initialize_{self.config.instance_id}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
+            if redis_client.get(database_exist_cache_key):
+                return
+            self._initialize_vector_database()
+            self._create_namespace_if_not_exists()
+            redis_client.set(database_exist_cache_key, 1, ex=3600)
+
+    def _initialize_vector_database(self) -> None:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        request = gpdb_20160503_models.InitVectorDatabaseRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            manager_account=self.config.account,
+            manager_account_password=self.config.account_password,
+        )
+        self._client.init_vector_database(request)
+
+    def _create_namespace_if_not_exists(self) -> None:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+        from Tea.exceptions import TeaException
+
+        try:
+            request = gpdb_20160503_models.DescribeNamespaceRequest(
+                dbinstance_id=self.config.instance_id,
+                region_id=self.config.region_id,
+                namespace=self.config.namespace,
+                manager_account=self.config.account,
+                manager_account_password=self.config.account_password,
+            )
+            self._client.describe_namespace(request)
+        except TeaException as e:
+            if e.statusCode == 404:
+                request = gpdb_20160503_models.CreateNamespaceRequest(
+                    dbinstance_id=self.config.instance_id,
+                    region_id=self.config.region_id,
+                    manager_account=self.config.account,
+                    manager_account_password=self.config.account_password,
+                    namespace=self.config.namespace,
+                    namespace_password=self.config.namespace_password,
+                )
+                self._client.create_namespace(request)
+            else:
+                raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
+
+    def _create_collection_if_not_exists(self, embedding_dimension: int):
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+        from Tea.exceptions import TeaException
+
+        cache_key = f"vector_indexing_{self._collection_name}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+            if redis_client.get(collection_exist_cache_key):
+                return
+            try:
+                request = gpdb_20160503_models.DescribeCollectionRequest(
+                    dbinstance_id=self.config.instance_id,
+                    region_id=self.config.region_id,
+                    namespace=self.config.namespace,
+                    namespace_password=self.config.namespace_password,
+                    collection=self._collection_name,
+                )
+                self._client.describe_collection(request)
+            except TeaException as e:
+                if e.statusCode == 404:
+                    metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
+                    full_text_retrieval_fields = "page_content"
+                    request = gpdb_20160503_models.CreateCollectionRequest(
+                        dbinstance_id=self.config.instance_id,
+                        region_id=self.config.region_id,
+                        manager_account=self.config.account,
+                        manager_account_password=self.config.account_password,
+                        namespace=self.config.namespace,
+                        collection=self._collection_name,
+                        dimension=embedding_dimension,
+                        metrics=self.config.metrics,
+                        metadata=metadata,
+                        full_text_retrieval_fields=full_text_retrieval_fields,
+                    )
+                    self._client.create_collection(request)
+                else:
+                    raise ValueError(f"failed to create collection {self._collection_name}: {e}")
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
+        for doc, embedding in zip(documents, embeddings, strict=True):
+            metadata = {
+                "ref_doc_id": doc.metadata["doc_id"],
+                "page_content": doc.page_content,
+                "metadata_": json.dumps(doc.metadata),
+            }
+            rows.append(
+                gpdb_20160503_models.UpsertCollectionDataRequestRows(
+                    vector=embedding,
+                    metadata=metadata,
+                )
+            )
+        request = gpdb_20160503_models.UpsertCollectionDataRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            namespace=self.config.namespace,
+            namespace_password=self.config.namespace_password,
+            collection=self._collection_name,
+            rows=rows,
+        )
+        self._client.upsert_collection_data(request)
+
+    def text_exists(self, id: str) -> bool:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        request = gpdb_20160503_models.QueryCollectionDataRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            namespace=self.config.namespace,
+            namespace_password=self.config.namespace_password,
+            collection=self._collection_name,
+            metrics=self.config.metrics,
+            include_values=True,
+            vector=None,
+            content=None,
+            top_k=1,
+            filter=f"ref_doc_id='{id}'",
+        )
+        response = self._client.query_collection_data(request)
+        return len(response.body.matches.match) > 0
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        ids_str = ",".join(f"'{id}'" for id in ids)
+        ids_str = f"({ids_str})"
+        request = gpdb_20160503_models.DeleteCollectionDataRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            namespace=self.config.namespace,
+            namespace_password=self.config.namespace_password,
+            collection=self._collection_name,
+            collection_data=None,
+            collection_data_filter=f"ref_doc_id IN {ids_str}",
+        )
+        self._client.delete_collection_data(request)
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        request = gpdb_20160503_models.DeleteCollectionDataRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            namespace=self.config.namespace,
+            namespace_password=self.config.namespace_password,
+            collection=self._collection_name,
+            collection_data=None,
+            collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
+        )
+        self._client.delete_collection_data(request)
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        score_threshold = kwargs.get("score_threshold") or 0.0
+        request = gpdb_20160503_models.QueryCollectionDataRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            namespace=self.config.namespace,
+            namespace_password=self.config.namespace_password,
+            collection=self._collection_name,
+            include_values=kwargs.pop("include_values", True),
+            metrics=self.config.metrics,
+            vector=query_vector,
+            content=None,
+            top_k=kwargs.get("top_k", 4),
+            filter=None,
+        )
+        response = self._client.query_collection_data(request)
+        documents = []
+        for match in response.body.matches.match:
+            if match.score > score_threshold:
+                metadata = json.loads(match.metadata.get("metadata_"))
+                metadata["score"] = match.score
+                doc = Document(
+                    page_content=match.metadata.get("page_content"),
+                    vector=match.values.value,
+                    metadata=metadata,
+                )
+                documents.append(doc)
+        documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
+        return documents
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
+        request = gpdb_20160503_models.QueryCollectionDataRequest(
+            dbinstance_id=self.config.instance_id,
+            region_id=self.config.region_id,
+            namespace=self.config.namespace,
+            namespace_password=self.config.namespace_password,
+            collection=self._collection_name,
+            include_values=kwargs.pop("include_values", True),
+            metrics=self.config.metrics,
+            vector=None,
+            content=query,
+            top_k=kwargs.get("top_k", 4),
+            filter=None,
+        )
+        response = self._client.query_collection_data(request)
+        documents = []
+        for match in response.body.matches.match:
+            if match.score > score_threshold:
+                metadata = json.loads(match.metadata.get("metadata_"))
+                metadata["score"] = match.score
+                doc = Document(
+                    page_content=match.metadata.get("page_content"),
+                    vector=match.values.value,
+                    metadata=metadata,
+                )
+                documents.append(doc)
+        documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
+        return documents
+
+    def delete(self) -> None:
+        try:
+            from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+
+            request = gpdb_20160503_models.DeleteCollectionRequest(
+                collection=self._collection_name,
+                dbinstance_id=self.config.instance_id,
+                namespace=self.config.namespace,
+                namespace_password=self.config.namespace_password,
+                region_id=self.config.region_id,
+            )
+            self._client.delete_collection(request)
+        except Exception as e:
+            raise e

+ 245 - 0
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py

@@ -0,0 +1,245 @@
+import json
+import uuid
+from contextlib import contextmanager
+from typing import Any
+
+import psycopg2.extras
+import psycopg2.pool
+from pydantic import BaseModel, model_validator
+
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+
+class AnalyticdbVectorBySqlConfig(BaseModel):
+    host: str
+    port: int
+    account: str
+    account_password: str
+    min_connection: int
+    max_connection: int
+    namespace: str = "dify"
+    metrics: str = "cosine"
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_config(cls, values: dict) -> dict:
+        if not values["host"]:
+            raise ValueError("config ANALYTICDB_HOST is required")
+        if not values["port"]:
+            raise ValueError("config ANALYTICDB_PORT is required")
+        if not values["account"]:
+            raise ValueError("config ANALYTICDB_ACCOUNT is required")
+        if not values["account_password"]:
+            raise ValueError("config ANALYTICDB_PASSWORD is required")
+        if not values["min_connection"]:
+            raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
+        if not values["max_connection"]:
+            raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
+        if values["min_connection"] > values["max_connection"]:
+            raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
+        return values
+
+
+class AnalyticdbVectorBySql:
+    def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
+        self._collection_name = collection_name.lower()
+        self.databaseName = "knowledgebase"
+        self.config = config
+        self.table_name = f"{self.config.namespace}.{self._collection_name}"
+        self.pool = None
+        self._initialize()
+        if not self.pool:
+            self.pool = self._create_connection_pool()
+
+    def _initialize(self) -> None:
+        cache_key = f"vector_initialize_{self.config.host}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            database_exist_cache_key = f"vector_initialize_{self.config.host}"
+            if redis_client.get(database_exist_cache_key):
+                return
+            self._initialize_vector_database()
+            redis_client.set(database_exist_cache_key, 1, ex=3600)
+
+    def _create_connection_pool(self):
+        return psycopg2.pool.SimpleConnectionPool(
+            self.config.min_connection,
+            self.config.max_connection,
+            host=self.config.host,
+            port=self.config.port,
+            user=self.config.account,
+            password=self.config.account_password,
+            database=self.databaseName,
+        )
+
+    @contextmanager
+    def _get_cursor(self):
+        conn = self.pool.getconn()
+        cur = conn.cursor()
+        try:
+            yield cur
+        finally:
+            cur.close()
+            conn.commit()
+            self.pool.putconn(conn)
+
+    def _initialize_vector_database(self) -> None:
+        conn = psycopg2.connect(
+            host=self.config.host,
+            port=self.config.port,
+            user=self.config.account,
+            password=self.config.account_password,
+            database="postgres",
+        )
+        conn.autocommit = True
+        cur = conn.cursor()
+        try:
+            cur.execute(f"CREATE DATABASE {self.databaseName}")
+        except Exception as e:
+            if "already exists" in str(e):
+                return
+            raise e
+        finally:
+            cur.close()
+            conn.close()
+        self.pool = self._create_connection_pool()
+        with self._get_cursor() as cur:
+            try:
+                cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
+                cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
+            except Exception as e:
+                if "already exists" not in str(e):
+                    raise e
+            cur.execute(
+                "CREATE OR REPLACE FUNCTION "
+                "public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
+                "RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
+                "SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
+                "FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
+                "AS words_only;$function$"
+            )
+            cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
+
+    def _create_collection_if_not_exists(self, embedding_dimension: int):
+        cache_key = f"vector_indexing_{self._collection_name}"
+        lock_name = f"{cache_key}_lock"
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+            if redis_client.get(collection_exist_cache_key):
+                return
+            with self._get_cursor() as cur:
+                cur.execute(
+                    f"CREATE TABLE IF NOT EXISTS {self.table_name}("
+                    f"id text PRIMARY KEY,"
+                    f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
+                    f"to_tsvector TSVECTOR"
+                    f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
+                )
+                if embedding_dimension is not None:
+                    index_name = f"{self._collection_name}_embedding_idx"
+                    cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
+                    cur.execute(
+                        f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
+                        f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
+                        f"pq_enable=0, external_storage=0)"
+                    )
+                    cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        values = []
+        id_prefix = str(uuid.uuid4()) + "_"
+        sql = f"""
+                INSERT INTO {self.table_name} 
+                (id, ref_doc_id, vector, page_content, metadata_, to_tsvector) 
+                VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn',  %s));
+            """
+        for i, doc in enumerate(documents):
+            values.append(
+                (
+                    id_prefix + str(i),
+                    doc.metadata.get("doc_id", str(uuid.uuid4())),
+                    embeddings[i],
+                    doc.page_content,
+                    json.dumps(doc.metadata),
+                    doc.page_content,
+                )
+            )
+        with self._get_cursor() as cur:
+            psycopg2.extras.execute_batch(cur, sql, values)
+
+    def text_exists(self, id: str) -> bool:
+        with self._get_cursor() as cur:
+            cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
+            return cur.fetchone() is not None
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        with self._get_cursor() as cur:
+            try:
+                cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
+            except Exception as e:
+                if "does not exist" not in str(e):
+                    raise e
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        with self._get_cursor() as cur:
+            try:
+                cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
+            except Exception as e:
+                if "does not exist" not in str(e):
+                    raise e
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        top_k = kwargs.get("top_k", 4)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
+        with self._get_cursor() as cur:
+            query_vector_str = json.dumps(query_vector)
+            query_vector_str = "{" + query_vector_str[1:-1] + "}"
+            cur.execute(
+                f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
+                f"t.page_content as page_content, t.metadata_ AS metadata_ "
+                f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
+                f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
+                (query_vector_str,),
+            )
+            documents = []
+            for record in cur:
+                id, vector, score, page_content, metadata = record
+                if score > score_threshold:
+                    metadata["score"] = score
+                    doc = Document(
+                        page_content=page_content,
+                        vector=vector,
+                        metadata=metadata,
+                    )
+                    documents.append(doc)
+        return documents
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        top_k = kwargs.get("top_k", 4)
+        with self._get_cursor() as cur:
+            cur.execute(
+                f"""SELECT id, vector, page_content, metadata_, 
+                ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
+                FROM {self.table_name}
+                WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
+                ORDER BY score DESC
+                LIMIT {top_k}""",
+                (f"'{query}'", f"'{query}'"),
+            )
+            documents = []
+            for record in cur:
+                id, vector, page_content, metadata, score = record
+                metadata["score"] = score
+                doc = Document(
+                    page_content=page_content,
+                    vector=vector,
+                    metadata=metadata,
+                )
+                documents.append(doc)
+        return documents
+
+    def delete(self) -> None:
+        with self._get_cursor() as cur:
+            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

+ 33 - 16
api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py

@@ -1,27 +1,43 @@
 from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
+from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
+from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
 from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
 
 
 class AnalyticdbVectorTest(AbstractVectorTest):
-    def __init__(self):
+    def __init__(self, config_type: str):
         super().__init__()
         # Analyticdb requires collection_name length less than 60.
         # it's ok for normal usage.
         self.collection_name = self.collection_name.replace("_test", "")
-        self.vector = AnalyticdbVector(
-            collection_name=self.collection_name,
-            config=AnalyticdbConfig(
-                access_key_id="test_key_id",
-                access_key_secret="test_key_secret",
-                region_id="test_region",
-                instance_id="test_id",
-                account="test_account",
-                account_password="test_passwd",
-                namespace="difytest_namespace",
-                collection="difytest_collection",
-                namespace_password="test_passwd",
-            ),
-        )
+        if config_type == "sql":
+            self.vector = AnalyticdbVector(
+                collection_name=self.collection_name,
+                sql_config=AnalyticdbVectorBySqlConfig(
+                    host="test_host",
+                    port=5432,
+                    account="test_account",
+                    account_password="test_passwd",
+                    namespace="difytest_namespace",
+                ),
+                api_config=None,
+            )
+        else:
+            self.vector = AnalyticdbVector(
+                collection_name=self.collection_name,
+                sql_config=None,
+                api_config=AnalyticdbVectorOpenAPIConfig(
+                    access_key_id="test_key_id",
+                    access_key_secret="test_key_secret",
+                    region_id="test_region",
+                    instance_id="test_id",
+                    account="test_account",
+                    account_password="test_passwd",
+                    namespace="difytest_namespace",
+                    collection="difytest_collection",
+                    namespace_password="test_passwd",
+                ),
+            )
 
     def run_all_tests(self):
         self.vector.delete()
@@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest):
 
 
 def test_chroma_vector(setup_mock_redis):
-    AnalyticdbVectorTest().run_all_tests()
+    AnalyticdbVectorTest("api").run_all_tests()
+    AnalyticdbVectorTest("sql").run_all_tests()

+ 4 - 0
docker/.env.example

@@ -450,6 +450,10 @@ ANALYTICDB_ACCOUNT=testaccount
 ANALYTICDB_PASSWORD=testpassword
 ANALYTICDB_NAMESPACE=dify
 ANALYTICDB_NAMESPACE_PASSWORD=difypassword
+ANALYTICDB_HOST=gp-test.aliyuncs.com
+ANALYTICDB_PORT=5432
+ANALYTICDB_MIN_CONNECTION=1
+ANALYTICDB_MAX_CONNECTION=5
 
 # TiDB vector configurations, only available when VECTOR_STORE is `tidb`
 TIDB_VECTOR_HOST=tidb

+ 4 - 0
docker/docker-compose.yaml

@@ -185,6 +185,10 @@ x-shared-env: &shared-api-worker-env
   ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-}
   ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
   ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-}
+  ANALYTICDB_HOST: ${ANALYTICDB_HOST:-}
+  ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
+  ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
+  ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
   OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
   OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
   OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}