test_llm.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.hunyuan.llm.llm import HunyuanLargeLanguageModel
  8. def test_validate_credentials():
  9. model = HunyuanLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='hunyuan-standard',
  13. credentials={
  14. 'secret_id': 'invalid_key',
  15. 'secret_key': 'invalid_key'
  16. }
  17. )
  18. model.validate_credentials(
  19. model='hunyuan-standard',
  20. credentials={
  21. 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
  22. 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
  23. }
  24. )
  25. def test_invoke_model():
  26. model = HunyuanLargeLanguageModel()
  27. response = model.invoke(
  28. model='hunyuan-standard',
  29. credentials={
  30. 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
  31. 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
  32. },
  33. prompt_messages=[
  34. UserPromptMessage(
  35. content='Hi'
  36. )
  37. ],
  38. model_parameters={
  39. 'temperature': 0.5,
  40. 'max_tokens': 10
  41. },
  42. stop=['How'],
  43. stream=False,
  44. user="abc-123"
  45. )
  46. assert isinstance(response, LLMResult)
  47. assert len(response.message.content) > 0
  48. def test_invoke_stream_model():
  49. model = HunyuanLargeLanguageModel()
  50. response = model.invoke(
  51. model='hunyuan-standard',
  52. credentials={
  53. 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
  54. 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
  55. },
  56. prompt_messages=[
  57. UserPromptMessage(
  58. content='Hi'
  59. )
  60. ],
  61. model_parameters={
  62. 'temperature': 0.5,
  63. 'max_tokens': 100,
  64. 'seed': 1234
  65. },
  66. stream=True,
  67. user="abc-123"
  68. )
  69. assert isinstance(response, Generator)
  70. for chunk in response:
  71. assert isinstance(chunk, LLMResultChunk)
  72. assert isinstance(chunk.delta, LLMResultChunkDelta)
  73. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  74. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  75. def test_get_num_tokens():
  76. model = HunyuanLargeLanguageModel()
  77. num_tokens = model.get_num_tokens(
  78. model='hunyuan-standard',
  79. credentials={
  80. 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
  81. 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
  82. },
  83. prompt_messages=[
  84. SystemPromptMessage(
  85. content='You are a helpful AI assistant.',
  86. ),
  87. UserPromptMessage(
  88. content='Hello World!'
  89. )
  90. ]
  91. )
  92. assert num_tokens == 14