|
@@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
if stop:
|
|
|
req_params['stop'] = stop
|
|
|
|
|
|
+ extra_model_kwargs = {}
|
|
|
+
|
|
|
+ if tools:
|
|
|
+ extra_model_kwargs['tools'] = [
|
|
|
+ MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
|
|
|
+ ]
|
|
|
+
|
|
|
resp = MaaSClient.wrap_exception(
|
|
|
- lambda: client.chat(req_params, prompt_messages, stream))
|
|
|
+ lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
|
|
if not stream:
|
|
|
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
|
|
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
|
@@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
choice = choices[0]
|
|
|
message = choice['message']
|
|
|
|
|
|
+ # parse tool calls
|
|
|
+ tool_calls = []
|
|
|
+ if message['tool_calls']:
|
|
|
+ for call in message['tool_calls']:
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=call['function']['name'],
|
|
|
+ type=call['type'],
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=call['function']['name'],
|
|
|
+ arguments=call['function']['arguments']
|
|
|
+ )
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
+
|
|
|
return LLMResult(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
message=AssistantPromptMessage(
|
|
|
content=message['content'] if message['content'] else '',
|
|
|
- tool_calls=[],
|
|
|
+ tool_calls=tool_calls,
|
|
|
),
|
|
|
usage=self._calc_usage(model, credentials, resp['usage']),
|
|
|
)
|
|
@@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
if credentials.get('context_size'):
|
|
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
|
|
credentials.get('context_size', 4096))
|
|
|
+
|
|
|
+ model_features = ModelConfigs.get(
|
|
|
+ credentials['base_model_name'], {}).get('features', [])
|
|
|
+
|
|
|
entity = AIModelEntity(
|
|
|
model=model,
|
|
|
label=I18nObject(
|
|
@@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
model_type=ModelType.LLM,
|
|
|
model_properties=model_properties,
|
|
|
- parameter_rules=rules
|
|
|
+ parameter_rules=rules,
|
|
|
+ features=model_features,
|
|
|
)
|
|
|
|
|
|
return entity
|