vikingdb.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os
  2. from typing import Union
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from _pytest.monkeypatch import MonkeyPatch
  6. from volcengine.viking_db import ( # type: ignore
  7. Collection,
  8. Data,
  9. DistanceType,
  10. Field,
  11. FieldType,
  12. Index,
  13. IndexType,
  14. QuantType,
  15. VectorIndexParams,
  16. VikingDBService,
  17. )
  18. from core.rag.datasource.vdb.field import Field as vdb_Field
  19. class MockVikingDBClass:
  20. def __init__(
  21. self,
  22. host="api-vikingdb.volces.com",
  23. region="cn-north-1",
  24. ak="",
  25. sk="",
  26. scheme="http",
  27. connection_timeout=30,
  28. socket_timeout=30,
  29. proxy=None,
  30. ):
  31. self._viking_db_service = MagicMock()
  32. self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
  33. def get_collection(self, collection_name) -> Collection:
  34. return Collection(
  35. collection_name=collection_name,
  36. description="Collection For Dify",
  37. viking_db_service=self._viking_db_service,
  38. primary_key=vdb_Field.PRIMARY_KEY.value,
  39. fields=[
  40. Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
  41. Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
  42. Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
  43. Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
  44. Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),
  45. ],
  46. indexes=[
  47. Index(
  48. collection_name=collection_name,
  49. index_name=f"{collection_name}_idx",
  50. vector_index=VectorIndexParams(
  51. distance=DistanceType.L2,
  52. index_type=IndexType.HNSW,
  53. quant=QuantType.Float,
  54. ),
  55. scalar_index=None,
  56. stat=None,
  57. viking_db_service=self._viking_db_service,
  58. )
  59. ],
  60. )
  61. def drop_collection(self, collection_name):
  62. assert collection_name != ""
  63. def create_collection(self, collection_name, fields, description="") -> Collection:
  64. return Collection(
  65. collection_name=collection_name,
  66. description=description,
  67. primary_key=vdb_Field.PRIMARY_KEY.value,
  68. viking_db_service=self._viking_db_service,
  69. fields=fields,
  70. )
  71. def get_index(self, collection_name, index_name) -> Index:
  72. return Index(
  73. collection_name=collection_name,
  74. index_name=index_name,
  75. viking_db_service=self._viking_db_service,
  76. stat=None,
  77. scalar_index=None,
  78. vector_index=VectorIndexParams(
  79. distance=DistanceType.L2,
  80. index_type=IndexType.HNSW,
  81. quant=QuantType.Float,
  82. ),
  83. )
  84. def create_index(
  85. self,
  86. collection_name,
  87. index_name,
  88. vector_index=None,
  89. cpu_quota=2,
  90. description="",
  91. partition_by="",
  92. scalar_index=None,
  93. shard_count=None,
  94. shard_policy=None,
  95. ):
  96. return Index(
  97. collection_name=collection_name,
  98. index_name=index_name,
  99. vector_index=vector_index,
  100. cpu_quota=cpu_quota,
  101. description=description,
  102. partition_by=partition_by,
  103. scalar_index=scalar_index,
  104. shard_count=shard_count,
  105. shard_policy=shard_policy,
  106. viking_db_service=self._viking_db_service,
  107. stat=None,
  108. )
  109. def drop_index(self, collection_name, index_name):
  110. assert collection_name != ""
  111. assert index_name != ""
  112. def upsert_data(self, data: Union[Data, list[Data]]):
  113. assert data is not None
  114. def fetch_data(self, id: Union[str, list[str], int, list[int]]):
  115. return Data(
  116. fields={
  117. vdb_Field.GROUP_KEY.value: "test_group",
  118. vdb_Field.METADATA_KEY.value: "{}",
  119. vdb_Field.CONTENT_KEY.value: "content",
  120. vdb_Field.PRIMARY_KEY.value: id,
  121. vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
  122. },
  123. id=id,
  124. )
  125. def delete_data(self, id: Union[str, list[str], int, list[int]]):
  126. assert id is not None
  127. def search_by_vector(
  128. self,
  129. vector,
  130. sparse_vectors=None,
  131. filter=None,
  132. limit=10,
  133. output_fields=None,
  134. partition="default",
  135. dense_weight=None,
  136. ) -> list[Data]:
  137. return [
  138. Data(
  139. fields={
  140. vdb_Field.GROUP_KEY.value: "test_group",
  141. vdb_Field.METADATA_KEY.value: '\
  142. {"source": "/var/folders/ml/xxx/xxx.txt", \
  143. "document_id": "test_document_id", \
  144. "dataset_id": "test_dataset_id", \
  145. "doc_id": "test_id", \
  146. "doc_hash": "test_hash"}',
  147. vdb_Field.CONTENT_KEY.value: "content",
  148. vdb_Field.PRIMARY_KEY.value: "test_id",
  149. vdb_Field.VECTOR.value: vector,
  150. },
  151. id="test_id",
  152. score=0.10,
  153. )
  154. ]
  155. def search(
  156. self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
  157. ) -> list[Data]:
  158. return [
  159. Data(
  160. fields={
  161. vdb_Field.GROUP_KEY.value: "test_group",
  162. vdb_Field.METADATA_KEY.value: '\
  163. {"source": "/var/folders/ml/xxx/xxx.txt", \
  164. "document_id": "test_document_id", \
  165. "dataset_id": "test_dataset_id", \
  166. "doc_id": "test_id", \
  167. "doc_hash": "test_hash"}',
  168. vdb_Field.CONTENT_KEY.value: "content",
  169. vdb_Field.PRIMARY_KEY.value: "test_id",
  170. vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
  171. },
  172. id="test_id",
  173. score=0.10,
  174. )
  175. ]
  176. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  177. @pytest.fixture
  178. def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
  179. if MOCK:
  180. monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
  181. monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
  182. monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
  183. monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
  184. monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
  185. monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
  186. monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
  187. monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
  188. monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
  189. monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
  190. monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
  191. monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
  192. yield
  193. if MOCK:
  194. monkeypatch.undo()