|
@@ -26,13 +26,15 @@ cache_lock = Lock()
|
|
|
|
|
|
class TeiHelper:
|
|
|
@staticmethod
|
|
|
- def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
|
|
+ def get_tei_extra_parameter(
|
|
|
+ server_url: str, model_name: str, headers: Optional[dict] = None
|
|
|
+ ) -> TeiModelExtraParameter:
|
|
|
TeiHelper._clean_cache()
|
|
|
with cache_lock:
|
|
|
if model_name not in cache:
|
|
|
cache[model_name] = {
|
|
|
"expires": time() + 300,
|
|
|
- "value": TeiHelper._get_tei_extra_parameter(server_url),
|
|
|
+ "value": TeiHelper._get_tei_extra_parameter(server_url, headers),
|
|
|
}
|
|
|
return cache[model_name]["value"]
|
|
|
|
|
@@ -47,7 +49,7 @@ class TeiHelper:
|
|
|
pass
|
|
|
|
|
|
@staticmethod
|
|
|
- def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
|
|
|
+ def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
|
|
|
"""
|
|
|
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
|
|
"""
|
|
@@ -61,7 +63,7 @@ class TeiHelper:
|
|
|
session.mount("https://", HTTPAdapter(max_retries=3))
|
|
|
|
|
|
try:
|
|
|
- response = session.get(url, timeout=10)
|
|
|
+ response = session.get(url, headers=headers, timeout=10)
|
|
|
except (MissingSchema, ConnectionError, Timeout) as e:
|
|
|
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
|
|
|
if response.status_code != 200:
|
|
@@ -86,7 +88,7 @@ class TeiHelper:
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
- def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
|
|
+ def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
|
|
|
"""
|
|
|
Invoke tokenize endpoint
|
|
|
|
|
@@ -114,15 +116,15 @@ class TeiHelper:
|
|
|
:param server_url: server url
|
|
|
:param texts: texts to tokenize
|
|
|
"""
|
|
|
- resp = httpx.post(
|
|
|
- f"{server_url}/tokenize",
|
|
|
- json={"inputs": texts},
|
|
|
- )
|
|
|
+ url = f"{server_url}/tokenize"
|
|
|
+ json_data = {"inputs": texts}
|
|
|
+ resp = httpx.post(url, json=json_data, headers=headers)
|
|
|
+
|
|
|
resp.raise_for_status()
|
|
|
return resp.json()
|
|
|
|
|
|
@staticmethod
|
|
|
- def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
|
|
+ def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
|
|
|
"""
|
|
|
Invoke embeddings endpoint
|
|
|
|
|
@@ -147,15 +149,14 @@ class TeiHelper:
|
|
|
:param texts: texts to embed
|
|
|
"""
|
|
|
# Use OpenAI compatible API here, which has usage tracking
|
|
|
- resp = httpx.post(
|
|
|
- f"{server_url}/v1/embeddings",
|
|
|
- json={"input": texts},
|
|
|
- )
|
|
|
+ url = f"{server_url}/v1/embeddings"
|
|
|
+ json_data = {"input": texts}
|
|
|
+ resp = httpx.post(url, json=json_data, headers=headers)
|
|
|
resp.raise_for_status()
|
|
|
return resp.json()
|
|
|
|
|
|
@staticmethod
|
|
|
- def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
|
|
|
+ def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
|
|
|
"""
|
|
|
Invoke rerank endpoint
|
|
|
|
|
@@ -173,10 +174,7 @@ class TeiHelper:
|
|
|
:param candidates: candidates to rerank
|
|
|
"""
|
|
|
params = {"query": query, "texts": docs, "return_text": True}
|
|
|
-
|
|
|
- response = httpx.post(
|
|
|
- server_url + "/rerank",
|
|
|
- json=params,
|
|
|
- )
|
|
|
+ url = f"{server_url}/rerank"
|
|
|
+ response = httpx.post(url, json=params, headers=headers)
|
|
|
response.raise_for_status()
|
|
|
return response.json()
|