|
@@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
|
|
|
|
|
class XinferenceLLM(Xinference):
|
|
|
def _call(
|
|
|
- self,
|
|
|
- prompt: str,
|
|
|
- stop: Optional[List[str]] = None,
|
|
|
- run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
- **kwargs: Any,
|
|
|
+ self,
|
|
|
+ prompt: str,
|
|
|
+ stop: Optional[List[str]] = None,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ **kwargs: Any,
|
|
|
) -> str:
|
|
|
"""Call the xinference model and return the output.
|
|
|
|
|
@@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
|
|
|
if generate_config and generate_config.get("stream"):
|
|
|
combined_text_output = ""
|
|
|
for token in self._stream_generate(
|
|
|
- model=model,
|
|
|
- prompt=prompt,
|
|
|
- run_manager=run_manager,
|
|
|
- generate_config=generate_config,
|
|
|
+ model=model,
|
|
|
+ prompt=prompt,
|
|
|
+ run_manager=run_manager,
|
|
|
+ generate_config=generate_config,
|
|
|
):
|
|
|
combined_text_output += token
|
|
|
return combined_text_output
|
|
@@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
|
|
|
if generate_config and generate_config.get("stream"):
|
|
|
combined_text_output = ""
|
|
|
for token in self._stream_generate(
|
|
|
- model=model,
|
|
|
- prompt=prompt,
|
|
|
- run_manager=run_manager,
|
|
|
- generate_config=generate_config,
|
|
|
+ model=model,
|
|
|
+ prompt=prompt,
|
|
|
+ run_manager=run_manager,
|
|
|
+ generate_config=generate_config,
|
|
|
):
|
|
|
combined_text_output += token
|
|
|
completion = combined_text_output
|
|
@@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
|
|
|
|
|
|
return completion
|
|
|
|
|
|
-
|
|
|
def _stream_generate(
|
|
|
- self,
|
|
|
- model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
|
|
- prompt: str,
|
|
|
- run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
- generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
|
|
+ self,
|
|
|
+ model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
|
|
+ prompt: str,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ generate_config: Optional[
|
|
|
+ Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
|
|
) -> Generator[str, None, None]:
|
|
|
"""
|
|
|
Args:
|
|
@@ -123,6 +123,10 @@ class XinferenceLLM(Xinference):
|
|
|
if choices:
|
|
|
choice = choices[0]
|
|
|
if isinstance(choice, dict):
|
|
|
+ if 'finish_reason' in choice and choice['finish_reason'] \
|
|
|
+ and choice['finish_reason'] in ['stop', 'length']:
|
|
|
+ break
|
|
|
+
|
|
|
if 'text' in choice:
|
|
|
token = choice.get("text", "")
|
|
|
elif 'delta' in choice and 'content' in choice['delta']:
|