|
@@ -1,13 +1,10 @@
|
|
|
import copy
|
|
|
import json
|
|
|
import logging
|
|
|
-from collections.abc import Iterable
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
from opensearchpy import OpenSearch
|
|
|
-from opensearchpy.helpers import bulk
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
-from tenacity import retry, stop_after_attempt, wait_fixed
|
|
|
|
|
|
from configs import dify_config
|
|
|
from core.rag.datasource.vdb.field import Field
|
|
@@ -23,11 +20,15 @@ logger = logging.getLogger(__name__)
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
logging.getLogger("lindorm").setLevel(logging.WARN)
|
|
|
|
|
|
+ROUTING_FIELD = "routing_field"
|
|
|
+UGC_INDEX_PREFIX = "ugc_index"
|
|
|
+
|
|
|
|
|
|
class LindormVectorStoreConfig(BaseModel):
|
|
|
hosts: str
|
|
|
username: Optional[str] = None
|
|
|
password: Optional[str] = None
|
|
|
+ using_ugc: Optional[bool] = False
|
|
|
|
|
|
@model_validator(mode="before")
|
|
|
@classmethod
|
|
@@ -41,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
|
|
|
return values
|
|
|
|
|
|
def to_opensearch_params(self) -> dict[str, Any]:
|
|
|
- params = {
|
|
|
- "hosts": self.hosts,
|
|
|
- }
|
|
|
+ params = {"hosts": self.hosts}
|
|
|
if self.username and self.password:
|
|
|
params["http_auth"] = (self.username, self.password)
|
|
|
return params
|
|
@@ -51,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel):
|
|
|
|
|
|
class LindormVectorStore(BaseVector):
|
|
|
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
|
|
- super().__init__(collection_name.lower())
|
|
|
+ self._routing = None
|
|
|
+ self._routing_field = None
|
|
|
+ if config.using_ugc:
|
|
|
+ routing_value: str = kwargs.get("routing_value")
|
|
|
+ if routing_value is None:
|
|
|
+ raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
|
|
+ self._routing = routing_value.lower()
|
|
|
+ self._routing_field = ROUTING_FIELD
|
|
|
+ ugc_index_name = collection_name
|
|
|
+ super().__init__(ugc_index_name.lower())
|
|
|
+ else:
|
|
|
+ super().__init__(collection_name.lower())
|
|
|
self._client_config = config
|
|
|
self._client = OpenSearch(**config.to_opensearch_params())
|
|
|
+ self._using_ugc = config.using_ugc
|
|
|
self.kwargs = kwargs
|
|
|
|
|
|
def get_type(self) -> str:
|
|
@@ -66,89 +77,37 @@ class LindormVectorStore(BaseVector):
|
|
|
def refresh(self):
|
|
|
self._client.indices.refresh(index=self._collection_name)
|
|
|
|
|
|
- def __filter_existed_ids(
|
|
|
- self,
|
|
|
- texts: list[str],
|
|
|
- metadatas: list[dict],
|
|
|
- ids: list[str],
|
|
|
- bulk_size: int = 1024,
|
|
|
- ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
|
|
|
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
|
|
- def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
|
|
|
- try:
|
|
|
- existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
|
|
|
- return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
|
|
- except Exception as e:
|
|
|
- logger.exception(f"Error fetching batch {batch_ids}")
|
|
|
- return set()
|
|
|
-
|
|
|
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
|
|
- def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
|
|
|
- try:
|
|
|
- existing_docs = self._client.mget(
|
|
|
- body={
|
|
|
- "docs": [
|
|
|
- {"_index": self._collection_name, "_id": id, "routing": routing}
|
|
|
- for id, routing in zip(batch_ids, route_ids)
|
|
|
- ]
|
|
|
- },
|
|
|
- _source=False,
|
|
|
- )
|
|
|
- return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
|
|
- except Exception as e:
|
|
|
- logger.exception(f"Error fetching batch ids: {batch_ids}")
|
|
|
- return set()
|
|
|
-
|
|
|
- if ids is None:
|
|
|
- return texts, metadatas, ids
|
|
|
-
|
|
|
- if len(texts) != len(ids):
|
|
|
- raise RuntimeError(f"texts {len(texts)} != {ids}")
|
|
|
-
|
|
|
- filtered_texts = []
|
|
|
- filtered_metadatas = []
|
|
|
- filtered_ids = []
|
|
|
-
|
|
|
- def batch(iterable, n):
|
|
|
- length = len(iterable)
|
|
|
- for idx in range(0, length, n):
|
|
|
- yield iterable[idx : min(idx + n, length)]
|
|
|
-
|
|
|
- for ids_batch, texts_batch, metadatas_batch in zip(
|
|
|
- batch(ids, bulk_size),
|
|
|
- batch(texts, bulk_size),
|
|
|
- batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
|
|
|
- ):
|
|
|
- existing_ids_set = __fetch_existing_ids(ids_batch)
|
|
|
- for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
|
|
|
- if doc_id not in existing_ids_set:
|
|
|
- filtered_texts.append(text)
|
|
|
- filtered_ids.append(doc_id)
|
|
|
- if metadatas is not None:
|
|
|
- filtered_metadatas.append(metadata)
|
|
|
-
|
|
|
- return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
|
|
|
-
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
actions = []
|
|
|
uuids = self._get_uuids(documents)
|
|
|
for i in range(len(documents)):
|
|
|
- action = {
|
|
|
- "_op_type": "index",
|
|
|
- "_index": self._collection_name.lower(),
|
|
|
- "_id": uuids[i],
|
|
|
- "_source": {
|
|
|
- Field.CONTENT_KEY.value: documents[i].page_content,
|
|
|
- Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
|
|
- Field.METADATA_KEY.value: documents[i].metadata,
|
|
|
- },
|
|
|
+ action_header = {
|
|
|
+ "index": {
|
|
|
+ "_index": self.collection_name.lower(),
|
|
|
+ "_id": uuids[i],
|
|
|
+ }
|
|
|
+ }
|
|
|
+ action_values = {
|
|
|
+ Field.CONTENT_KEY.value: documents[i].page_content,
|
|
|
+ Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
|
|
+ Field.METADATA_KEY.value: documents[i].metadata,
|
|
|
}
|
|
|
- actions.append(action)
|
|
|
- bulk(self._client, actions)
|
|
|
- self.refresh()
|
|
|
+ if self._using_ugc:
|
|
|
+ action_header["index"]["routing"] = self._routing
|
|
|
+ action_values[self._routing_field] = self._routing
|
|
|
+ actions.append(action_header)
|
|
|
+ actions.append(action_values)
|
|
|
+ response = self._client.bulk(actions)
|
|
|
+ if response["errors"]:
|
|
|
+ for item in response["items"]:
|
|
|
+ print(f"{item['index']['status']}: {item['index']['error']['type']}")
|
|
|
+ else:
|
|
|
+ self.refresh()
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
|
|
- query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
|
|
|
+ query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
|
|
|
+ if self._using_ugc:
|
|
|
+ query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
|
|
|
response = self._client.search(index=self._collection_name, body=query)
|
|
|
if response["hits"]["hits"]:
|
|
|
return [hit["_id"] for hit in response["hits"]["hits"]]
|
|
@@ -156,50 +115,62 @@ class LindormVectorStore(BaseVector):
|
|
|
return None
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str):
|
|
|
- query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
|
|
- results = self._client.search(index=self._collection_name, body=query_str)
|
|
|
- ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
|
|
+ ids = self.get_ids_by_metadata_field(key, value)
|
|
|
if ids:
|
|
|
self.delete_by_ids(ids)
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
+ params = {}
|
|
|
+ if self._using_ugc:
|
|
|
+ params["routing"] = self._routing
|
|
|
for id in ids:
|
|
|
- if self._client.exists(index=self._collection_name, id=id):
|
|
|
- self._client.delete(index=self._collection_name, id=id)
|
|
|
+ if self._client.exists(index=self._collection_name, id=id, params=params):
|
|
|
+ params = {}
|
|
|
+ if self._using_ugc:
|
|
|
+ params["routing"] = self._routing
|
|
|
+ self._client.delete(index=self._collection_name, id=id, params=params)
|
|
|
+ self.refresh()
|
|
|
else:
|
|
|
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
|
|
|
|
|
def delete(self) -> None:
|
|
|
- try:
|
|
|
+ if self._using_ugc:
|
|
|
+ routing_filter_query = {
|
|
|
+ "query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
|
|
|
+ }
|
|
|
+ self._client.delete_by_query(self._collection_name, body=routing_filter_query)
|
|
|
+ self.refresh()
|
|
|
+ else:
|
|
|
if self._client.indices.exists(index=self._collection_name):
|
|
|
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
|
|
logger.info("Delete index success")
|
|
|
else:
|
|
|
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
|
|
- except Exception as e:
|
|
|
- logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
|
|
|
- raise e
|
|
|
|
|
|
def text_exists(self, id: str) -> bool:
|
|
|
try:
|
|
|
- self._client.get(index=self._collection_name, id=id)
|
|
|
+ params = {}
|
|
|
+ if self._using_ugc:
|
|
|
+ params["routing"] = self._routing
|
|
|
+ self._client.get(index=self._collection_name, id=id, params=params)
|
|
|
return True
|
|
|
except:
|
|
|
return False
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
- # Make sure query_vector is a list
|
|
|
if not isinstance(query_vector, list):
|
|
|
raise ValueError("query_vector should be a list of floats")
|
|
|
|
|
|
- # Check whether query_vector is a floating-point number list
|
|
|
if not all(isinstance(x, float) for x in query_vector):
|
|
|
raise ValueError("All elements in query_vector should be floats")
|
|
|
|
|
|
top_k = kwargs.get("top_k", 10)
|
|
|
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
|
|
try:
|
|
|
- response = self._client.search(index=self._collection_name, body=query)
|
|
|
+ params = {}
|
|
|
+ if self._using_ugc:
|
|
|
+ params["routing"] = self._routing
|
|
|
+ response = self._client.search(index=self._collection_name, body=query, params=params)
|
|
|
except Exception as e:
|
|
|
logger.exception(f"Error executing vector search, query: {query}")
|
|
|
raise
|
|
@@ -232,7 +203,7 @@ class LindormVectorStore(BaseVector):
|
|
|
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
|
|
top_k = kwargs.get("top_k", 10)
|
|
|
filters = kwargs.get("filter")
|
|
|
- routing = kwargs.get("routing")
|
|
|
+ routing = self._routing
|
|
|
full_text_query = default_text_search_query(
|
|
|
query_text=query,
|
|
|
k=top_k,
|
|
@@ -243,6 +214,7 @@ class LindormVectorStore(BaseVector):
|
|
|
minimum_should_match=minimum_should_match,
|
|
|
filters=filters,
|
|
|
routing=routing,
|
|
|
+ routing_field=self._routing_field,
|
|
|
)
|
|
|
response = self._client.search(index=self._collection_name, body=full_text_query)
|
|
|
docs = []
|
|
@@ -265,17 +237,18 @@ class LindormVectorStore(BaseVector):
|
|
|
logger.info(f"Collection {self._collection_name} already exists.")
|
|
|
return
|
|
|
if self._client.indices.exists(index=self._collection_name):
|
|
|
- logger.info("{self._collection_name.lower()} already exists.")
|
|
|
+ logger.info(f"{self._collection_name.lower()} already exists.")
|
|
|
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
return
|
|
|
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
|
|
self.kwargs = copy.deepcopy(kwargs)
|
|
|
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
|
|
- shards = kwargs.pop("shards", 2)
|
|
|
+ shards = kwargs.pop("shards", 4)
|
|
|
|
|
|
engine = kwargs.pop("engine", "lvector")
|
|
|
- method_name = kwargs.pop("method_name", "hnsw")
|
|
|
+ method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
|
|
|
+ space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
|
|
|
data_type = kwargs.pop("data_type", "float")
|
|
|
- space_type = kwargs.pop("space_type", "cosinesimil")
|
|
|
|
|
|
hnsw_m = kwargs.pop("hnsw_m", 24)
|
|
|
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
|
@@ -288,10 +261,10 @@ class LindormVectorStore(BaseVector):
|
|
|
mapping = default_text_mapping(
|
|
|
dimension,
|
|
|
method_name,
|
|
|
+ space_type=space_type,
|
|
|
shards=shards,
|
|
|
engine=engine,
|
|
|
data_type=data_type,
|
|
|
- space_type=space_type,
|
|
|
vector_field=vector_field,
|
|
|
hnsw_m=hnsw_m,
|
|
|
hnsw_ef_construction=hnsw_ef_construction,
|
|
@@ -301,6 +274,7 @@ class LindormVectorStore(BaseVector):
|
|
|
centroids_hnsw_m=centroids_hnsw_m,
|
|
|
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
|
|
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
|
|
+ using_ugc=self._using_ugc,
|
|
|
**kwargs,
|
|
|
)
|
|
|
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
|
@@ -309,15 +283,20 @@ class LindormVectorStore(BaseVector):
|
|
|
|
|
|
|
|
|
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
|
|
|
- routing_field = kwargs.get("routing_field")
|
|
|
excludes_from_source = kwargs.get("excludes_from_source")
|
|
|
analyzer = kwargs.get("analyzer", "ik_max_word")
|
|
|
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
|
|
engine = kwargs["engine"]
|
|
|
shard = kwargs["shards"]
|
|
|
- space_type = kwargs["space_type"]
|
|
|
+ space_type = kwargs.get("space_type")
|
|
|
+ if space_type is None:
|
|
|
+ if method_name == "hnsw":
|
|
|
+ space_type = "l2"
|
|
|
+ else:
|
|
|
+ space_type = "cosine"
|
|
|
data_type = kwargs["data_type"]
|
|
|
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
|
|
+ using_ugc = kwargs.get("using_ugc", False)
|
|
|
|
|
|
if method_name == "ivfpq":
|
|
|
ivfpq_m = kwargs["ivfpq_m"]
|
|
@@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
|
|
|
if excludes_from_source:
|
|
|
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
|
|
|
|
|
- if method_name == "ivfpq" and routing_field is not None:
|
|
|
+ if using_ugc and method_name == "ivfpq":
|
|
|
mapping["settings"]["index"]["knn_routing"] = True
|
|
|
mapping["settings"]["index"]["knn.offline.construction"] = True
|
|
|
-
|
|
|
- if method_name == "flat" and routing_field is not None:
|
|
|
+ elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
|
|
|
mapping["settings"]["index"]["knn_routing"] = True
|
|
|
-
|
|
|
return mapping
|
|
|
|
|
|
|
|
@@ -386,14 +363,12 @@ def default_text_search_query(
|
|
|
minimum_should_match: int = 0,
|
|
|
filters: Optional[list[dict]] = None,
|
|
|
routing: Optional[str] = None,
|
|
|
+ routing_field: Optional[str] = None,
|
|
|
**kwargs,
|
|
|
) -> dict:
|
|
|
if routing is not None:
|
|
|
- routing_field = kwargs.get("routing_field", "routing_field")
|
|
|
query_clause = {
|
|
|
- "bool": {
|
|
|
- "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
|
|
|
- }
|
|
|
+ "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
|
|
|
}
|
|
|
else:
|
|
|
query_clause = {"match": {text_field: query_text}}
|
|
@@ -483,16 +458,40 @@ def default_vector_search_query(
|
|
|
|
|
|
class LindormVectorStoreFactory(AbstractVectorFactory):
|
|
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
|
|
- if dataset.index_struct_dict:
|
|
|
- class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
- collection_name = class_prefix
|
|
|
- else:
|
|
|
- dataset_id = dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
- dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
|
|
|
lindorm_config = LindormVectorStoreConfig(
|
|
|
hosts=dify_config.LINDORM_URL,
|
|
|
username=dify_config.LINDORM_USERNAME,
|
|
|
password=dify_config.LINDORM_PASSWORD,
|
|
|
+ using_ugc=dify_config.USING_UGC_INDEX,
|
|
|
)
|
|
|
- return LindormVectorStore(collection_name, lindorm_config)
|
|
|
+ using_ugc = dify_config.USING_UGC_INDEX
|
|
|
+ routing_value = None
|
|
|
+ if dataset.index_struct:
|
|
|
+ if using_ugc:
|
|
|
+ dimension = dataset.index_struct_dict["dimension"]
|
|
|
+ index_type = dataset.index_struct_dict["index_type"]
|
|
|
+ distance_type = dataset.index_struct_dict["distance_type"]
|
|
|
+ index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
|
|
+ routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
+ else:
|
|
|
+ index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
+ else:
|
|
|
+ embedding_vector = embeddings.embed_query("hello word")
|
|
|
+ dimension = len(embedding_vector)
|
|
|
+ index_type = dify_config.DEFAULT_INDEX_TYPE
|
|
|
+ distance_type = dify_config.DEFAULT_DISTANCE_TYPE
|
|
|
+ class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
|
|
|
+ index_struct_dict = {
|
|
|
+ "type": VectorType.LINDORM,
|
|
|
+ "vector_store": {"class_prefix": class_prefix},
|
|
|
+ "index_type": index_type,
|
|
|
+ "dimension": dimension,
|
|
|
+ "distance_type": distance_type,
|
|
|
+ }
|
|
|
+ dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
+ if using_ugc:
|
|
|
+ index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
|
|
+ routing_value = class_prefix
|
|
|
+ else:
|
|
|
+ index_name = class_prefix
|
|
|
+ return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)
|