|
@@ -0,0 +1,239 @@
|
|
|
+import json
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+from pydantic import BaseModel
|
|
|
+from volcengine.viking_db import (
|
|
|
+ Data,
|
|
|
+ DistanceType,
|
|
|
+ Field,
|
|
|
+ FieldType,
|
|
|
+ IndexType,
|
|
|
+ QuantType,
|
|
|
+ VectorIndexParams,
|
|
|
+ VikingDBService,
|
|
|
+)
|
|
|
+
|
|
|
+from configs import dify_config
|
|
|
+from core.rag.datasource.entity.embedding import Embeddings
|
|
|
+from core.rag.datasource.vdb.field import Field as vdb_Field
|
|
|
+from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
|
|
+from core.rag.datasource.vdb.vector_type import VectorType
|
|
|
+from core.rag.models.document import Document
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
+from models.dataset import Dataset
|
|
|
+
|
|
|
+
|
|
|
+class VikingDBConfig(BaseModel):
|
|
|
+ access_key: str
|
|
|
+ secret_key: str
|
|
|
+ host: str
|
|
|
+ region: str
|
|
|
+ scheme: str
|
|
|
+ connection_timeout: int
|
|
|
+ socket_timeout: int
|
|
|
+ index_type: str = IndexType.HNSW
|
|
|
+ distance: str = DistanceType.L2
|
|
|
+ quant: str = QuantType.Float
|
|
|
+
|
|
|
+
|
|
|
+class VikingDBVector(BaseVector):
|
|
|
+ def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig):
|
|
|
+ super().__init__(collection_name)
|
|
|
+ self._group_id = group_id
|
|
|
+ self._client_config = config
|
|
|
+ self._index_name = f"{self._collection_name}_idx"
|
|
|
+ self._client = VikingDBService(
|
|
|
+ host=config.host,
|
|
|
+ region=config.region,
|
|
|
+ scheme=config.scheme,
|
|
|
+ connection_timeout=config.connection_timeout,
|
|
|
+ socket_timeout=config.socket_timeout,
|
|
|
+ ak=config.access_key,
|
|
|
+ sk=config.secret_key,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _has_collection(self) -> bool:
|
|
|
+ try:
|
|
|
+ self._client.get_collection(self._collection_name)
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+
|
|
|
+ def _has_index(self) -> bool:
|
|
|
+ try:
|
|
|
+ self._client.get_index(self._collection_name, self._index_name)
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+
|
|
|
+ def _create_collection(self, dimension: int):
|
|
|
+ lock_name = f"vector_indexing_lock_{self._collection_name}"
|
|
|
+ with redis_client.lock(lock_name, timeout=20):
|
|
|
+ collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
|
|
+ if redis_client.get(collection_exist_cache_key):
|
|
|
+ return
|
|
|
+
|
|
|
+ if not self._has_collection():
|
|
|
+ fields = [
|
|
|
+ Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
|
|
|
+ Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
|
|
|
+ Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
|
|
|
+ Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
|
|
|
+ Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
|
|
|
+ ]
|
|
|
+
|
|
|
+ self._client.create_collection(
|
|
|
+ collection_name=self._collection_name,
|
|
|
+ fields=fields,
|
|
|
+ description="Collection For Dify",
|
|
|
+ )
|
|
|
+
|
|
|
+ if not self._has_index():
|
|
|
+ vector_index = VectorIndexParams(
|
|
|
+ distance=self._client_config.distance,
|
|
|
+ index_type=self._client_config.index_type,
|
|
|
+ quant=self._client_config.quant,
|
|
|
+ )
|
|
|
+
|
|
|
+ self._client.create_index(
|
|
|
+ collection_name=self._collection_name,
|
|
|
+ index_name=self._index_name,
|
|
|
+ vector_index=vector_index,
|
|
|
+ partition_by=vdb_Field.GROUP_KEY.value,
|
|
|
+ description="Index For Dify",
|
|
|
+ )
|
|
|
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
+
|
|
|
+ def get_type(self) -> str:
|
|
|
+ return VectorType.VIKINGDB
|
|
|
+
|
|
|
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
+ dimension = len(embeddings[0])
|
|
|
+ self._create_collection(dimension)
|
|
|
+ self.add_texts(texts, embeddings, **kwargs)
|
|
|
+
|
|
|
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
+ page_contents = [doc.page_content for doc in documents]
|
|
|
+ metadatas = [doc.metadata for doc in documents]
|
|
|
+ docs = []
|
|
|
+
|
|
|
+ for i, page_content in enumerate(page_contents):
|
|
|
+ metadata = {}
|
|
|
+ if metadatas is not None:
|
|
|
+ for key, val in metadatas[i].items():
|
|
|
+ metadata[key] = val
|
|
|
+ doc = Data(
|
|
|
+ {
|
|
|
+ vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
|
|
|
+ vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
|
|
|
+ vdb_Field.CONTENT_KEY.value: page_content,
|
|
|
+ vdb_Field.METADATA_KEY.value: json.dumps(metadata),
|
|
|
+ vdb_Field.GROUP_KEY.value: self._group_id,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ docs.append(doc)
|
|
|
+
|
|
|
+ self._client.get_collection(self._collection_name).upsert_data(docs)
|
|
|
+
|
|
|
+ def text_exists(self, id: str) -> bool:
|
|
|
+ docs = self._client.get_collection(self._collection_name).fetch_data(id)
|
|
|
+ not_exists_str = "data does not exist"
|
|
|
+ if docs is not None and not_exists_str not in docs.fields.get("message", ""):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def delete_by_ids(self, ids: list[str]) -> None:
|
|
|
+ self._client.get_collection(self._collection_name).delete_data(ids)
|
|
|
+
|
|
|
+ def get_ids_by_metadata_field(self, key: str, value: str):
|
|
|
+ # Note: Metadata field value is an dict, but vikingdb field
|
|
|
+ # not support json type
|
|
|
+ results = self._client.get_index(self._collection_name, self._index_name).search(
|
|
|
+ filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
|
|
|
+ # max value is 5000
|
|
|
+ limit=5000,
|
|
|
+ )
|
|
|
+
|
|
|
+ if not results:
|
|
|
+ return []
|
|
|
+
|
|
|
+ ids = []
|
|
|
+ for result in results:
|
|
|
+ metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
|
|
+ if metadata is not None:
|
|
|
+ metadata = json.loads(metadata)
|
|
|
+ if metadata.get(key) == value:
|
|
|
+ ids.append(result.id)
|
|
|
+ return ids
|
|
|
+
|
|
|
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
|
|
|
+ ids = self.get_ids_by_metadata_field(key, value)
|
|
|
+ self.delete_by_ids(ids)
|
|
|
+
|
|
|
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
+ results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
|
|
|
+ query_vector, limit=kwargs.get("top_k", 50)
|
|
|
+ )
|
|
|
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
|
|
+ return self._get_search_res(results, score_threshold)
|
|
|
+
|
|
|
+ def _get_search_res(self, results, score_threshold):
|
|
|
+ if len(results) == 0:
|
|
|
+ return []
|
|
|
+
|
|
|
+ docs = []
|
|
|
+ for result in results:
|
|
|
+ metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
|
|
+ if metadata is not None:
|
|
|
+ metadata = json.loads(metadata)
|
|
|
+ if result.score > score_threshold:
|
|
|
+ metadata["score"] = result.score
|
|
|
+ doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
|
|
+ docs.append(doc)
|
|
|
+ docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
|
|
+ return docs
|
|
|
+
|
|
|
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
+ return []
|
|
|
+
|
|
|
+ def delete(self) -> None:
|
|
|
+ if self._has_index():
|
|
|
+ self._client.drop_index(self._collection_name, self._index_name)
|
|
|
+ if self._has_collection():
|
|
|
+ self._client.drop_collection(self._collection_name)
|
|
|
+
|
|
|
+
|
|
|
+class VikingDBVectorFactory(AbstractVectorFactory):
|
|
|
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector:
|
|
|
+ if dataset.index_struct_dict:
|
|
|
+ class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
+ collection_name = class_prefix.lower()
|
|
|
+ else:
|
|
|
+ dataset_id = dataset.id
|
|
|
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
|
+ dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name))
|
|
|
+
|
|
|
+ if dify_config.VIKINGDB_ACCESS_KEY is None:
|
|
|
+ raise ValueError("VIKINGDB_ACCESS_KEY should not be None")
|
|
|
+ if dify_config.VIKINGDB_SECRET_KEY is None:
|
|
|
+ raise ValueError("VIKINGDB_SECRET_KEY should not be None")
|
|
|
+ if dify_config.VIKINGDB_HOST is None:
|
|
|
+ raise ValueError("VIKINGDB_HOST should not be None")
|
|
|
+ if dify_config.VIKINGDB_REGION is None:
|
|
|
+ raise ValueError("VIKINGDB_REGION should not be None")
|
|
|
+ if dify_config.VIKINGDB_SCHEME is None:
|
|
|
+ raise ValueError("VIKINGDB_SCHEME should not be None")
|
|
|
+ return VikingDBVector(
|
|
|
+ collection_name=collection_name,
|
|
|
+ group_id=dataset.id,
|
|
|
+ config=VikingDBConfig(
|
|
|
+ access_key=dify_config.VIKINGDB_ACCESS_KEY,
|
|
|
+ secret_key=dify_config.VIKINGDB_SECRET_KEY,
|
|
|
+ host=dify_config.VIKINGDB_HOST,
|
|
|
+ region=dify_config.VIKINGDB_REGION,
|
|
|
+ scheme=dify_config.VIKINGDB_SCHEME,
|
|
|
+ connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT,
|
|
|
+ socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT,
|
|
|
+ ),
|
|
|
+ )
|