Explorar o código

Lindorm vdb (#11574)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
Jiang hai 4 meses
pai
achega
0d04cdc323

+ 1 - 0
api/.env.example

@@ -294,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30
 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
 LINDORM_USERNAME=admin
 LINDORM_PASSWORD=admin
+USING_UGC_INDEX=False
 
 # OceanBase Vector configuration
 OCEANBASE_VECTOR_HOST=127.0.0.1

+ 11 - 0
api/configs/middleware/vdb/lindorm_config.py

@@ -21,3 +21,14 @@ class LindormConfig(BaseSettings):
         description="Lindorm password",
         default=None,
     )
+    DEFAULT_INDEX_TYPE: Optional[str] = Field(
+        description="Lindorm Vector Index Type, hnsw or flat is available in dify",
+        default="hnsw",
+    )
+    DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
+        description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
+    )
+    USING_UGC_INDEX: Optional[bool] = Field(
+        description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
+        default=False,
+    )

+ 119 - 120
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py

@@ -1,13 +1,10 @@
 import copy
 import json
 import logging
-from collections.abc import Iterable
 from typing import Any, Optional
 
 from opensearchpy import OpenSearch
-from opensearchpy.helpers import bulk
 from pydantic import BaseModel, model_validator
-from tenacity import retry, stop_after_attempt, wait_fixed
 
 from configs import dify_config
 from core.rag.datasource.vdb.field import Field
@@ -23,11 +20,15 @@ logger = logging.getLogger(__name__)
 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 logging.getLogger("lindorm").setLevel(logging.WARN)
 
+ROUTING_FIELD = "routing_field"
+UGC_INDEX_PREFIX = "ugc_index"
+
 
 class LindormVectorStoreConfig(BaseModel):
     hosts: str
     username: Optional[str] = None
     password: Optional[str] = None
+    using_ugc: Optional[bool] = False
 
     @model_validator(mode="before")
     @classmethod
@@ -41,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
         return values
 
     def to_opensearch_params(self) -> dict[str, Any]:
-        params = {
-            "hosts": self.hosts,
-        }
+        params = {"hosts": self.hosts}
         if self.username and self.password:
             params["http_auth"] = (self.username, self.password)
         return params
@@ -51,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel):
 
 class LindormVectorStore(BaseVector):
     def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
-        super().__init__(collection_name.lower())
+        self._routing = None
+        self._routing_field = None
+        if config.using_ugc:
+            routing_value: str = kwargs.get("routing_value")
+            if routing_value is None:
+                raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
+            self._routing = routing_value.lower()
+            self._routing_field = ROUTING_FIELD
+            ugc_index_name = collection_name
+            super().__init__(ugc_index_name.lower())
+        else:
+            super().__init__(collection_name.lower())
         self._client_config = config
         self._client = OpenSearch(**config.to_opensearch_params())
+        self._using_ugc = config.using_ugc
         self.kwargs = kwargs
 
     def get_type(self) -> str:
@@ -66,89 +77,37 @@ class LindormVectorStore(BaseVector):
     def refresh(self):
         self._client.indices.refresh(index=self._collection_name)
 
-    def __filter_existed_ids(
-        self,
-        texts: list[str],
-        metadatas: list[dict],
-        ids: list[str],
-        bulk_size: int = 1024,
-    ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
-        @retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
-        def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
-            try:
-                existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
-                return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
-            except Exception as e:
-                logger.exception(f"Error fetching batch {batch_ids}")
-                return set()
-
-        @retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
-        def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
-            try:
-                existing_docs = self._client.mget(
-                    body={
-                        "docs": [
-                            {"_index": self._collection_name, "_id": id, "routing": routing}
-                            for id, routing in zip(batch_ids, route_ids)
-                        ]
-                    },
-                    _source=False,
-                )
-                return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
-            except Exception as e:
-                logger.exception(f"Error fetching batch ids: {batch_ids}")
-                return set()
-
-        if ids is None:
-            return texts, metadatas, ids
-
-        if len(texts) != len(ids):
-            raise RuntimeError(f"texts {len(texts)} != {ids}")
-
-        filtered_texts = []
-        filtered_metadatas = []
-        filtered_ids = []
-
-        def batch(iterable, n):
-            length = len(iterable)
-            for idx in range(0, length, n):
-                yield iterable[idx : min(idx + n, length)]
-
-        for ids_batch, texts_batch, metadatas_batch in zip(
-            batch(ids, bulk_size),
-            batch(texts, bulk_size),
-            batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
-        ):
-            existing_ids_set = __fetch_existing_ids(ids_batch)
-            for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
-                if doc_id not in existing_ids_set:
-                    filtered_texts.append(text)
-                    filtered_ids.append(doc_id)
-                    if metadatas is not None:
-                        filtered_metadatas.append(metadata)
-
-        return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
-
     def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
         actions = []
         uuids = self._get_uuids(documents)
         for i in range(len(documents)):
-            action = {
-                "_op_type": "index",
-                "_index": self._collection_name.lower(),
-                "_id": uuids[i],
-                "_source": {
-                    Field.CONTENT_KEY.value: documents[i].page_content,
-                    Field.VECTOR.value: embeddings[i],  # Make sure you pass an array here
-                    Field.METADATA_KEY.value: documents[i].metadata,
-                },
+            action_header = {
+                "index": {
+                    "_index": self.collection_name.lower(),
+                    "_id": uuids[i],
+                }
+            }
+            action_values = {
+                Field.CONTENT_KEY.value: documents[i].page_content,
+                Field.VECTOR.value: embeddings[i],  # Make sure you pass an array here
+                Field.METADATA_KEY.value: documents[i].metadata,
             }
-            actions.append(action)
-        bulk(self._client, actions)
-        self.refresh()
+            if self._using_ugc:
+                action_header["index"]["routing"] = self._routing
+                action_values[self._routing_field] = self._routing
+            actions.append(action_header)
+            actions.append(action_values)
+        response = self._client.bulk(actions)
+        if response["errors"]:
+            for item in response["items"]:
+                print(f"{item['index']['status']}: {item['index']['error']['type']}")
+        else:
+            self.refresh()
 
     def get_ids_by_metadata_field(self, key: str, value: str):
-        query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
+        query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
+        if self._using_ugc:
+            query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
         response = self._client.search(index=self._collection_name, body=query)
         if response["hits"]["hits"]:
             return [hit["_id"] for hit in response["hits"]["hits"]]
@@ -156,50 +115,62 @@ class LindormVectorStore(BaseVector):
             return None
 
     def delete_by_metadata_field(self, key: str, value: str):
-        query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
-        results = self._client.search(index=self._collection_name, body=query_str)
-        ids = [hit["_id"] for hit in results["hits"]["hits"]]
+        ids = self.get_ids_by_metadata_field(key, value)
         if ids:
             self.delete_by_ids(ids)
 
     def delete_by_ids(self, ids: list[str]) -> None:
+        params = {}
+        if self._using_ugc:
+            params["routing"] = self._routing
         for id in ids:
-            if self._client.exists(index=self._collection_name, id=id):
-                self._client.delete(index=self._collection_name, id=id)
+            if self._client.exists(index=self._collection_name, id=id, params=params):
+                params = {}
+                if self._using_ugc:
+                    params["routing"] = self._routing
+                self._client.delete(index=self._collection_name, id=id, params=params)
+                self.refresh()
             else:
                 logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
 
     def delete(self) -> None:
-        try:
+        if self._using_ugc:
+            routing_filter_query = {
+                "query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
+            }
+            self._client.delete_by_query(self._collection_name, body=routing_filter_query)
+            self.refresh()
+        else:
             if self._client.indices.exists(index=self._collection_name):
                 self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
                 logger.info("Delete index success")
             else:
                 logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
-        except Exception as e:
-            logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
-            raise e
 
     def text_exists(self, id: str) -> bool:
         try:
-            self._client.get(index=self._collection_name, id=id)
+            params = {}
+            if self._using_ugc:
+                params["routing"] = self._routing
+            self._client.get(index=self._collection_name, id=id, params=params)
             return True
         except:
             return False
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        # Make sure query_vector is a list
         if not isinstance(query_vector, list):
             raise ValueError("query_vector should be a list of floats")
 
-        # Check whether query_vector is a floating-point number list
         if not all(isinstance(x, float) for x in query_vector):
             raise ValueError("All elements in query_vector should be floats")
 
         top_k = kwargs.get("top_k", 10)
         query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
         try:
-            response = self._client.search(index=self._collection_name, body=query)
+            params = {}
+            if self._using_ugc:
+                params["routing"] = self._routing
+            response = self._client.search(index=self._collection_name, body=query, params=params)
         except Exception as e:
             logger.exception(f"Error executing vector search, query: {query}")
             raise
@@ -232,7 +203,7 @@ class LindormVectorStore(BaseVector):
         minimum_should_match = kwargs.get("minimum_should_match", 0)
         top_k = kwargs.get("top_k", 10)
         filters = kwargs.get("filter")
-        routing = kwargs.get("routing")
+        routing = self._routing
         full_text_query = default_text_search_query(
             query_text=query,
             k=top_k,
@@ -243,6 +214,7 @@ class LindormVectorStore(BaseVector):
             minimum_should_match=minimum_should_match,
             filters=filters,
             routing=routing,
+            routing_field=self._routing_field,
         )
         response = self._client.search(index=self._collection_name, body=full_text_query)
         docs = []
@@ -265,17 +237,18 @@ class LindormVectorStore(BaseVector):
                 logger.info(f"Collection {self._collection_name} already exists.")
                 return
             if self._client.indices.exists(index=self._collection_name):
-                logger.info("{self._collection_name.lower()} already exists.")
+                logger.info(f"{self._collection_name.lower()} already exists.")
+                redis_client.set(collection_exist_cache_key, 1, ex=3600)
                 return
             if len(self.kwargs) == 0 and len(kwargs) != 0:
                 self.kwargs = copy.deepcopy(kwargs)
             vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
-            shards = kwargs.pop("shards", 2)
+            shards = kwargs.pop("shards", 4)
 
             engine = kwargs.pop("engine", "lvector")
-            method_name = kwargs.pop("method_name", "hnsw")
+            method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
+            space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
             data_type = kwargs.pop("data_type", "float")
-            space_type = kwargs.pop("space_type", "cosinesimil")
 
             hnsw_m = kwargs.pop("hnsw_m", 24)
             hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
@@ -288,10 +261,10 @@ class LindormVectorStore(BaseVector):
             mapping = default_text_mapping(
                 dimension,
                 method_name,
+                space_type=space_type,
                 shards=shards,
                 engine=engine,
                 data_type=data_type,
-                space_type=space_type,
                 vector_field=vector_field,
                 hnsw_m=hnsw_m,
                 hnsw_ef_construction=hnsw_ef_construction,
@@ -301,6 +274,7 @@ class LindormVectorStore(BaseVector):
                 centroids_hnsw_m=centroids_hnsw_m,
                 centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
                 centroids_hnsw_ef_search=centroids_hnsw_ef_search,
+                using_ugc=self._using_ugc,
                 **kwargs,
             )
             self._client.indices.create(index=self._collection_name.lower(), body=mapping)
@@ -309,15 +283,20 @@ class LindormVectorStore(BaseVector):
 
 
 def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
-    routing_field = kwargs.get("routing_field")
     excludes_from_source = kwargs.get("excludes_from_source")
     analyzer = kwargs.get("analyzer", "ik_max_word")
     text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
     engine = kwargs["engine"]
     shard = kwargs["shards"]
-    space_type = kwargs["space_type"]
+    space_type = kwargs.get("space_type")
+    if space_type is None:
+        if method_name == "hnsw":
+            space_type = "l2"
+        else:
+            space_type = "cosine"
     data_type = kwargs["data_type"]
     vector_field = kwargs.get("vector_field", Field.VECTOR.value)
+    using_ugc = kwargs.get("using_ugc", False)
 
     if method_name == "ivfpq":
         ivfpq_m = kwargs["ivfpq_m"]
@@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
     if excludes_from_source:
         mapping["mappings"]["_source"] = {"excludes": excludes_from_source}  # e.g. {"excludes": ["vector_field"]}
 
-    if method_name == "ivfpq" and routing_field is not None:
+    if using_ugc and method_name == "ivfpq":
         mapping["settings"]["index"]["knn_routing"] = True
         mapping["settings"]["index"]["knn.offline.construction"] = True
-
-    if method_name == "flat" and routing_field is not None:
+    elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
         mapping["settings"]["index"]["knn_routing"] = True
-
     return mapping
 
 
@@ -386,14 +363,12 @@ def default_text_search_query(
     minimum_should_match: int = 0,
     filters: Optional[list[dict]] = None,
     routing: Optional[str] = None,
+    routing_field: Optional[str] = None,
     **kwargs,
 ) -> dict:
     if routing is not None:
-        routing_field = kwargs.get("routing_field", "routing_field")
         query_clause = {
-            "bool": {
-                "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
-            }
+            "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
         }
     else:
         query_clause = {"match": {text_field: query_text}}
@@ -483,16 +458,40 @@ def default_vector_search_query(
 
 class LindormVectorStoreFactory(AbstractVectorFactory):
     def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
-        if dataset.index_struct_dict:
-            class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
-            collection_name = class_prefix
-        else:
-            dataset_id = dataset.id
-            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-            dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
         lindorm_config = LindormVectorStoreConfig(
             hosts=dify_config.LINDORM_URL,
             username=dify_config.LINDORM_USERNAME,
             password=dify_config.LINDORM_PASSWORD,
+            using_ugc=dify_config.USING_UGC_INDEX,
         )
-        return LindormVectorStore(collection_name, lindorm_config)
+        using_ugc = dify_config.USING_UGC_INDEX
+        routing_value = None
+        if dataset.index_struct:
+            if using_ugc:
+                dimension = dataset.index_struct_dict["dimension"]
+                index_type = dataset.index_struct_dict["index_type"]
+                distance_type = dataset.index_struct_dict["distance_type"]
+                index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
+                routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
+            else:
+                index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
+        else:
+            embedding_vector = embeddings.embed_query("hello word")
+            dimension = len(embedding_vector)
+            index_type = dify_config.DEFAULT_INDEX_TYPE
+            distance_type = dify_config.DEFAULT_DISTANCE_TYPE
+            class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
+            index_struct_dict = {
+                "type": VectorType.LINDORM,
+                "vector_store": {"class_prefix": class_prefix},
+                "index_type": index_type,
+                "dimension": dimension,
+                "distance_type": distance_type,
+            }
+            dataset.index_struct = json.dumps(index_struct_dict)
+            if using_ugc:
+                index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
+                routing_value = class_prefix
+            else:
+                index_name = class_prefix
+        return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)

+ 26 - 3
api/tests/integration_tests/vdb/lindorm/test_lindorm.py

@@ -7,9 +7,10 @@ env = environs.Env()
 
 
 class Config:
-    SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070")
+    SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
     SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
-    SEARCH_PWD = env.str("SEARCH_PWD", "PWD")
+    SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
+    USING_UGC = env.bool("USING_UGC", True)
 
 
 class TestLindormVectorStore(AbstractVectorTest):
@@ -31,5 +32,27 @@ class TestLindormVectorStore(AbstractVectorTest):
         assert ids[0] == self.example_doc_id
 
 
-def test_lindorm_vector(setup_mock_redis):
+class TestLindormVectorStoreUGC(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = LindormVectorStore(
+            collection_name="ugc_index_test",
+            config=LindormVectorStoreConfig(
+                hosts=Config.SEARCH_ENDPOINT,
+                username=Config.SEARCH_USERNAME,
+                password=Config.SEARCH_PWD,
+                using_ugc=Config.USING_UGC,
+            ),
+            routing_value=self.collection_name,
+        )
+
+    def get_ids_by_metadata_field(self):
+        ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
+        assert ids is not None
+        assert len(ids) == 1
+        assert ids[0] == self.example_doc_id
+
+
+def test_lindorm_vector_ugc(setup_mock_redis):
     TestLindormVectorStore().run_all_tests()
+    TestLindormVectorStoreUGC().run_all_tests()