test_zhipuai_model.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import json
  2. import os
  3. from unittest.mock import patch
  4. from core.model_providers.models.entity.message import PromptMessage, MessageType
  5. from core.model_providers.models.entity.model_params import ModelKwargs
  6. from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
  7. from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
  8. from models.provider import Provider, ProviderType
  9. def get_mock_provider(valid_api_key):
  10. return Provider(
  11. id='provider_id',
  12. tenant_id='tenant_id',
  13. provider_name='zhipuai',
  14. provider_type=ProviderType.CUSTOM.value,
  15. encrypted_config=json.dumps({
  16. 'api_key': valid_api_key
  17. }),
  18. is_valid=True,
  19. )
  20. def get_mock_model(model_name: str, streaming: bool = False):
  21. model_kwargs = ModelKwargs(
  22. temperature=0.01,
  23. )
  24. valid_api_key = os.environ['ZHIPUAI_API_KEY']
  25. model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
  26. return ZhipuAIModel(
  27. model_provider=model_provider,
  28. name=model_name,
  29. model_kwargs=model_kwargs,
  30. streaming=streaming
  31. )
  32. def decrypt_side_effect(tenant_id, encrypted_api_key):
  33. return encrypted_api_key
  34. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  35. def test_chat_get_num_tokens(mock_decrypt):
  36. model = get_mock_model('chatglm_lite')
  37. rst = model.get_num_tokens([
  38. PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
  39. PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
  40. ])
  41. assert rst > 0
  42. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  43. def test_chat_run(mock_decrypt, mocker):
  44. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  45. model = get_mock_model('chatglm_lite')
  46. messages = [
  47. PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
  48. ]
  49. rst = model.run(
  50. messages,
  51. )
  52. assert len(rst.content) > 0
  53. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  54. def test_chat_stream_run(mock_decrypt, mocker):
  55. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  56. model = get_mock_model('chatglm_lite', streaming=True)
  57. messages = [
  58. PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
  59. ]
  60. rst = model.run(
  61. messages
  62. )
  63. assert len(rst.content) > 0