test_llm.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. ImagePromptMessageContent,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. TextPromptMessageContent,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
  14. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  15. from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel
  16. def test_validate_credentials():
  17. model = StepfunLargeLanguageModel()
  18. with pytest.raises(CredentialsValidateFailedError):
  19. model.validate_credentials(
  20. model='step-1-8k',
  21. credentials={
  22. 'api_key': 'invalid_key'
  23. }
  24. )
  25. model.validate_credentials(
  26. model='step-1-8k',
  27. credentials={
  28. 'api_key': os.environ.get('STEPFUN_API_KEY')
  29. }
  30. )
  31. def test_invoke_model():
  32. model = StepfunLargeLanguageModel()
  33. response = model.invoke(
  34. model='step-1-8k',
  35. credentials={
  36. 'api_key': os.environ.get('STEPFUN_API_KEY')
  37. },
  38. prompt_messages=[
  39. UserPromptMessage(
  40. content='Hello World!'
  41. )
  42. ],
  43. model_parameters={
  44. 'temperature': 0.9,
  45. 'top_p': 0.7
  46. },
  47. stop=['Hi'],
  48. stream=False,
  49. user="abc-123"
  50. )
  51. assert isinstance(response, LLMResult)
  52. assert len(response.message.content) > 0
  53. def test_invoke_stream_model():
  54. model = StepfunLargeLanguageModel()
  55. response = model.invoke(
  56. model='step-1-8k',
  57. credentials={
  58. 'api_key': os.environ.get('STEPFUN_API_KEY')
  59. },
  60. prompt_messages=[
  61. SystemPromptMessage(
  62. content='You are a helpful AI assistant.',
  63. ),
  64. UserPromptMessage(
  65. content='Hello World!'
  66. )
  67. ],
  68. model_parameters={
  69. 'temperature': 0.9,
  70. 'top_p': 0.7
  71. },
  72. stream=True,
  73. user="abc-123"
  74. )
  75. assert isinstance(response, Generator)
  76. for chunk in response:
  77. assert isinstance(chunk, LLMResultChunk)
  78. assert isinstance(chunk.delta, LLMResultChunkDelta)
  79. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  80. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  81. def test_get_customizable_model_schema():
  82. model = StepfunLargeLanguageModel()
  83. schema = model.get_customizable_model_schema(
  84. model='step-1-8k',
  85. credentials={
  86. 'api_key': os.environ.get('STEPFUN_API_KEY')
  87. }
  88. )
  89. assert isinstance(schema, AIModelEntity)
  90. def test_invoke_chat_model_with_tools():
  91. model = StepfunLargeLanguageModel()
  92. result = model.invoke(
  93. model='step-1-8k',
  94. credentials={
  95. 'api_key': os.environ.get('STEPFUN_API_KEY')
  96. },
  97. prompt_messages=[
  98. SystemPromptMessage(
  99. content='You are a helpful AI assistant.',
  100. ),
  101. UserPromptMessage(
  102. content="what's the weather today in Shanghai?",
  103. )
  104. ],
  105. model_parameters={
  106. 'temperature': 0.9,
  107. 'max_tokens': 100
  108. },
  109. tools=[
  110. PromptMessageTool(
  111. name='get_weather',
  112. description='Determine weather in my location',
  113. parameters={
  114. "type": "object",
  115. "properties": {
  116. "location": {
  117. "type": "string",
  118. "description": "The city and state e.g. San Francisco, CA"
  119. },
  120. "unit": {
  121. "type": "string",
  122. "enum": [
  123. "c",
  124. "f"
  125. ]
  126. }
  127. },
  128. "required": [
  129. "location"
  130. ]
  131. }
  132. ),
  133. PromptMessageTool(
  134. name='get_stock_price',
  135. description='Get the current stock price',
  136. parameters={
  137. "type": "object",
  138. "properties": {
  139. "symbol": {
  140. "type": "string",
  141. "description": "The stock symbol"
  142. }
  143. },
  144. "required": [
  145. "symbol"
  146. ]
  147. }
  148. )
  149. ],
  150. stream=False,
  151. user="abc-123"
  152. )
  153. assert isinstance(result, LLMResult)
  154. assert isinstance(result.message, AssistantPromptMessage)
  155. assert len(result.message.tool_calls) > 0