xinference.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import os
  2. import re
  3. from typing import List, Union
  4. import pytest
  5. from _pytest.monkeypatch import MonkeyPatch
  6. from requests import Response
  7. from requests.exceptions import ConnectionError
  8. from requests.sessions import Session
  9. from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle,
  10. RESTfulChatModelHandle, RESTfulEmbeddingModelHandle,
  11. RESTfulGenerateModelHandle, RESTfulRerankModelHandle)
  12. from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
  13. class MockXinferenceClass(object):
  14. def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
  15. if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
  16. raise RuntimeError('404 Not Found')
  17. if 'generate' == model_uid:
  18. return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  19. if 'chat' == model_uid:
  20. return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  21. if 'embedding' == model_uid:
  22. return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  23. if 'rerank' == model_uid:
  24. return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  25. raise RuntimeError('404 Not Found')
  26. def get(self: Session, url: str, **kwargs):
  27. response = Response()
  28. if 'v1/models/' in url:
  29. # get model uid
  30. model_uid = url.split('/')[-1]
  31. if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
  32. model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
  33. response.status_code = 404
  34. return response
  35. # check if url is valid
  36. if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
  37. response.status_code = 404
  38. return response
  39. if model_uid in ['generate', 'chat']:
  40. response.status_code = 200
  41. response._content = b'''{
  42. "model_type": "LLM",
  43. "address": "127.0.0.1:43877",
  44. "accelerators": [
  45. "0",
  46. "1"
  47. ],
  48. "model_name": "chatglm3-6b",
  49. "model_lang": [
  50. "en"
  51. ],
  52. "model_ability": [
  53. "generate",
  54. "chat"
  55. ],
  56. "model_description": "latest chatglm3",
  57. "model_format": "pytorch",
  58. "model_size_in_billions": 7,
  59. "quantization": "none",
  60. "model_hub": "huggingface",
  61. "revision": null,
  62. "context_length": 2048,
  63. "replica": 1
  64. }'''
  65. return response
  66. elif model_uid == 'embedding':
  67. response.status_code = 200
  68. response._content = b'''{
  69. "model_type": "embedding",
  70. "address": "127.0.0.1:43877",
  71. "accelerators": [
  72. "0",
  73. "1"
  74. ],
  75. "model_name": "bge",
  76. "model_lang": [
  77. "en"
  78. ],
  79. "revision": null,
  80. "max_tokens": 512
  81. }'''
  82. return response
  83. elif 'v1/cluster/auth' in url:
  84. response.status_code = 200
  85. response._content = b'''{
  86. "auth": true
  87. }'''
  88. return response
  89. def _check_cluster_authenticated(self):
  90. self._cluster_authed = True
  91. def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
  92. # check if self._model_uid is a valid uuid
  93. if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
  94. self._model_uid != 'rerank':
  95. raise RuntimeError('404 Not Found')
  96. if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
  97. raise RuntimeError('404 Not Found')
  98. if top_n is None:
  99. top_n = 1
  100. return {
  101. 'results': [
  102. {
  103. 'index': i,
  104. 'document': doc,
  105. 'relevance_score': 0.9
  106. }
  107. for i, doc in enumerate(documents[:top_n])
  108. ]
  109. }
  110. def create_embedding(
  111. self: RESTfulGenerateModelHandle,
  112. input: Union[str, List[str]],
  113. **kwargs
  114. ) -> dict:
  115. # check if self._model_uid is a valid uuid
  116. if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
  117. self._model_uid != 'embedding':
  118. raise RuntimeError('404 Not Found')
  119. if isinstance(input, str):
  120. input = [input]
  121. ipt_len = len(input)
  122. embedding = Embedding(
  123. object="list",
  124. model=self._model_uid,
  125. data=[
  126. EmbeddingData(
  127. index=i,
  128. object="embedding",
  129. embedding=[1919.810 for _ in range(768)]
  130. )
  131. for i in range(ipt_len)
  132. ],
  133. usage=EmbeddingUsage(
  134. prompt_tokens=ipt_len,
  135. total_tokens=ipt_len
  136. )
  137. )
  138. return embedding
  139. MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
  140. @pytest.fixture
  141. def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
  142. if MOCK:
  143. monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
  144. monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
  145. monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
  146. monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
  147. monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
  148. yield
  149. if MOCK:
  150. monkeypatch.undo()