xinference.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import os
  2. import re
  3. from typing import Union
  4. import pytest
  5. from _pytest.monkeypatch import MonkeyPatch
  6. from requests import Response
  7. from requests.exceptions import ConnectionError
  8. from requests.sessions import Session
  9. from xinference_client.client.restful.restful_client import (
  10. Client,
  11. RESTfulChatglmCppChatModelHandle,
  12. RESTfulChatModelHandle,
  13. RESTfulEmbeddingModelHandle,
  14. RESTfulGenerateModelHandle,
  15. RESTfulRerankModelHandle,
  16. )
  17. from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
  18. class MockXinferenceClass:
  19. def get_chat_model(
  20. self: Client, model_uid: str
  21. ) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
  22. if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
  23. raise RuntimeError("404 Not Found")
  24. if "generate" == model_uid:
  25. return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  26. if "chat" == model_uid:
  27. return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  28. if "embedding" == model_uid:
  29. return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  30. if "rerank" == model_uid:
  31. return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  32. raise RuntimeError("404 Not Found")
  33. def get(self: Session, url: str, **kwargs):
  34. response = Response()
  35. if "v1/models/" in url:
  36. # get model uid
  37. model_uid = url.split("/")[-1] or ""
  38. if not re.match(
  39. r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
  40. ) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
  41. response.status_code = 404
  42. response._content = b"{}"
  43. return response
  44. # check if url is valid
  45. if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
  46. response.status_code = 404
  47. response._content = b"{}"
  48. return response
  49. if model_uid in ["generate", "chat"]:
  50. response.status_code = 200
  51. response._content = b"""{
  52. "model_type": "LLM",
  53. "address": "127.0.0.1:43877",
  54. "accelerators": [
  55. "0",
  56. "1"
  57. ],
  58. "model_name": "chatglm3-6b",
  59. "model_lang": [
  60. "en"
  61. ],
  62. "model_ability": [
  63. "generate",
  64. "chat"
  65. ],
  66. "model_description": "latest chatglm3",
  67. "model_format": "pytorch",
  68. "model_size_in_billions": 7,
  69. "quantization": "none",
  70. "model_hub": "huggingface",
  71. "revision": null,
  72. "context_length": 2048,
  73. "replica": 1
  74. }"""
  75. return response
  76. elif model_uid == "embedding":
  77. response.status_code = 200
  78. response._content = b"""{
  79. "model_type": "embedding",
  80. "address": "127.0.0.1:43877",
  81. "accelerators": [
  82. "0",
  83. "1"
  84. ],
  85. "model_name": "bge",
  86. "model_lang": [
  87. "en"
  88. ],
  89. "revision": null,
  90. "max_tokens": 512
  91. }"""
  92. return response
  93. elif "v1/cluster/auth" in url:
  94. response.status_code = 200
  95. response._content = b"""{
  96. "auth": true
  97. }"""
  98. return response
  99. def _check_cluster_authenticated(self):
  100. self._cluster_authed = True
  101. def rerank(
  102. self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
  103. ) -> dict:
  104. # check if self._model_uid is a valid uuid
  105. if (
  106. not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
  107. and self._model_uid != "rerank"
  108. ):
  109. raise RuntimeError("404 Not Found")
  110. if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
  111. raise RuntimeError("404 Not Found")
  112. if top_n is None:
  113. top_n = 1
  114. return {
  115. "results": [
  116. {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
  117. ]
  118. }
  119. def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
  120. # check if self._model_uid is a valid uuid
  121. if (
  122. not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
  123. and self._model_uid != "embedding"
  124. ):
  125. raise RuntimeError("404 Not Found")
  126. if isinstance(input, str):
  127. input = [input]
  128. ipt_len = len(input)
  129. embedding = Embedding(
  130. object="list",
  131. model=self._model_uid,
  132. data=[
  133. EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
  134. for i in range(ipt_len)
  135. ],
  136. usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
  137. )
  138. return embedding
  139. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  140. @pytest.fixture
  141. def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
  142. if MOCK:
  143. monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
  144. monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
  145. monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
  146. monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
  147. monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
  148. yield
  149. if MOCK:
  150. monkeypatch.undo()