123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- import os
- from typing import Union
- from unittest.mock import MagicMock
- import pytest
- from _pytest.monkeypatch import MonkeyPatch
- from volcengine.viking_db import ( # type: ignore
- Collection,
- Data,
- DistanceType,
- Field,
- FieldType,
- Index,
- IndexType,
- QuantType,
- VectorIndexParams,
- VikingDBService,
- )
- from core.rag.datasource.vdb.field import Field as vdb_Field
- class MockVikingDBClass:
- def __init__(
- self,
- host="api-vikingdb.volces.com",
- region="cn-north-1",
- ak="",
- sk="",
- scheme="http",
- connection_timeout=30,
- socket_timeout=30,
- proxy=None,
- ):
- self._viking_db_service = MagicMock()
- self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
- def get_collection(self, collection_name) -> Collection:
- return Collection(
- collection_name=collection_name,
- description="Collection For Dify",
- viking_db_service=self._viking_db_service,
- primary_key=vdb_Field.PRIMARY_KEY.value,
- fields=[
- Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
- Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
- Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
- Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
- Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
- ],
- indexes=[
- Index(
- collection_name=collection_name,
- index_name=f"{collection_name}_idx",
- vector_index=VectorIndexParams(
- distance=DistanceType.L2,
- index_type=IndexType.HNSW,
- quant=QuantType.Float,
- ),
- scalar_index=None,
- stat=None,
- viking_db_service=self._viking_db_service,
- )
- ],
- )
- def drop_collection(self, collection_name):
- assert collection_name != ""
- def create_collection(self, collection_name, fields, description="") -> Collection:
- return Collection(
- collection_name=collection_name,
- description=description,
- primary_key=vdb_Field.PRIMARY_KEY.value,
- viking_db_service=self._viking_db_service,
- fields=fields,
- )
- def get_index(self, collection_name, index_name) -> Index:
- return Index(
- collection_name=collection_name,
- index_name=index_name,
- viking_db_service=self._viking_db_service,
- stat=None,
- scalar_index=None,
- vector_index=VectorIndexParams(
- distance=DistanceType.L2,
- index_type=IndexType.HNSW,
- quant=QuantType.Float,
- ),
- )
- def create_index(
- self,
- collection_name,
- index_name,
- vector_index=None,
- cpu_quota=2,
- description="",
- partition_by="",
- scalar_index=None,
- shard_count=None,
- shard_policy=None,
- ):
- return Index(
- collection_name=collection_name,
- index_name=index_name,
- vector_index=vector_index,
- cpu_quota=cpu_quota,
- description=description,
- partition_by=partition_by,
- scalar_index=scalar_index,
- shard_count=shard_count,
- shard_policy=shard_policy,
- viking_db_service=self._viking_db_service,
- stat=None,
- )
- def drop_index(self, collection_name, index_name):
- assert collection_name != ""
- assert index_name != ""
- def upsert_data(self, data: Union[Data, list[Data]]):
- assert data is not None
- def fetch_data(self, id: Union[str, list[str], int, list[int]]):
- return Data(
- fields={
- vdb_Field.GROUP_KEY.value: "test_group",
- vdb_Field.METADATA_KEY.value: "{}",
- vdb_Field.CONTENT_KEY.value: "content",
- vdb_Field.PRIMARY_KEY.value: id,
- vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
- },
- id=id,
- )
- def delete_data(self, id: Union[str, list[str], int, list[int]]):
- assert id is not None
- def search_by_vector(
- self,
- vector,
- sparse_vectors=None,
- filter=None,
- limit=10,
- output_fields=None,
- partition="default",
- dense_weight=None,
- ) -> list[Data]:
- return [
- Data(
- fields={
- vdb_Field.GROUP_KEY.value: "test_group",
- vdb_Field.METADATA_KEY.value: '\
- {"source": "/var/folders/ml/xxx/xxx.txt", \
- "document_id": "test_document_id", \
- "dataset_id": "test_dataset_id", \
- "doc_id": "test_id", \
- "doc_hash": "test_hash"}',
- vdb_Field.CONTENT_KEY.value: "content",
- vdb_Field.PRIMARY_KEY.value: "test_id",
- vdb_Field.VECTOR.value: vector,
- },
- id="test_id",
- score=0.10,
- )
- ]
- def search(
- self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
- ) -> list[Data]:
- return [
- Data(
- fields={
- vdb_Field.GROUP_KEY.value: "test_group",
- vdb_Field.METADATA_KEY.value: '\
- {"source": "/var/folders/ml/xxx/xxx.txt", \
- "document_id": "test_document_id", \
- "dataset_id": "test_dataset_id", \
- "doc_id": "test_id", \
- "doc_hash": "test_hash"}',
- vdb_Field.CONTENT_KEY.value: "content",
- vdb_Field.PRIMARY_KEY.value: "test_id",
- vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
- },
- id="test_id",
- score=0.10,
- )
- ]
- MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
- @pytest.fixture
- def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
- if MOCK:
- monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
- monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
- monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
- monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
- monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
- monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
- monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
- monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
- monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
- monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
- monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
- monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
- yield
- if MOCK:
- monkeypatch.undo()
|