test_llm.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  5. from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
  6. from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
  7. from core.model_manager import ModelInstance
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.model_runtime.model_providers import ModelProviderFactory
  10. from core.workflow.entities.node_entities import SystemVariable
  11. from core.workflow.entities.variable_pool import VariablePool
  12. from core.workflow.nodes.base_node import UserFrom
  13. from core.workflow.nodes.llm.llm_node import LLMNode
  14. from extensions.ext_database import db
  15. from models.provider import ProviderType
  16. from models.workflow import WorkflowNodeExecutionStatus
  17. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  18. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  19. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  20. def test_execute_llm(setup_openai_mock):
  21. node = LLMNode(
  22. tenant_id='1',
  23. app_id='1',
  24. workflow_id='1',
  25. user_id='1',
  26. user_from=UserFrom.ACCOUNT,
  27. config={
  28. 'id': 'llm',
  29. 'data': {
  30. 'title': '123',
  31. 'type': 'llm',
  32. 'model': {
  33. 'provider': 'openai',
  34. 'name': 'gpt-3.5-turbo',
  35. 'mode': 'chat',
  36. 'completion_params': {}
  37. },
  38. 'prompt_template': [
  39. {
  40. 'role': 'system',
  41. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.'
  42. },
  43. {
  44. 'role': 'user',
  45. 'text': '{{#sys.query#}}'
  46. }
  47. ],
  48. 'memory': None,
  49. 'context': {
  50. 'enabled': False
  51. },
  52. 'vision': {
  53. 'enabled': False
  54. }
  55. }
  56. }
  57. )
  58. # construct variable pool
  59. pool = VariablePool(system_variables={
  60. SystemVariable.QUERY: 'what\'s the weather today?',
  61. SystemVariable.FILES: [],
  62. SystemVariable.CONVERSATION: 'abababa'
  63. }, user_inputs={})
  64. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  65. credentials = {
  66. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  67. }
  68. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  69. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  70. provider_model_bundle = ProviderModelBundle(
  71. configuration=ProviderConfiguration(
  72. tenant_id='1',
  73. provider=provider_instance.get_provider_schema(),
  74. preferred_provider_type=ProviderType.CUSTOM,
  75. using_provider_type=ProviderType.CUSTOM,
  76. system_configuration=SystemConfiguration(
  77. enabled=False
  78. ),
  79. custom_configuration=CustomConfiguration(
  80. provider=CustomProviderConfiguration(
  81. credentials=credentials
  82. )
  83. )
  84. ),
  85. provider_instance=provider_instance,
  86. model_type_instance=model_type_instance
  87. )
  88. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  89. model_config = ModelConfigWithCredentialsEntity(
  90. model='gpt-3.5-turbo',
  91. provider='openai',
  92. mode='chat',
  93. credentials=credentials,
  94. parameters={},
  95. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  96. provider_model_bundle=provider_model_bundle
  97. )
  98. # Mock db.session.close()
  99. db.session.close = MagicMock()
  100. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  101. # execute node
  102. result = node.run(pool)
  103. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  104. assert result.outputs['text'] is not None
  105. assert result.outputs['usage']['total_tokens'] > 0