Browse Source

refactor: optimize the calculation of rerank threshold and the logic for forbidden characters in model_uid (#8879)

zhuhao 6 months ago
parent
commit
77aef9ff1d

+ 2 - 1
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -59,6 +59,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
 from core.model_runtime.model_providers.xinference.xinference_helper import (
     XinferenceHelper,
     XinferenceModelExtraParameter,
+    validate_model_uid,
 )
 from core.model_runtime.utils import helper
 
@@ -114,7 +115,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         }
         """
         try:
-            if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
+            if not validate_model_uid(credentials):
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 
             extra_param = XinferenceHelper.get_xinference_extra_parameter(

+ 3 - 5
api/core/model_runtime/model_providers/xinference/rerank/rerank.py

@@ -15,6 +15,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
 
 
 class XinferenceRerankModel(RerankModel):
@@ -77,10 +78,7 @@ class XinferenceRerankModel(RerankModel):
             )
 
             # score threshold check
-            if score_threshold is not None:
-                if result["relevance_score"] >= score_threshold:
-                    rerank_documents.append(rerank_document)
-            else:
+            if score_threshold is None or result["relevance_score"] >= score_threshold:
                 rerank_documents.append(rerank_document)
 
         return RerankResult(model=model, docs=rerank_documents)
@@ -94,7 +92,7 @@ class XinferenceRerankModel(RerankModel):
         :return:
         """
         try:
-            if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
+            if not validate_model_uid(credentials):
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 
             credentials["server_url"] = credentials["server_url"].removesuffix("/")

+ 2 - 1
api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py

@@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid
 
 
 class XinferenceSpeech2TextModel(Speech2TextModel):
@@ -42,7 +43,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
         :return:
         """
         try:
-            if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
+            if not validate_model_uid(credentials):
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 
             credentials["server_url"] = credentials["server_url"].removesuffix("/")

+ 2 - 2
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -17,7 +17,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
+from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
 
 
 class XinferenceTextEmbeddingModel(TextEmbeddingModel):
@@ -110,7 +110,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         :return:
         """
         try:
-            if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
+            if not validate_model_uid(credentials):
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 
             server_url = credentials["server_url"]

+ 2 - 2
api/core/model_runtime/model_providers/xinference/tts/tts.py

@@ -15,7 +15,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.tts_model import TTSModel
-from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
+from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid
 
 
 class XinferenceText2SpeechModel(TTSModel):
@@ -70,7 +70,7 @@ class XinferenceText2SpeechModel(TTSModel):
         :return:
         """
         try:
-            if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
+            if not validate_model_uid(credentials):
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 
             credentials["server_url"] = credentials["server_url"].removesuffix("/")

+ 13 - 0
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -132,3 +132,16 @@ class XinferenceHelper:
             context_length=context_length,
             model_family=model_family,
         )
+
+
+def validate_model_uid(credentials: dict) -> bool:
+    """
+    Validate the model_uid within the credentials dictionary to ensure it does not
+    contain forbidden characters ("/", "?", "#").
+
+    param credentials: model credentials
+    :return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
+    """
+    forbidden_characters = ["/", "?", "#"]
+    model_uid = credentials.get("model_uid", "")
+    return not any(char in forbidden_characters for char in model_uid)