google.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from unittest.mock import MagicMock
  2. import google.generativeai.types.generation_types as generation_config_types # type: ignore
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from google.ai import generativelanguage as glm
  6. from google.ai.generativelanguage_v1beta.types import content as gag_content
  7. from google.generativeai import GenerativeModel
  8. from google.generativeai.types import GenerateContentResponse, content_types, safety_types
  9. from google.generativeai.types.generation_types import BaseGenerateContentResponse
  10. from extensions import ext_redis
  11. class MockGoogleResponseClass:
  12. _done = False
  13. def __iter__(self):
  14. full_response_text = "it's google!"
  15. for i in range(0, len(full_response_text) + 1, 1):
  16. if i == len(full_response_text):
  17. self._done = True
  18. yield GenerateContentResponse(
  19. done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
  20. )
  21. else:
  22. yield GenerateContentResponse(
  23. done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
  24. )
  25. class MockGoogleResponseCandidateClass:
  26. finish_reason = "stop"
  27. @property
  28. def content(self) -> gag_content.Content:
  29. return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
  30. class MockGoogleClass:
  31. @staticmethod
  32. def generate_content_sync() -> GenerateContentResponse:
  33. return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
  34. @staticmethod
  35. def generate_content_stream() -> MockGoogleResponseClass:
  36. return MockGoogleResponseClass()
  37. def generate_content(
  38. self: GenerativeModel,
  39. contents: content_types.ContentsType,
  40. *,
  41. generation_config: generation_config_types.GenerationConfigType | None = None,
  42. safety_settings: safety_types.SafetySettingOptions | None = None,
  43. stream: bool = False,
  44. **kwargs,
  45. ) -> GenerateContentResponse:
  46. if stream:
  47. return MockGoogleClass.generate_content_stream()
  48. return MockGoogleClass.generate_content_sync()
  49. @property
  50. def generative_response_text(self) -> str:
  51. return "it's google!"
  52. @property
  53. def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
  54. return [MockGoogleResponseCandidateClass()]
  55. def mock_configure(api_key: str):
  56. if len(api_key) < 16:
  57. raise Exception("Invalid API key")
  58. class MockFileState:
  59. def __init__(self):
  60. self.name = "FINISHED"
  61. class MockGoogleFile:
  62. def __init__(self, name: str = "mock_file_name"):
  63. self.name = name
  64. self.state = MockFileState()
  65. def mock_get_file(name: str) -> MockGoogleFile:
  66. return MockGoogleFile(name)
  67. def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
  68. return MockGoogleFile()
  69. @pytest.fixture
  70. def setup_google_mock(request, monkeypatch: MonkeyPatch):
  71. monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
  72. monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
  73. monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
  74. monkeypatch.setattr("google.generativeai.configure", mock_configure)
  75. monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
  76. monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
  77. yield
  78. monkeypatch.undo()
  79. @pytest.fixture
  80. def setup_mock_redis() -> None:
  81. ext_redis.redis_client.get = MagicMock(return_value=None)
  82. ext_redis.redis_client.setex = MagicMock(return_value=None)
  83. ext_redis.redis_client.exists = MagicMock(return_value=True)