|
@@ -28,7 +28,10 @@ from core.model_runtime.entities.common_entities import I18nObject
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
+ ImagePromptMessageContent,
|
|
|
PromptMessage,
|
|
|
+ PromptMessageContent,
|
|
|
+ PromptMessageContentType,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
ToolPromptMessage,
|
|
@@ -61,8 +64,8 @@ from core.model_runtime.utils import helper
|
|
|
|
|
|
|
|
|
class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
- def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
- model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
+ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
-> LLMResult | Generator:
|
|
|
"""
|
|
@@ -99,7 +102,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
try:
|
|
|
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
|
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
|
|
-
|
|
|
+
|
|
|
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
|
|
server_url=credentials['server_url'],
|
|
|
model_uid=credentials['model_uid']
|
|
@@ -111,10 +114,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
credentials['completion_type'] = 'completion'
|
|
|
else:
|
|
|
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type')
|
|
|
-
|
|
|
+
|
|
|
if extra_param.support_function_call:
|
|
|
credentials['support_function_call'] = True
|
|
|
|
|
|
+ if extra_param.support_vision:
|
|
|
+ credentials['support_vision'] = True
|
|
|
+
|
|
|
if extra_param.context_length:
|
|
|
credentials['context_length'] = extra_param.context_length
|
|
|
|
|
@@ -135,7 +141,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
return self._num_tokens_from_messages(prompt_messages, tools)
|
|
|
|
|
|
- def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
|
|
|
+ def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
|
|
|
is_completion_model: bool = False) -> int:
|
|
|
def tokens(text: str):
|
|
|
return self._get_num_tokens_by_gpt2(text)
|
|
@@ -155,7 +161,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
text = ''
|
|
|
for item in value:
|
|
|
if isinstance(item, dict) and item['type'] == 'text':
|
|
|
- text += item.text
|
|
|
+ text += item['text']
|
|
|
|
|
|
value = text
|
|
|
|
|
@@ -191,7 +197,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
num_tokens += self._num_tokens_for_tools(tools)
|
|
|
|
|
|
return num_tokens
|
|
|
-
|
|
|
+
|
|
|
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
|
|
"""
|
|
|
Calculate num tokens for tool calling
|
|
@@ -234,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
num_tokens += tokens(required_field)
|
|
|
|
|
|
return num_tokens
|
|
|
-
|
|
|
+
|
|
|
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
|
|
|
"""
|
|
|
convert prompt message to text
|
|
@@ -260,7 +266,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
if isinstance(message.content, str):
|
|
|
message_dict = {"role": "user", "content": message.content}
|
|
|
else:
|
|
|
- raise ValueError("User message content must be str")
|
|
|
+ sub_messages = []
|
|
|
+ for message_content in message.content:
|
|
|
+ if message_content.type == PromptMessageContentType.TEXT:
|
|
|
+ message_content = cast(PromptMessageContent, message_content)
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "text",
|
|
|
+ "text": message_content.data
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ elif message_content.type == PromptMessageContentType.IMAGE:
|
|
|
+ message_content = cast(ImagePromptMessageContent, message_content)
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": message_content.data,
|
|
|
+ "detail": message_content.detail.value
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ message_dict = {"role": "user", "content": sub_messages}
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
@@ -277,7 +302,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
|
|
else:
|
|
|
raise ValueError(f"Unknown message type {type(message)}")
|
|
|
-
|
|
|
+
|
|
|
return message_dict
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
@@ -338,8 +363,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
completion_type = LLMMode.COMPLETION.value
|
|
|
else:
|
|
|
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
+ features = []
|
|
|
+
|
|
|
support_function_call = credentials.get('support_function_call', False)
|
|
|
+ if support_function_call:
|
|
|
+ features.append(ModelFeature.TOOL_CALL)
|
|
|
+
|
|
|
+ support_vision = credentials.get('support_vision', False)
|
|
|
+ if support_vision:
|
|
|
+ features.append(ModelFeature.VISION)
|
|
|
+
|
|
|
context_length = credentials.get('context_length', 2048)
|
|
|
|
|
|
entity = AIModelEntity(
|
|
@@ -349,10 +384,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
),
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
model_type=ModelType.LLM,
|
|
|
- features=[
|
|
|
- ModelFeature.TOOL_CALL
|
|
|
- ] if support_function_call else [],
|
|
|
- model_properties={
|
|
|
+ features=features,
|
|
|
+ model_properties={
|
|
|
ModelPropertyKey.MODE: completion_type,
|
|
|
ModelPropertyKey.CONTEXT_SIZE: context_length
|
|
|
},
|
|
@@ -360,22 +393,22 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
return entity
|
|
|
-
|
|
|
- def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+
|
|
|
+ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter,
|
|
|
- tools: list[PromptMessageTool] | None = None,
|
|
|
+ tools: list[PromptMessageTool] | None = None,
|
|
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
-> LLMResult | Generator:
|
|
|
"""
|
|
|
generate text from LLM
|
|
|
|
|
|
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate`
|
|
|
-
|
|
|
+
|
|
|
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
|
|
|
"""
|
|
|
if 'server_url' not in credentials:
|
|
|
raise CredentialsValidateFailedError('server_url is required in credentials')
|
|
|
-
|
|
|
+
|
|
|
if credentials['server_url'].endswith('/'):
|
|
|
credentials['server_url'] = credentials['server_url'][:-1]
|
|
|
|
|
@@ -408,11 +441,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
'function': helper.dump_model(tool)
|
|
|
} for tool in tools
|
|
|
]
|
|
|
-
|
|
|
+ vision = credentials.get('support_vision', False)
|
|
|
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
|
|
|
resp = client.chat.completions.create(
|
|
|
model=credentials['model_uid'],
|
|
|
- messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
|
|
|
+ messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
|
|
|
stream=stream,
|
|
|
user=user,
|
|
|
**generate_config,
|
|
@@ -497,7 +530,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
if len(resp.choices) == 0:
|
|
|
raise InvokeServerUnavailableError("Empty response")
|
|
|
-
|
|
|
+
|
|
|
assistant_message = resp.choices[0].message
|
|
|
|
|
|
# convert tool call to assistant message tool call
|
|
@@ -527,7 +560,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
return response
|
|
|
-
|
|
|
+
|
|
|
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
tools: list[PromptMessageTool],
|
|
|
resp: Iterator[ChatCompletionChunk]) -> Generator:
|
|
@@ -544,7 +577,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
# check if there is a tool call in the response
|
|
|
function_call = None
|
|
|
tool_calls = []
|
|
@@ -573,9 +606,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
|
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
|
|
|
|
|
- usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
|
-
|
|
|
+
|
|
|
yield LLMResultChunk(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
@@ -608,7 +641,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
if len(resp.choices) == 0:
|
|
|
raise InvokeServerUnavailableError("Empty response")
|
|
|
-
|
|
|
+
|
|
|
assistant_message = resp.choices[0].text
|
|
|
|
|
|
# transform assistant message to prompt message
|
|
@@ -670,9 +703,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
completion_tokens = self._num_tokens_from_messages(
|
|
|
messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True
|
|
|
)
|
|
|
- usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
|
-
|
|
|
+
|
|
|
yield LLMResultChunk(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|