test_llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. ImagePromptMessageContent,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. TextPromptMessageContent,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
  14. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  15. from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel
  16. def test_validate_credentials():
  17. model = StepfunLargeLanguageModel()
  18. with pytest.raises(CredentialsValidateFailedError):
  19. model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"})
  20. model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")})
  21. def test_invoke_model():
  22. model = StepfunLargeLanguageModel()
  23. response = model.invoke(
  24. model="step-1-8k",
  25. credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
  26. prompt_messages=[UserPromptMessage(content="Hello World!")],
  27. model_parameters={"temperature": 0.9, "top_p": 0.7},
  28. stop=["Hi"],
  29. stream=False,
  30. user="abc-123",
  31. )
  32. assert isinstance(response, LLMResult)
  33. assert len(response.message.content) > 0
  34. def test_invoke_stream_model():
  35. model = StepfunLargeLanguageModel()
  36. response = model.invoke(
  37. model="step-1-8k",
  38. credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
  39. prompt_messages=[
  40. SystemPromptMessage(
  41. content="You are a helpful AI assistant.",
  42. ),
  43. UserPromptMessage(content="Hello World!"),
  44. ],
  45. model_parameters={"temperature": 0.9, "top_p": 0.7},
  46. stream=True,
  47. user="abc-123",
  48. )
  49. assert isinstance(response, Generator)
  50. for chunk in response:
  51. assert isinstance(chunk, LLMResultChunk)
  52. assert isinstance(chunk.delta, LLMResultChunkDelta)
  53. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  54. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  55. def test_get_customizable_model_schema():
  56. model = StepfunLargeLanguageModel()
  57. schema = model.get_customizable_model_schema(
  58. model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}
  59. )
  60. assert isinstance(schema, AIModelEntity)
  61. def test_invoke_chat_model_with_tools():
  62. model = StepfunLargeLanguageModel()
  63. result = model.invoke(
  64. model="step-1-8k",
  65. credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
  66. prompt_messages=[
  67. SystemPromptMessage(
  68. content="You are a helpful AI assistant.",
  69. ),
  70. UserPromptMessage(
  71. content="what's the weather today in Shanghai?",
  72. ),
  73. ],
  74. model_parameters={"temperature": 0.9, "max_tokens": 100},
  75. tools=[
  76. PromptMessageTool(
  77. name="get_weather",
  78. description="Determine weather in my location",
  79. parameters={
  80. "type": "object",
  81. "properties": {
  82. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  83. "unit": {"type": "string", "enum": ["c", "f"]},
  84. },
  85. "required": ["location"],
  86. },
  87. ),
  88. PromptMessageTool(
  89. name="get_stock_price",
  90. description="Get the current stock price",
  91. parameters={
  92. "type": "object",
  93. "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
  94. "required": ["symbol"],
  95. },
  96. ),
  97. ],
  98. stream=False,
  99. user="abc-123",
  100. )
  101. assert isinstance(result, LLMResult)
  102. assert isinstance(result.message, AssistantPromptMessage)
  103. assert len(result.message.tool_calls) > 0