tcvectordb.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import os
  2. from typing import Optional
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from requests.adapters import HTTPAdapter
  6. from tcvectordb import VectorDBClient # type: ignore
  7. from tcvectordb.model.database import Collection, Database # type: ignore
  8. from tcvectordb.model.document import Document, Filter # type: ignore
  9. from tcvectordb.model.enum import ReadConsistency # type: ignore
  10. from tcvectordb.model.index import Index # type: ignore
  11. from xinference_client.types import Embedding # type: ignore
  12. class MockTcvectordbClass:
  13. def mock_vector_db_client(
  14. self,
  15. url=None,
  16. username="",
  17. key="",
  18. read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
  19. timeout=5,
  20. adapter: HTTPAdapter = None,
  21. ):
  22. self._conn = None
  23. self._read_consistency = read_consistency
  24. def list_databases(self) -> list[Database]:
  25. return [
  26. Database(
  27. conn=self._conn,
  28. read_consistency=self._read_consistency,
  29. name="dify",
  30. )
  31. ]
  32. def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
  33. return []
  34. def drop_collection(self, name: str, timeout: Optional[float] = None):
  35. return {"code": 0, "msg": "operation success"}
  36. def create_collection(
  37. self,
  38. name: str,
  39. shard: int,
  40. replicas: int,
  41. description: str,
  42. index: Index,
  43. embedding: Embedding = None,
  44. timeout: Optional[float] = None,
  45. ) -> Collection:
  46. return Collection(
  47. self,
  48. name,
  49. shard,
  50. replicas,
  51. description,
  52. index,
  53. embedding=embedding,
  54. read_consistency=self._read_consistency,
  55. timeout=timeout,
  56. )
  57. def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
  58. collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout)
  59. return collection
  60. def collection_upsert(
  61. self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs
  62. ):
  63. return {"code": 0, "msg": "operation success"}
  64. def collection_search(
  65. self,
  66. vectors: list[list[float]],
  67. filter: Filter = None,
  68. params=None,
  69. retrieve_vector: bool = False,
  70. limit: int = 10,
  71. output_fields: Optional[list[str]] = None,
  72. timeout: Optional[float] = None,
  73. ) -> list[list[dict]]:
  74. return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]]
  75. def collection_query(
  76. self,
  77. document_ids: Optional[list] = None,
  78. retrieve_vector: bool = False,
  79. limit: Optional[int] = None,
  80. offset: Optional[int] = None,
  81. filter: Optional[Filter] = None,
  82. output_fields: Optional[list[str]] = None,
  83. timeout: Optional[float] = None,
  84. ) -> list[dict]:
  85. return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
  86. def collection_delete(
  87. self,
  88. document_ids: Optional[list[str]] = None,
  89. filter: Filter = None,
  90. timeout: Optional[float] = None,
  91. ):
  92. return {"code": 0, "msg": "operation success"}
  93. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  94. @pytest.fixture
  95. def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
  96. if MOCK:
  97. monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
  98. monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
  99. monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
  100. monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)
  101. monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection)
  102. monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection)
  103. monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert)
  104. monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search)
  105. monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query)
  106. monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete)
  107. yield
  108. if MOCK:
  109. monkeypatch.undo()