|
@@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
"""
|
|
|
server_url = credentials['server_url']
|
|
|
model_uid = credentials['model_uid']
|
|
|
-
|
|
|
+ api_key = credentials.get('api_key')
|
|
|
if server_url.endswith('/'):
|
|
|
server_url = server_url[:-1]
|
|
|
+ auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
|
|
|
|
|
try:
|
|
|
- handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
|
|
+ handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
|
|
|
embeddings = handle.create_embedding(input=texts)
|
|
|
except RuntimeError as e:
|
|
|
- raise InvokeServerUnavailableError(e)
|
|
|
-
|
|
|
+ raise InvokeServerUnavailableError(str(e))
|
|
|
+
|
|
|
"""
|
|
|
for convenience, the response json is like:
|
|
|
class Embedding(TypedDict):
|
|
@@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
try:
|
|
|
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
|
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
|
|
-
|
|
|
+
|
|
|
server_url = credentials['server_url']
|
|
|
model_uid = credentials['model_uid']
|
|
|
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
|
@@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
server_url = server_url[:-1]
|
|
|
|
|
|
client = Client(base_url=server_url)
|
|
|
-
|
|
|
+
|
|
|
try:
|
|
|
handle = client.get_model(model_uid=model_uid)
|
|
|
except RuntimeError as e:
|
|
@@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
KeyError
|
|
|
]
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
|
"""
|
|
|
Calculate response usage
|
|
@@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
|
|
"""
|
|
|
used to define customizable model schema
|
|
|
"""
|
|
|
-
|
|
|
+
|
|
|
entity = AIModelEntity(
|
|
|
model=model,
|
|
|
label=I18nObject(
|