|
@@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import (
|
|
|
AIModelEntity,
|
|
|
DefaultParameterName,
|
|
|
FetchFrom,
|
|
|
+ ModelFeature,
|
|
|
ModelPropertyKey,
|
|
|
ModelType,
|
|
|
ParameterRule,
|
|
@@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
"""
|
|
|
generate custom model entities from credentials
|
|
|
"""
|
|
|
+ support_function_call = False
|
|
|
+ features = []
|
|
|
+ function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
+ if function_calling_type == 'function_call':
|
|
|
+ features = [ModelFeature.TOOL_CALL]
|
|
|
+ support_function_call = True
|
|
|
+ endpoint_url = credentials["endpoint_url"]
|
|
|
+ # if not endpoint_url.endswith('/'):
|
|
|
+ # endpoint_url += '/'
|
|
|
+ # if 'https://api.openai.com/v1/' == endpoint_url:
|
|
|
+ # features = [ModelFeature.STREAM_TOOL_CALL]
|
|
|
entity = AIModelEntity(
|
|
|
model=model,
|
|
|
label=I18nObject(en_US=model),
|
|
|
model_type=ModelType.LLM,
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ features=features if support_function_call else [],
|
|
|
model_properties={
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
|
|
ModelPropertyKey.MODE: credentials.get('mode'),
|
|
@@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
max=1,
|
|
|
precision=2
|
|
|
),
|
|
|
- ParameterRule(
|
|
|
- name="top_k",
|
|
|
- label=I18nObject(en_US="Top K"),
|
|
|
- type=ParameterType.INT,
|
|
|
- default=int(credentials.get('top_k', 1)),
|
|
|
- min=1,
|
|
|
- max=100
|
|
|
- ),
|
|
|
ParameterRule(
|
|
|
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
|
|
label=I18nObject(en_US="Frequency Penalty"),
|
|
@@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
output=Decimal(credentials.get('output_price', 0)),
|
|
|
unit=Decimal(credentials.get('unit', 0)),
|
|
|
currency=credentials.get('currency', "USD")
|
|
|
- )
|
|
|
+ ),
|
|
|
)
|
|
|
|
|
|
if credentials['mode'] == 'chat':
|
|
@@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
raise ValueError("Unsupported completion type for model configuration.")
|
|
|
|
|
|
# annotate tools with names, descriptions, etc.
|
|
|
+ function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
formatted_tools = []
|
|
|
if tools:
|
|
|
- data["tool_choice"] = "auto"
|
|
|
+ if function_calling_type == 'function_call':
|
|
|
+ data['functions'] = [{
|
|
|
+ "name": tool.name,
|
|
|
+ "description": tool.description,
|
|
|
+ "parameters": tool.parameters
|
|
|
+ } for tool in tools]
|
|
|
+ elif function_calling_type == 'tool_call':
|
|
|
+ data["tool_choice"] = "auto"
|
|
|
|
|
|
- for tool in tools:
|
|
|
- formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
|
|
+ for tool in tools:
|
|
|
+ formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
|
|
|
|
|
- data["tools"] = formatted_tools
|
|
|
+ data["tools"] = formatted_tools
|
|
|
|
|
|
if stop:
|
|
|
data["stop"] = stop
|
|
@@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
|
|
if chunk:
|
|
|
- #ignore sse comments
|
|
|
+ # ignore sse comments
|
|
|
if chunk.startswith(':'):
|
|
|
- continue
|
|
|
+ continue
|
|
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
|
|
chunk_json = None
|
|
|
try:
|
|
@@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
|
response_content = ''
|
|
|
tool_calls = None
|
|
|
-
|
|
|
+ function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
if completion_type is LLMMode.CHAT:
|
|
|
response_content = output.get('message', {})['content']
|
|
|
- tool_calls = output.get('message', {}).get('tool_calls')
|
|
|
+ if function_calling_type == 'tool_call':
|
|
|
+ tool_calls = output.get('message', {}).get('tool_calls')
|
|
|
+ elif function_calling_type == 'function_call':
|
|
|
+ tool_calls = output.get('message', {}).get('function_call')
|
|
|
|
|
|
elif completion_type is LLMMode.COMPLETION:
|
|
|
response_content = output['text']
|
|
@@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
|
|
|
|
|
if tool_calls:
|
|
|
- assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
|
|
+ if function_calling_type == 'tool_call':
|
|
|
+ assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
|
|
+ elif function_calling_type == 'function_call':
|
|
|
+ assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
|
|
|
|
|
usage = response_json.get("usage")
|
|
|
if usage:
|
|
@@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
|
if message.tool_calls:
|
|
|
- message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
|
|
- in
|
|
|
- message.tool_calls]
|
|
|
- # function_call = message.tool_calls[0]
|
|
|
- # message_dict["function_call"] = {
|
|
|
- # "name": function_call.function.name,
|
|
|
- # "arguments": function_call.function.arguments,
|
|
|
- # }
|
|
|
+ # message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
|
|
+ # in
|
|
|
+ # message.tool_calls]
|
|
|
+
|
|
|
+ function_call = message.tool_calls[0]
|
|
|
+ message_dict["function_call"] = {
|
|
|
+ "name": function_call.function.name,
|
|
|
+ "arguments": function_call.function.arguments,
|
|
|
+ }
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
message_dict = {"role": "system", "content": message.content}
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
message = cast(ToolPromptMessage, message)
|
|
|
- message_dict = {
|
|
|
- "role": "tool",
|
|
|
- "content": message.content,
|
|
|
- "tool_call_id": message.tool_call_id
|
|
|
- }
|
|
|
# message_dict = {
|
|
|
- # "role": "function",
|
|
|
+ # "role": "tool",
|
|
|
# "content": message.content,
|
|
|
- # "name": message.tool_call_id
|
|
|
+ # "tool_call_id": message.tool_call_id
|
|
|
# }
|
|
|
+ message_dict = {
|
|
|
+ "role": "function",
|
|
|
+ "content": message.content,
|
|
|
+ "name": message.tool_call_id
|
|
|
+ }
|
|
|
else:
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
- if message.name is not None:
|
|
|
+ if message.name:
|
|
|
message_dict["name"] = message.name
|
|
|
|
|
|
return message_dict
|
|
@@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
tool_calls.append(tool_call)
|
|
|
|
|
|
return tool_calls
|
|
|
+
|
|
|
+ def _extract_response_function_call(self, response_function_call) \
|
|
|
+ -> AssistantPromptMessage.ToolCall:
|
|
|
+ """
|
|
|
+ Extract function call from response
|
|
|
+
|
|
|
+ :param response_function_call: response function call
|
|
|
+ :return: tool call
|
|
|
+ """
|
|
|
+ tool_call = None
|
|
|
+ if response_function_call:
|
|
|
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=response_function_call['name'],
|
|
|
+ arguments=response_function_call['arguments']
|
|
|
+ )
|
|
|
+
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=response_function_call['name'],
|
|
|
+ type="function",
|
|
|
+ function=function
|
|
|
+ )
|
|
|
+
|
|
|
+ return tool_call
|