test_llm.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. TextPromptMessageContent,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.entities.model_entities import ParameterRule
  13. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  14. from core.model_runtime.model_providers.localai.llm.llm import LocalAILanguageModel
  15. def test_validate_credentials_for_chat_model():
  16. model = LocalAILanguageModel()
  17. with pytest.raises(CredentialsValidateFailedError):
  18. model.validate_credentials(
  19. model="chinese-llama-2-7b",
  20. credentials={
  21. "server_url": "hahahaha",
  22. "completion_type": "completion",
  23. },
  24. )
  25. model.validate_credentials(
  26. model="chinese-llama-2-7b",
  27. credentials={
  28. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  29. "completion_type": "completion",
  30. },
  31. )
  32. def test_invoke_completion_model():
  33. model = LocalAILanguageModel()
  34. response = model.invoke(
  35. model="chinese-llama-2-7b",
  36. credentials={
  37. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  38. "completion_type": "completion",
  39. },
  40. prompt_messages=[UserPromptMessage(content="ping")],
  41. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  42. stop=[],
  43. user="abc-123",
  44. stream=False,
  45. )
  46. assert isinstance(response, LLMResult)
  47. assert len(response.message.content) > 0
  48. assert response.usage.total_tokens > 0
  49. def test_invoke_chat_model():
  50. model = LocalAILanguageModel()
  51. response = model.invoke(
  52. model="chinese-llama-2-7b",
  53. credentials={
  54. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  55. "completion_type": "chat_completion",
  56. },
  57. prompt_messages=[UserPromptMessage(content="ping")],
  58. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  59. stop=[],
  60. user="abc-123",
  61. stream=False,
  62. )
  63. assert isinstance(response, LLMResult)
  64. assert len(response.message.content) > 0
  65. assert response.usage.total_tokens > 0
  66. def test_invoke_stream_completion_model():
  67. model = LocalAILanguageModel()
  68. response = model.invoke(
  69. model="chinese-llama-2-7b",
  70. credentials={
  71. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  72. "completion_type": "completion",
  73. },
  74. prompt_messages=[UserPromptMessage(content="Hello World!")],
  75. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  76. stop=["you"],
  77. stream=True,
  78. user="abc-123",
  79. )
  80. assert isinstance(response, Generator)
  81. for chunk in response:
  82. assert isinstance(chunk, LLMResultChunk)
  83. assert isinstance(chunk.delta, LLMResultChunkDelta)
  84. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  85. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  86. def test_invoke_stream_chat_model():
  87. model = LocalAILanguageModel()
  88. response = model.invoke(
  89. model="chinese-llama-2-7b",
  90. credentials={
  91. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  92. "completion_type": "chat_completion",
  93. },
  94. prompt_messages=[UserPromptMessage(content="Hello World!")],
  95. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  96. stop=["you"],
  97. stream=True,
  98. user="abc-123",
  99. )
  100. assert isinstance(response, Generator)
  101. for chunk in response:
  102. assert isinstance(chunk, LLMResultChunk)
  103. assert isinstance(chunk.delta, LLMResultChunkDelta)
  104. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  105. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  106. def test_get_num_tokens():
  107. model = LocalAILanguageModel()
  108. num_tokens = model.get_num_tokens(
  109. model="????",
  110. credentials={
  111. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  112. "completion_type": "chat_completion",
  113. },
  114. prompt_messages=[
  115. SystemPromptMessage(
  116. content="You are a helpful AI assistant.",
  117. ),
  118. UserPromptMessage(content="Hello World!"),
  119. ],
  120. tools=[
  121. PromptMessageTool(
  122. name="get_current_weather",
  123. description="Get the current weather in a given location",
  124. parameters={
  125. "type": "object",
  126. "properties": {
  127. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  128. "unit": {"type": "string", "enum": ["c", "f"]},
  129. },
  130. "required": ["location"],
  131. },
  132. )
  133. ],
  134. )
  135. assert isinstance(num_tokens, int)
  136. assert num_tokens == 77
  137. num_tokens = model.get_num_tokens(
  138. model="????",
  139. credentials={
  140. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  141. "completion_type": "chat_completion",
  142. },
  143. prompt_messages=[UserPromptMessage(content="Hello World!")],
  144. )
  145. assert isinstance(num_tokens, int)
  146. assert num_tokens == 10