test_openai_model.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import json
  2. import os
  3. from unittest.mock import patch
  4. from langchain.schema import Generation, ChatGeneration, AIMessage
  5. from core.model_providers.providers.openai_provider import OpenAIProvider
  6. from core.model_providers.models.entity.message import PromptMessage, MessageType
  7. from core.model_providers.models.entity.model_params import ModelKwargs
  8. from core.model_providers.models.llm.openai_model import OpenAIModel
  9. from models.provider import Provider, ProviderType
  10. def get_mock_provider(valid_openai_api_key):
  11. return Provider(
  12. id='provider_id',
  13. tenant_id='tenant_id',
  14. provider_name='openai',
  15. provider_type=ProviderType.CUSTOM.value,
  16. encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
  17. is_valid=True,
  18. )
  19. def get_mock_openai_model(model_name):
  20. model_kwargs = ModelKwargs(
  21. max_tokens=10,
  22. temperature=0
  23. )
  24. model_name = model_name
  25. valid_openai_api_key = os.environ['OPENAI_API_KEY']
  26. openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
  27. return OpenAIModel(
  28. model_provider=openai_provider,
  29. name=model_name,
  30. model_kwargs=model_kwargs
  31. )
  32. def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
  33. return encrypted_openai_api_key
  34. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  35. def test_get_num_tokens(mock_decrypt):
  36. openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
  37. rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
  38. assert rst == 6
  39. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  40. def test_chat_get_num_tokens(mock_decrypt):
  41. openai_model = get_mock_openai_model('gpt-3.5-turbo')
  42. rst = openai_model.get_num_tokens([
  43. PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
  44. PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
  45. ])
  46. assert rst == 22
  47. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  48. def test_run(mock_decrypt, mocker):
  49. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  50. openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
  51. rst = openai_model.run(
  52. [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
  53. stop=['\nHuman:'],
  54. )
  55. assert len(rst.content) > 0
  56. assert rst.content.strip() == 'n'
  57. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  58. def test_chat_run(mock_decrypt, mocker):
  59. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  60. openai_model = get_mock_openai_model('gpt-3.5-turbo')
  61. messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
  62. rst = openai_model.run(
  63. messages,
  64. stop=['\nHuman:'],
  65. )
  66. assert len(rst.content) > 0