test_rerank.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from core.model_runtime.entities.rerank_entities import RerankResult
  5. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  6. from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
  7. def test_validate_credentials():
  8. model = VoyageRerankModel()
  9. with pytest.raises(CredentialsValidateFailedError):
  10. model.validate_credentials(
  11. model="rerank-lite-1",
  12. credentials={"api_key": "invalid_key"},
  13. )
  14. with patch("httpx.post") as mock_post:
  15. mock_response = Mock()
  16. mock_response.json.return_value = {
  17. "object": "list",
  18. "data": [
  19. {
  20. "relevance_score": 0.546875,
  21. "index": 0,
  22. "document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
  23. "States Census, Carson City had a population of 55,274.",
  24. },
  25. {
  26. "relevance_score": 0.4765625,
  27. "index": 1,
  28. "document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
  29. "Pacific Ocean that are a political division controlled by the United States. Its "
  30. "capital is Saipan.",
  31. },
  32. ],
  33. "model": "rerank-lite-1",
  34. "usage": {"total_tokens": 96},
  35. }
  36. mock_response.status_code = 200
  37. mock_post.return_value = mock_response
  38. model.validate_credentials(
  39. model="rerank-lite-1",
  40. credentials={
  41. "api_key": os.environ.get("VOYAGE_API_KEY"),
  42. },
  43. )
  44. def test_invoke_model():
  45. model = VoyageRerankModel()
  46. with patch("httpx.post") as mock_post:
  47. mock_response = Mock()
  48. mock_response.json.return_value = {
  49. "object": "list",
  50. "data": [
  51. {
  52. "relevance_score": 0.84375,
  53. "index": 0,
  54. "document": "Kasumi is a girl name of Japanese origin meaning mist.",
  55. },
  56. {
  57. "relevance_score": 0.4765625,
  58. "index": 1,
  59. "document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
  60. "leads a team named PopiParty.",
  61. },
  62. ],
  63. "model": "rerank-lite-1",
  64. "usage": {"total_tokens": 59},
  65. }
  66. mock_response.status_code = 200
  67. mock_post.return_value = mock_response
  68. result = model.invoke(
  69. model="rerank-lite-1",
  70. credentials={
  71. "api_key": os.environ.get("VOYAGE_API_KEY"),
  72. },
  73. query="Who is Kasumi?",
  74. docs=[
  75. "Kasumi is a girl name of Japanese origin meaning mist.",
  76. "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
  77. "PopiParty.",
  78. ],
  79. score_threshold=0.5,
  80. )
  81. assert isinstance(result, RerankResult)
  82. assert len(result.docs) == 1
  83. assert result.docs[0].index == 0
  84. assert result.docs[0].score >= 0.5