|
@@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
|
|
+from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
|
|
|
|
|
|
|
|
|
class XinferenceRerankModel(RerankModel):
|
|
@@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
|
|
|
)
|
|
|
|
|
|
# score threshold check
|
|
|
- if score_threshold is not None:
|
|
|
- if result["relevance_score"] >= score_threshold:
|
|
|
- rerank_documents.append(rerank_document)
|
|
|
- else:
|
|
|
+ if score_threshold is None or result["relevance_score"] >= score_threshold:
|
|
|
rerank_documents.append(rerank_document)
|
|
|
|
|
|
return RerankResult(model=model, docs=rerank_documents)
|
|
@@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
|
|
|
:return:
|
|
|
"""
|
|
|
try:
|
|
|
- if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
|
|
|
+ if not validate_model_uid(credentials):
|
|
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
|
|
|
|
|
credentials["server_url"] = credentials["server_url"].removesuffix("/")
|