瀏覽代碼

feat: add xinference llm context size (#2336)

Yeuoly 1 年之前
父節點
當前提交
0c330fc020

+ 5 - 0
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -75,6 +75,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             if extra_param.support_function_call:
                 credentials['support_function_call'] = True
 
+            if extra_param.context_length:
+                credentials['context_length'] = extra_param.context_length
+
         except RuntimeError as e:
             raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
         except KeyError as e:
@@ -296,6 +299,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                 raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
             
         support_function_call = credentials.get('support_function_call', False)
+        context_length = credentials.get('context_length', 2048)
 
         entity = AIModelEntity(
             model=model,
@@ -309,6 +313,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             ] if support_function_call else [],
             model_properties={ 
                 ModelPropertyKey.MODE: completion_type,
+                ModelPropertyKey.CONTEXT_SIZE: context_length
             },
             parameter_rules=rules
         )

+ 8 - 3
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -14,15 +14,17 @@ class XinferenceModelExtraParameter(object):
     model_handle_type: str
     model_ability: List[str]
     max_tokens: int = 512
+    context_length: int = 2048
     support_function_call: bool = False
 
     def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], 
-                 support_function_call: bool, max_tokens: int) -> None:
+                 support_function_call: bool, max_tokens: int, context_length: int) -> None:
         self.model_format = model_format
         self.model_handle_type = model_handle_type
         self.model_ability = model_ability
         self.support_function_call = support_function_call
         self.max_tokens = max_tokens
+        self.context_length = context_length
 
 cache = {}
 cache_lock = Lock()
@@ -57,7 +59,7 @@ class XinferenceHelper:
 
         url = path.join(server_url, 'v1/models', model_uid)
 
-        # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
+        # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
         session = Session()
         session.mount('http://', HTTPAdapter(max_retries=3))
         session.mount('https://', HTTPAdapter(max_retries=3))
@@ -88,11 +90,14 @@ class XinferenceHelper:
         
         support_function_call = 'tools' in model_ability
         max_tokens = response_json.get('max_tokens', 512)
+
+        context_length = response_json.get('context_length', 2048)
         
         return XinferenceModelExtraParameter(
             model_format=model_format,
             model_handle_type=model_handle_type,
             model_ability=model_ability,
             support_function_call=support_function_call,
-            max_tokens=max_tokens
+            max_tokens=max_tokens,
+            context_length=context_length
         )