|
@@ -63,6 +63,9 @@ from core.model_runtime.model_providers.xinference.xinference_helper import (
|
|
|
)
|
|
|
from core.model_runtime.utils import helper
|
|
|
|
|
|
+DEFAULT_MAX_RETRIES = 3
|
|
|
+DEFAULT_INVOKE_TIMEOUT = 60
|
|
|
+
|
|
|
|
|
|
class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
def _invoke(
|
|
@@ -315,7 +318,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
message_dict = {"role": "system", "content": message.content}
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
message = cast(ToolPromptMessage, message)
|
|
|
- message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
|
|
+ message_dict = {
|
|
|
+ "tool_call_id": message.tool_call_id,
|
|
|
+ "role": "tool",
|
|
|
+ "content": message.content,
|
|
|
+ "name": message.name,
|
|
|
+ }
|
|
|
else:
|
|
|
raise ValueError(f"Unknown message type {type(message)}")
|
|
|
|
|
@@ -466,8 +474,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
client = OpenAI(
|
|
|
base_url=f'{credentials["server_url"]}/v1',
|
|
|
api_key=api_key,
|
|
|
- max_retries=3,
|
|
|
- timeout=60,
|
|
|
+ max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
|
|
|
+ timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),
|
|
|
)
|
|
|
|
|
|
xinference_client = Client(
|