test_rerank.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.rerank_entities import RerankResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel
  6. from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
  7. @pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
  8. def test_validate_credentials(setup_xinference_mock):
  9. model = XinferenceRerankModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="bge-reranker-base",
  13. credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")},
  14. )
  15. model.validate_credentials(
  16. model="bge-reranker-base",
  17. credentials={
  18. "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
  19. "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
  20. },
  21. )
  22. @pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
  23. def test_invoke_model(setup_xinference_mock):
  24. model = XinferenceRerankModel()
  25. result = model.invoke(
  26. model="bge-reranker-base",
  27. credentials={
  28. "server_url": os.environ.get("XINFERENCE_SERVER_URL"),
  29. "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
  30. },
  31. query="Who is Kasumi?",
  32. docs=[
  33. 'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
  34. "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
  35. "and she leads a team named PopiParty.",
  36. ],
  37. score_threshold=0.8,
  38. )
  39. assert isinstance(result, RerankResult)
  40. assert len(result.docs) == 1
  41. assert result.docs[0].index == 0
  42. assert result.docs[0].score >= 0.8