test_llm.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import (
  5. LLMResult,
  6. LLMResultChunk,
  7. LLMResultChunkDelta,
  8. )
  9. from core.model_runtime.entities.message_entities import (
  10. AssistantPromptMessage,
  11. PromptMessageTool,
  12. SystemPromptMessage,
  13. UserPromptMessage,
  14. )
  15. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  16. from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
  17. def test_validate_credentials_for_chat_model():
  18. model = GPUStackLanguageModel()
  19. with pytest.raises(CredentialsValidateFailedError):
  20. model.validate_credentials(
  21. model="llama-3.2-1b-instruct",
  22. credentials={
  23. "endpoint_url": "invalid_url",
  24. "api_key": "invalid_api_key",
  25. "mode": "chat",
  26. },
  27. )
  28. model.validate_credentials(
  29. model="llama-3.2-1b-instruct",
  30. credentials={
  31. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  32. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  33. "mode": "chat",
  34. },
  35. )
  36. def test_invoke_completion_model():
  37. model = GPUStackLanguageModel()
  38. response = model.invoke(
  39. model="llama-3.2-1b-instruct",
  40. credentials={
  41. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  42. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  43. "mode": "completion",
  44. },
  45. prompt_messages=[UserPromptMessage(content="ping")],
  46. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  47. stop=[],
  48. user="abc-123",
  49. stream=False,
  50. )
  51. assert isinstance(response, LLMResult)
  52. assert len(response.message.content) > 0
  53. assert response.usage.total_tokens > 0
  54. def test_invoke_chat_model():
  55. model = GPUStackLanguageModel()
  56. response = model.invoke(
  57. model="llama-3.2-1b-instruct",
  58. credentials={
  59. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  60. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  61. "mode": "chat",
  62. },
  63. prompt_messages=[UserPromptMessage(content="ping")],
  64. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  65. stop=[],
  66. user="abc-123",
  67. stream=False,
  68. )
  69. assert isinstance(response, LLMResult)
  70. assert len(response.message.content) > 0
  71. assert response.usage.total_tokens > 0
  72. def test_invoke_stream_chat_model():
  73. model = GPUStackLanguageModel()
  74. response = model.invoke(
  75. model="llama-3.2-1b-instruct",
  76. credentials={
  77. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  78. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  79. "mode": "chat",
  80. },
  81. prompt_messages=[UserPromptMessage(content="Hello World!")],
  82. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  83. stop=["you"],
  84. stream=True,
  85. user="abc-123",
  86. )
  87. assert isinstance(response, Generator)
  88. for chunk in response:
  89. assert isinstance(chunk, LLMResultChunk)
  90. assert isinstance(chunk.delta, LLMResultChunkDelta)
  91. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  92. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  93. def test_get_num_tokens():
  94. model = GPUStackLanguageModel()
  95. num_tokens = model.get_num_tokens(
  96. model="????",
  97. credentials={
  98. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  99. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  100. "mode": "chat",
  101. },
  102. prompt_messages=[
  103. SystemPromptMessage(
  104. content="You are a helpful AI assistant.",
  105. ),
  106. UserPromptMessage(content="Hello World!"),
  107. ],
  108. tools=[
  109. PromptMessageTool(
  110. name="get_current_weather",
  111. description="Get the current weather in a given location",
  112. parameters={
  113. "type": "object",
  114. "properties": {
  115. "location": {
  116. "type": "string",
  117. "description": "The city and state e.g. San Francisco, CA",
  118. },
  119. "unit": {"type": "string", "enum": ["c", "f"]},
  120. },
  121. "required": ["location"],
  122. },
  123. )
  124. ],
  125. )
  126. assert isinstance(num_tokens, int)
  127. assert num_tokens == 80
  128. num_tokens = model.get_num_tokens(
  129. model="????",
  130. credentials={
  131. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  132. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  133. "mode": "chat",
  134. },
  135. prompt_messages=[UserPromptMessage(content="Hello World!")],
  136. )
  137. assert isinstance(num_tokens, int)
  138. assert num_tokens == 10