test_llm.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel
  8. def test_validate_credentials():
  9. model = SparkLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"})
  12. model.validate_credentials(
  13. model="spark-1.5",
  14. credentials={
  15. "app_id": os.environ.get("SPARK_APP_ID"),
  16. "api_secret": os.environ.get("SPARK_API_SECRET"),
  17. "api_key": os.environ.get("SPARK_API_KEY"),
  18. },
  19. )
  20. def test_invoke_model():
  21. model = SparkLargeLanguageModel()
  22. response = model.invoke(
  23. model="spark-1.5",
  24. credentials={
  25. "app_id": os.environ.get("SPARK_APP_ID"),
  26. "api_secret": os.environ.get("SPARK_API_SECRET"),
  27. "api_key": os.environ.get("SPARK_API_KEY"),
  28. },
  29. prompt_messages=[UserPromptMessage(content="Who are you?")],
  30. model_parameters={"temperature": 0.5, "max_tokens": 10},
  31. stop=["How"],
  32. stream=False,
  33. user="abc-123",
  34. )
  35. assert isinstance(response, LLMResult)
  36. assert len(response.message.content) > 0
  37. def test_invoke_stream_model():
  38. model = SparkLargeLanguageModel()
  39. response = model.invoke(
  40. model="spark-1.5",
  41. credentials={
  42. "app_id": os.environ.get("SPARK_APP_ID"),
  43. "api_secret": os.environ.get("SPARK_API_SECRET"),
  44. "api_key": os.environ.get("SPARK_API_KEY"),
  45. },
  46. prompt_messages=[UserPromptMessage(content="Hello World!")],
  47. model_parameters={"temperature": 0.5, "max_tokens": 100},
  48. stream=True,
  49. user="abc-123",
  50. )
  51. assert isinstance(response, Generator)
  52. for chunk in response:
  53. assert isinstance(chunk, LLMResultChunk)
  54. assert isinstance(chunk.delta, LLMResultChunkDelta)
  55. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  56. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  57. def test_get_num_tokens():
  58. model = SparkLargeLanguageModel()
  59. num_tokens = model.get_num_tokens(
  60. model="spark-1.5",
  61. credentials={
  62. "app_id": os.environ.get("SPARK_APP_ID"),
  63. "api_secret": os.environ.get("SPARK_API_SECRET"),
  64. "api_key": os.environ.get("SPARK_API_KEY"),
  65. },
  66. prompt_messages=[
  67. SystemPromptMessage(
  68. content="You are a helpful AI assistant.",
  69. ),
  70. UserPromptMessage(content="Hello World!"),
  71. ],
  72. )
  73. assert num_tokens == 14