test_embedding.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
  6. GPUStackTextEmbeddingModel,
  7. )
  8. def test_validate_credentials():
  9. model = GPUStackTextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="bge-m3",
  13. credentials={
  14. "endpoint_url": "invalid_url",
  15. "api_key": "invalid_api_key",
  16. },
  17. )
  18. model.validate_credentials(
  19. model="bge-m3",
  20. credentials={
  21. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  22. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  23. },
  24. )
  25. def test_invoke_model():
  26. model = GPUStackTextEmbeddingModel()
  27. result = model.invoke(
  28. model="bge-m3",
  29. credentials={
  30. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  31. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  32. "context_size": 8192,
  33. },
  34. texts=["hello", "world"],
  35. user="abc-123",
  36. )
  37. assert isinstance(result, TextEmbeddingResult)
  38. assert len(result.embeddings) == 2
  39. assert result.usage.total_tokens == 7