Bladeren bron

feat: optimize xinference request max token key and stop reason (#998)

takatost 1 jaar geleden
bovenliggende
commit
9ae91a2ec3

+ 1 - 2
api/core/model_providers/providers/xinference_provider.py

@@ -2,7 +2,6 @@ import json
 from typing import Type
 
 import requests
-from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
                 top_p=KwargRule[float](min=0, max=1, default=0.7),
                 presence_penalty=KwargRule[float](enabled=False),
                 frequency_penalty=KwargRule[float](enabled=False),
-                max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
+                max_tokens=KwargRule[int](min=10, max=4000, default=256),
             )
 
 

+ 23 - 19
api/core/third_party/langchain/llms/xinference_llm.py

@@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
 
 class XinferenceLLM(Xinference):
     def _call(
-        self,
-        prompt: str,
-        stop: Optional[List[str]] = None,
-        run_manager: Optional[CallbackManagerForLLMRun] = None,
-        **kwargs: Any,
+            self,
+            prompt: str,
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
     ) -> str:
         """Call the xinference model and return the output.
 
@@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
             if generate_config and generate_config.get("stream"):
                 combined_text_output = ""
                 for token in self._stream_generate(
-                    model=model,
-                    prompt=prompt,
-                    run_manager=run_manager,
-                    generate_config=generate_config,
+                        model=model,
+                        prompt=prompt,
+                        run_manager=run_manager,
+                        generate_config=generate_config,
                 ):
                     combined_text_output += token
                 return combined_text_output
@@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
             if generate_config and generate_config.get("stream"):
                 combined_text_output = ""
                 for token in self._stream_generate(
-                    model=model,
-                    prompt=prompt,
-                    run_manager=run_manager,
-                    generate_config=generate_config,
+                        model=model,
+                        prompt=prompt,
+                        run_manager=run_manager,
+                        generate_config=generate_config,
                 ):
                     combined_text_output += token
                 completion = combined_text_output
@@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
 
             return completion
 
-
     def _stream_generate(
-        self,
-        model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
-        prompt: str,
-        run_manager: Optional[CallbackManagerForLLMRun] = None,
-        generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
+            self,
+            model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
+            prompt: str,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            generate_config: Optional[
+                Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
     ) -> Generator[str, None, None]:
         """
         Args:
@@ -123,6 +123,10 @@ class XinferenceLLM(Xinference):
                 if choices:
                     choice = choices[0]
                     if isinstance(choice, dict):
+                        if 'finish_reason' in choice and choice['finish_reason'] \
+                                and choice['finish_reason'] in ['stop', 'length']:
+                            break
+
                         if 'text' in choice:
                             token = choice.get("text", "")
                         elif 'delta' in choice and 'content' in choice['delta']: