baiduvectordb.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. from collections import UserDict
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from _pytest.monkeypatch import MonkeyPatch
  6. from pymochow import MochowClient # type: ignore
  7. from pymochow.model.database import Database # type: ignore
  8. from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
  9. from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
  10. from pymochow.model.table import Table # type: ignore
  11. from requests.adapters import HTTPAdapter
  12. class AttrDict(UserDict):
  13. def __getattr__(self, item):
  14. return self.get(item)
  15. class MockBaiduVectorDBClass:
  16. def mock_vector_db_client(
  17. self,
  18. config=None,
  19. adapter: HTTPAdapter = None,
  20. ):
  21. self.conn = MagicMock()
  22. self._config = MagicMock()
  23. def list_databases(self, config=None) -> list[Database]:
  24. return [
  25. Database(
  26. conn=self.conn,
  27. database_name="dify",
  28. config=self._config,
  29. )
  30. ]
  31. def create_database(self, database_name: str, config=None) -> Database:
  32. return Database(conn=self.conn, database_name=database_name, config=config)
  33. def list_table(self, config=None) -> list[Table]:
  34. return []
  35. def drop_table(self, table_name: str, config=None):
  36. return {"code": 0, "msg": "Success"}
  37. def create_table(
  38. self,
  39. table_name: str,
  40. replication: int,
  41. partition: int,
  42. schema,
  43. enable_dynamic_field=False,
  44. description: str = "",
  45. config=None,
  46. ) -> Table:
  47. return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
  48. def describe_table(self, table_name: str, config=None) -> Table:
  49. return Table(
  50. self,
  51. table_name,
  52. 3,
  53. 1,
  54. None,
  55. enable_dynamic_field=False,
  56. description="table for dify",
  57. config=config,
  58. state=TableState.NORMAL,
  59. )
  60. def upsert(self, rows, config=None):
  61. return {"code": 0, "msg": "operation success", "affectedCount": 1}
  62. def rebuild_index(self, index_name: str, config=None):
  63. return {"code": 0, "msg": "Success"}
  64. def describe_index(self, index_name: str, config=None):
  65. return VectorIndex(
  66. index_name=index_name,
  67. index_type=IndexType.HNSW,
  68. field="vector",
  69. metric_type=MetricType.L2,
  70. params=HNSWParams(m=16, efconstruction=200),
  71. auto_build=False,
  72. state=IndexState.NORMAL,
  73. )
  74. def query(
  75. self,
  76. primary_key,
  77. partition_key=None,
  78. projections=None,
  79. retrieve_vector=False,
  80. read_consistency=ReadConsistency.EVENTUAL,
  81. config=None,
  82. ):
  83. return AttrDict(
  84. {
  85. "row": {
  86. "id": primary_key.get("id"),
  87. "vector": [0.23432432, 0.8923744, 0.89238432],
  88. "text": "text",
  89. "metadata": '{"doc_id": "doc_id_001"}',
  90. },
  91. "code": 0,
  92. "msg": "Success",
  93. }
  94. )
  95. def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
  96. return {"code": 0, "msg": "Success"}
  97. def search(
  98. self,
  99. anns,
  100. partition_key=None,
  101. projections=None,
  102. retrieve_vector=False,
  103. read_consistency=ReadConsistency.EVENTUAL,
  104. config=None,
  105. ):
  106. return AttrDict(
  107. {
  108. "rows": [
  109. {
  110. "row": {
  111. "id": "doc_id_001",
  112. "vector": [0.23432432, 0.8923744, 0.89238432],
  113. "text": "text",
  114. "metadata": '{"doc_id": "doc_id_001"}',
  115. },
  116. "distance": 0.1,
  117. "score": 0.5,
  118. }
  119. ],
  120. "code": 0,
  121. "msg": "Success",
  122. }
  123. )
  124. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  125. @pytest.fixture
  126. def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
  127. if MOCK:
  128. monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
  129. monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
  130. monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
  131. monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
  132. monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
  133. monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
  134. monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
  135. monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
  136. monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
  137. monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
  138. monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
  139. monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
  140. monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
  141. yield
  142. if MOCK:
  143. monkeypatch.undo()