test_rerank.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.gpustack.rerank.rerank import (
  6. GPUStackRerankModel,
  7. )
  8. def test_validate_credentials_for_rerank_model():
  9. model = GPUStackRerankModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="bge-reranker-v2-m3",
  13. credentials={
  14. "endpoint_url": "invalid_url",
  15. "api_key": "invalid_api_key",
  16. },
  17. )
  18. model.validate_credentials(
  19. model="bge-reranker-v2-m3",
  20. credentials={
  21. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  22. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  23. },
  24. )
  25. def test_invoke_rerank_model():
  26. model = GPUStackRerankModel()
  27. response = model.invoke(
  28. model="bge-reranker-v2-m3",
  29. credentials={
  30. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  31. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  32. },
  33. query="Organic skincare products for sensitive skin",
  34. docs=[
  35. "Eco-friendly kitchenware for modern homes",
  36. "Biodegradable cleaning supplies for eco-conscious consumers",
  37. "Organic cotton baby clothes for sensitive skin",
  38. "Natural organic skincare range for sensitive skin",
  39. "Tech gadgets for smart homes: 2024 edition",
  40. "Sustainable gardening tools and compost solutions",
  41. "Sensitive skin-friendly facial cleansers and toners",
  42. "Organic food wraps and storage solutions",
  43. "Yoga mats made from recycled materials",
  44. ],
  45. top_n=3,
  46. score_threshold=-0.75,
  47. user="abc-123",
  48. )
  49. assert isinstance(response, RerankResult)
  50. assert len(response.docs) == 3
  51. def test__invoke():
  52. model = GPUStackRerankModel()
  53. # Test case 1: Empty docs
  54. result = model._invoke(
  55. model="bge-reranker-v2-m3",
  56. credentials={
  57. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  58. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  59. },
  60. query="Organic skincare products for sensitive skin",
  61. docs=[],
  62. top_n=3,
  63. score_threshold=0.75,
  64. user="abc-123",
  65. )
  66. assert isinstance(result, RerankResult)
  67. assert len(result.docs) == 0
  68. # Test case 2: Expected docs
  69. result = model._invoke(
  70. model="bge-reranker-v2-m3",
  71. credentials={
  72. "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
  73. "api_key": os.environ.get("GPUSTACK_API_KEY"),
  74. },
  75. query="Organic skincare products for sensitive skin",
  76. docs=[
  77. "Eco-friendly kitchenware for modern homes",
  78. "Biodegradable cleaning supplies for eco-conscious consumers",
  79. "Organic cotton baby clothes for sensitive skin",
  80. "Natural organic skincare range for sensitive skin",
  81. "Tech gadgets for smart homes: 2024 edition",
  82. "Sustainable gardening tools and compost solutions",
  83. "Sensitive skin-friendly facial cleansers and toners",
  84. "Organic food wraps and storage solutions",
  85. "Yoga mats made from recycled materials",
  86. ],
  87. top_n=3,
  88. score_threshold=-0.75,
  89. user="abc-123",
  90. )
  91. assert isinstance(result, RerankResult)
  92. assert len(result.docs) == 3
  93. assert all(isinstance(doc, RerankDocument) for doc in result.docs)