test_openai_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import json
  2. import os
  3. from unittest.mock import patch
  4. from langchain.schema import Generation, ChatGeneration, AIMessage
  5. from core.model_providers.providers.openai_provider import OpenAIProvider
  6. from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
  7. from core.model_providers.models.entity.model_params import ModelKwargs
  8. from core.model_providers.models.llm.openai_model import OpenAIModel
  9. from models.provider import Provider, ProviderType
  10. def get_mock_provider(valid_openai_api_key):
  11. return Provider(
  12. id='provider_id',
  13. tenant_id='tenant_id',
  14. provider_name='openai',
  15. provider_type=ProviderType.CUSTOM.value,
  16. encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}),
  17. is_valid=True,
  18. )
  19. def get_mock_openai_model(model_name):
  20. model_kwargs = ModelKwargs(
  21. max_tokens=10,
  22. temperature=0
  23. )
  24. model_name = model_name
  25. valid_openai_api_key = os.environ['OPENAI_API_KEY']
  26. openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
  27. return OpenAIModel(
  28. model_provider=openai_provider,
  29. name=model_name,
  30. model_kwargs=model_kwargs
  31. )
  32. def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
  33. return encrypted_openai_api_key
  34. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  35. def test_get_num_tokens(mock_decrypt):
  36. openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
  37. rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
  38. assert rst == 6
  39. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  40. def test_chat_get_num_tokens(mock_decrypt):
  41. openai_model = get_mock_openai_model('gpt-3.5-turbo')
  42. rst = openai_model.get_num_tokens([
  43. PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
  44. PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
  45. ])
  46. assert rst == 22
  47. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  48. def test_vision_chat_get_num_tokens(mock_decrypt):
  49. openai_model = get_mock_openai_model('gpt-4-vision-preview')
  50. messages = [
  51. PromptMessage(content='What’s in first image?', files=[
  52. ImageMessageFile(
  53. data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
  54. ])
  55. ]
  56. rst = openai_model.get_num_tokens(messages)
  57. assert rst == 77
  58. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  59. def test_run(mock_decrypt, mocker):
  60. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  61. openai_model = get_mock_openai_model('gpt-3.5-turbo-instruct')
  62. rst = openai_model.run(
  63. [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
  64. stop=['\nHuman:'],
  65. )
  66. assert len(rst.content) > 0
  67. assert rst.content.strip() == 'n'
  68. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  69. def test_chat_run(mock_decrypt, mocker):
  70. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  71. openai_model = get_mock_openai_model('gpt-3.5-turbo')
  72. messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
  73. rst = openai_model.run(
  74. messages,
  75. stop=['\nHuman:'],
  76. )
  77. assert (len(rst.content) > 0)
  78. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  79. def test_vision_run(mock_decrypt, mocker):
  80. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  81. openai_model = get_mock_openai_model('gpt-4-vision-preview')
  82. messages = [
  83. PromptMessage(content='What’s in first image?', files=[
  84. ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
  85. ])
  86. ]
  87. rst = openai_model.run(
  88. messages,
  89. )
  90. assert len(rst.content) > 0