فهرست منبع

fix: xinference last token being ignored (#1013)

Uranus 1 سال پیش
والد
کامیت
2d9616c29c
1فایلهای تغییر یافته به همراه55 افزوده شده و 37 حذف شده
  1. 55 37
      api/core/third_party/langchain/llms/xinference_llm.py

+ 55 - 37
api/core/third_party/langchain/llms/xinference_llm.py

@@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator
 from langchain.callbacks.manager import CallbackManagerForLLMRun
 from langchain.llms import Xinference
 from langchain.llms.utils import enforce_stop_tokens
-from xinference.client import RESTfulChatglmCppChatModelHandle, \
-    RESTfulChatModelHandle, RESTfulGenerateModelHandle
+from xinference.client import (
+    RESTfulChatglmCppChatModelHandle,
+    RESTfulChatModelHandle,
+    RESTfulGenerateModelHandle,
+)
 
 
 class XinferenceLLM(Xinference):
     def _call(
-            self,
-            prompt: str,
-            stop: Optional[List[str]] = None,
-            run_manager: Optional[CallbackManagerForLLMRun] = None,
-            **kwargs: Any,
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        **kwargs: Any,
     ) -> str:
         """Call the xinference model and return the output.
 
@@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
         model = self.client.get_model(self.model_uid)
 
         if isinstance(model, RESTfulChatModelHandle):
-            generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
+            generate_config: "LlamaCppGenerateConfig" = kwargs.get(
+                "generate_config", {}
+            )
 
             if stop:
                 generate_config["stop"] = stop
@@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
             if generate_config and generate_config.get("stream"):
                 combined_text_output = ""
                 for token in self._stream_generate(
-                        model=model,
-                        prompt=prompt,
-                        run_manager=run_manager,
-                        generate_config=generate_config,
+                    model=model,
+                    prompt=prompt,
+                    run_manager=run_manager,
+                    generate_config=generate_config,
                 ):
                     combined_text_output += token
                 return combined_text_output
@@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
                 completion = model.chat(prompt=prompt, generate_config=generate_config)
                 return completion["choices"][0]["message"]["content"]
         elif isinstance(model, RESTfulGenerateModelHandle):
-            generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
+            generate_config: "LlamaCppGenerateConfig" = kwargs.get(
+                "generate_config", {}
+            )
 
             if stop:
                 generate_config["stop"] = stop
@@ -56,27 +63,31 @@ class XinferenceLLM(Xinference):
             if generate_config and generate_config.get("stream"):
                 combined_text_output = ""
                 for token in self._stream_generate(
-                        model=model,
-                        prompt=prompt,
-                        run_manager=run_manager,
-                        generate_config=generate_config,
+                    model=model,
+                    prompt=prompt,
+                    run_manager=run_manager,
+                    generate_config=generate_config,
                 ):
                     combined_text_output += token
                 return combined_text_output
 
             else:
-                completion = model.generate(prompt=prompt, generate_config=generate_config)
+                completion = model.generate(
+                    prompt=prompt, generate_config=generate_config
+                )
                 return completion["choices"][0]["text"]
         elif isinstance(model, RESTfulChatglmCppChatModelHandle):
-            generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
+            generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
+                "generate_config", {}
+            )
 
             if generate_config and generate_config.get("stream"):
                 combined_text_output = ""
                 for token in self._stream_generate(
-                        model=model,
-                        prompt=prompt,
-                        run_manager=run_manager,
-                        generate_config=generate_config,
+                    model=model,
+                    prompt=prompt,
+                    run_manager=run_manager,
+                    generate_config=generate_config,
                 ):
                     combined_text_output += token
                 completion = combined_text_output
@@ -90,12 +101,21 @@ class XinferenceLLM(Xinference):
             return completion
 
     def _stream_generate(
-            self,
-            model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
-            prompt: str,
-            run_manager: Optional[CallbackManagerForLLMRun] = None,
-            generate_config: Optional[
-                Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
+        self,
+        model: Union[
+            "RESTfulGenerateModelHandle",
+            "RESTfulChatModelHandle",
+            "RESTfulChatglmCppChatModelHandle",
+        ],
+        prompt: str,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        generate_config: Optional[
+            Union[
+                "LlamaCppGenerateConfig",
+                "PytorchGenerateConfig",
+                "ChatglmCppGenerateConfig",
+            ]
+        ] = None,
     ) -> Generator[str, None, None]:
         """
         Args:
@@ -108,7 +128,9 @@ class XinferenceLLM(Xinference):
         Yields:
             A string token.
         """
-        if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)):
+        if isinstance(
+            model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
+        ):
             streaming_response = model.chat(
                 prompt=prompt, generate_config=generate_config
             )
@@ -123,14 +145,10 @@ class XinferenceLLM(Xinference):
                 if choices:
                     choice = choices[0]
                     if isinstance(choice, dict):
-                        if 'finish_reason' in choice and choice['finish_reason'] \
-                                and choice['finish_reason'] in ['stop', 'length']:
-                            break
-
-                        if 'text' in choice:
+                        if "text" in choice:
                             token = choice.get("text", "")
-                        elif 'delta' in choice and 'content' in choice['delta']:
-                            token = choice.get('delta').get('content')
+                        elif "delta" in choice and "content" in choice["delta"]:
+                            token = choice.get("delta").get("content")
                         else:
                             continue
                         log_probs = choice.get("logprobs")