test_llm.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. TextPromptMessageContent,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  13. from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel
  14. def test_validate_credentials():
  15. model = OCILargeLanguageModel()
  16. with pytest.raises(CredentialsValidateFailedError):
  17. model.validate_credentials(
  18. model="cohere.command-r-plus",
  19. credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
  20. )
  21. model.validate_credentials(
  22. model="cohere.command-r-plus",
  23. credentials={
  24. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  25. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  26. },
  27. )
  28. def test_invoke_model():
  29. model = OCILargeLanguageModel()
  30. response = model.invoke(
  31. model="cohere.command-r-plus",
  32. credentials={
  33. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  34. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  35. },
  36. prompt_messages=[UserPromptMessage(content="Hi")],
  37. model_parameters={"temperature": 0.5, "max_tokens": 10},
  38. stop=["How"],
  39. stream=False,
  40. user="abc-123",
  41. )
  42. assert isinstance(response, LLMResult)
  43. assert len(response.message.content) > 0
  44. def test_invoke_stream_model():
  45. model = OCILargeLanguageModel()
  46. response = model.invoke(
  47. model="meta.llama-3-70b-instruct",
  48. credentials={
  49. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  50. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  51. },
  52. prompt_messages=[UserPromptMessage(content="Hi")],
  53. model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
  54. stream=True,
  55. user="abc-123",
  56. )
  57. assert isinstance(response, Generator)
  58. for chunk in response:
  59. assert isinstance(chunk, LLMResultChunk)
  60. assert isinstance(chunk.delta, LLMResultChunkDelta)
  61. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  62. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  63. def test_invoke_model_with_function():
  64. model = OCILargeLanguageModel()
  65. response = model.invoke(
  66. model="cohere.command-r-plus",
  67. credentials={
  68. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  69. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  70. },
  71. prompt_messages=[UserPromptMessage(content="Hi")],
  72. model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
  73. stream=False,
  74. user="abc-123",
  75. tools=[
  76. PromptMessageTool(
  77. name="get_current_weather",
  78. description="Get the current weather in a given location",
  79. parameters={
  80. "type": "object",
  81. "properties": {
  82. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  83. "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
  84. },
  85. "required": ["location"],
  86. },
  87. )
  88. ],
  89. )
  90. assert isinstance(response, LLMResult)
  91. assert len(response.message.content) > 0
  92. def test_get_num_tokens():
  93. model = OCILargeLanguageModel()
  94. num_tokens = model.get_num_tokens(
  95. model="cohere.command-r-plus",
  96. credentials={
  97. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  98. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  99. },
  100. prompt_messages=[
  101. SystemPromptMessage(
  102. content="You are a helpful AI assistant.",
  103. ),
  104. UserPromptMessage(content="Hello World!"),
  105. ],
  106. )
  107. assert num_tokens == 18