|
@@ -1,14 +1,13 @@
|
|
|
import copy
|
|
|
import logging
|
|
|
-from collections.abc import Generator
|
|
|
+from collections.abc import Generator, Sequence
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
|
import tiktoken
|
|
|
from openai import AzureOpenAI, Stream
|
|
|
from openai.types import Completion
|
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
|
|
-from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
|
|
-from openai.types.chat.chat_completion_message import FunctionCall
|
|
|
+from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
@@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
|
|
ImagePromptMessageContent,
|
|
|
PromptMessage,
|
|
|
PromptMessageContentType,
|
|
|
+ PromptMessageFunction,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
TextPromptMessageContent,
|
|
@@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
|
|
-from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
|
|
|
+from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
|
|
|
+from core.model_runtime.utils import helper
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
stream: bool = True, user: Optional[str] = None) \
|
|
|
-> Union[LLMResult, Generator]:
|
|
|
|
|
|
- ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
|
|
+ base_model_name = credentials.get('base_model_name')
|
|
|
+ if not base_model_name:
|
|
|
+ raise ValueError('Base Model Name is required')
|
|
|
+ ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
|
|
|
|
- if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
|
|
+ if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
|
|
# chat model
|
|
|
return self._chat_generate(
|
|
|
model=model,
|
|
@@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
user=user
|
|
|
)
|
|
|
|
|
|
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
- tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
|
-
|
|
|
- model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
|
|
|
- ModelPropertyKey.MODE)
|
|
|
+ def get_num_tokens(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None
|
|
|
+ ) -> int:
|
|
|
+ base_model_name = credentials.get('base_model_name')
|
|
|
+ if not base_model_name:
|
|
|
+ raise ValueError('Base Model Name is required')
|
|
|
+ model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
|
+ if not model_entity:
|
|
|
+ raise ValueError(f'Base Model Name {base_model_name} is invalid')
|
|
|
+ model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
|
|
|
|
|
if model_mode == LLMMode.CHAT.value:
|
|
|
# chat model
|
|
|
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
|
|
else:
|
|
|
# text completion model, do not support tool calling
|
|
|
- return self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
|
|
+ content = prompt_messages[0].content
|
|
|
+ assert isinstance(content, str)
|
|
|
+ return self._num_tokens_from_string(credentials,content)
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
if 'openai_api_base' not in credentials:
|
|
@@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
if 'base_model_name' not in credentials:
|
|
|
raise CredentialsValidateFailedError('Base Model Name is required')
|
|
|
|
|
|
- ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
|
|
+ base_model_name = credentials.get('base_model_name')
|
|
|
+ if not base_model_name:
|
|
|
+ raise CredentialsValidateFailedError('Base Model Name is required')
|
|
|
+ ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
|
|
|
|
if not ai_model_entity:
|
|
|
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
|
@@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
|
|
- ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
|
|
+ base_model_name = credentials.get('base_model_name')
|
|
|
+ if not base_model_name:
|
|
|
+ raise ValueError('Base Model Name is required')
|
|
|
+ ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
|
return ai_model_entity.entity if ai_model_entity else None
|
|
|
|
|
|
def _generate(self, model: str, credentials: dict,
|
|
@@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
|
|
|
|
- def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
|
|
- prompt_messages: list[PromptMessage]) -> LLMResult:
|
|
|
+ def _handle_generate_response(
|
|
|
+ self, model: str, credentials: dict, response: Completion,
|
|
|
+ prompt_messages: list[PromptMessage]
|
|
|
+ ):
|
|
|
assistant_text = response.choices[0].text
|
|
|
|
|
|
# transform assistant message to prompt message
|
|
@@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
completion_tokens = response.usage.completion_tokens
|
|
|
else:
|
|
|
# calculate num tokens
|
|
|
- prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
|
|
+ content = prompt_messages[0].content
|
|
|
+ assert isinstance(content, str)
|
|
|
+ prompt_tokens = self._num_tokens_from_string(credentials, content)
|
|
|
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
|
|
|
|
|
# transform usage
|
|
@@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
|
|
|
return result
|
|
|
|
|
|
- def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
|
|
- prompt_messages: list[PromptMessage]) -> Generator:
|
|
|
+ def _handle_generate_stream_response(
|
|
|
+ self, model: str, credentials: dict, response: Stream[Completion],
|
|
|
+ prompt_messages: list[PromptMessage]
|
|
|
+ ) -> Generator:
|
|
|
full_text = ''
|
|
|
for chunk in response:
|
|
|
if len(chunk.choices) == 0:
|
|
@@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
completion_tokens = chunk.usage.completion_tokens
|
|
|
else:
|
|
|
# calculate num tokens
|
|
|
- prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
|
|
+ content = prompt_messages[0].content
|
|
|
+ assert isinstance(content, str)
|
|
|
+ prompt_tokens = self._num_tokens_from_string(credentials, content)
|
|
|
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
|
|
|
|
|
# transform usage
|
|
@@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
extra_model_kwargs = {}
|
|
|
|
|
|
if tools:
|
|
|
- # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
|
|
- extra_model_kwargs['functions'] = [{
|
|
|
- "name": tool.name,
|
|
|
- "description": tool.description,
|
|
|
- "parameters": tool.parameters
|
|
|
- } for tool in tools]
|
|
|
+ extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
|
|
+ # extra_model_kwargs['functions'] = [{
|
|
|
+ # "name": tool.name,
|
|
|
+ # "description": tool.description,
|
|
|
+ # "parameters": tool.parameters
|
|
|
+ # } for tool in tools]
|
|
|
|
|
|
if stop:
|
|
|
extra_model_kwargs['stop'] = stop
|
|
@@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
extra_model_kwargs['user'] = user
|
|
|
|
|
|
# chat model
|
|
|
+ messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
|
|
response = client.chat.completions.create(
|
|
|
- messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
|
|
+ messages=messages,
|
|
|
model=model,
|
|
|
stream=stream,
|
|
|
**model_parameters,
|
|
@@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
|
|
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
|
|
|
- def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
|
|
- prompt_messages: list[PromptMessage],
|
|
|
- tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
|
|
|
-
|
|
|
+ def _handle_chat_generate_response(
|
|
|
+ self, model: str, credentials: dict, response: ChatCompletion,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None
|
|
|
+ ):
|
|
|
assistant_message = response.choices[0].message
|
|
|
- # assistant_message_tool_calls = assistant_message.tool_calls
|
|
|
- assistant_message_function_call = assistant_message.function_call
|
|
|
+ assistant_message_tool_calls = assistant_message.tool_calls
|
|
|
|
|
|
# extract tool calls from response
|
|
|
- # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
|
|
- function_call = self._extract_response_function_call(assistant_message_function_call)
|
|
|
- tool_calls = [function_call] if function_call else []
|
|
|
+ tool_calls = []
|
|
|
+ self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
|
|
|
|
|
# transform assistant message to prompt message
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
@@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
|
|
|
# transform response
|
|
|
- response = LLMResult(
|
|
|
+ result = LLMResult(
|
|
|
model=response.model or model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
message=assistant_prompt_message,
|
|
@@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
system_fingerprint=response.system_fingerprint,
|
|
|
)
|
|
|
|
|
|
- return response
|
|
|
+ return result
|
|
|
|
|
|
- def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
|
|
- response: Stream[ChatCompletionChunk],
|
|
|
- prompt_messages: list[PromptMessage],
|
|
|
- tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
|
|
+ def _handle_chat_generate_stream_response(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ response: Stream[ChatCompletionChunk],
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None
|
|
|
+ ):
|
|
|
index = 0
|
|
|
full_assistant_content = ''
|
|
|
- delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
|
|
real_model = model
|
|
|
system_fingerprint = None
|
|
|
completion = ''
|
|
|
+ tool_calls = []
|
|
|
for chunk in response:
|
|
|
if len(chunk.choices) == 0:
|
|
|
continue
|
|
|
|
|
|
delta = chunk.choices[0]
|
|
|
|
|
|
- # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
|
|
- if delta.delta is None or (
|
|
|
- delta.finish_reason is None
|
|
|
- and (delta.delta.content is None or delta.delta.content == '')
|
|
|
- and delta.delta.function_call is None
|
|
|
- ):
|
|
|
- continue
|
|
|
-
|
|
|
- # assistant_message_tool_calls = delta.delta.tool_calls
|
|
|
- assistant_message_function_call = delta.delta.function_call
|
|
|
-
|
|
|
# extract tool calls from response
|
|
|
- if delta_assistant_message_function_call_storage is not None:
|
|
|
- # handle process of stream function call
|
|
|
- if assistant_message_function_call:
|
|
|
- # message has not ended ever
|
|
|
- delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
|
|
- continue
|
|
|
- else:
|
|
|
- # message has ended
|
|
|
- assistant_message_function_call = delta_assistant_message_function_call_storage
|
|
|
- delta_assistant_message_function_call_storage = None
|
|
|
- else:
|
|
|
- if assistant_message_function_call:
|
|
|
- # start of stream function call
|
|
|
- delta_assistant_message_function_call_storage = assistant_message_function_call
|
|
|
- if delta_assistant_message_function_call_storage.arguments is None:
|
|
|
- delta_assistant_message_function_call_storage.arguments = ''
|
|
|
- continue
|
|
|
+ self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
|
|
|
|
|
- # extract tool calls from response
|
|
|
- # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
|
|
- function_call = self._extract_response_function_call(assistant_message_function_call)
|
|
|
- tool_calls = [function_call] if function_call else []
|
|
|
+ # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
|
|
+ if delta.finish_reason is None and not delta.delta.content:
|
|
|
+ continue
|
|
|
|
|
|
# transform assistant message to prompt message
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
@@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
- def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
|
|
- -> list[AssistantPromptMessage.ToolCall]:
|
|
|
-
|
|
|
- tool_calls = []
|
|
|
- if response_tool_calls:
|
|
|
- for response_tool_call in response_tool_calls:
|
|
|
- function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
- name=response_tool_call.function.name,
|
|
|
- arguments=response_tool_call.function.arguments
|
|
|
- )
|
|
|
-
|
|
|
- tool_call = AssistantPromptMessage.ToolCall(
|
|
|
- id=response_tool_call.id,
|
|
|
- type=response_tool_call.type,
|
|
|
- function=function
|
|
|
- )
|
|
|
- tool_calls.append(tool_call)
|
|
|
-
|
|
|
- return tool_calls
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
|
|
- -> AssistantPromptMessage.ToolCall:
|
|
|
-
|
|
|
- tool_call = None
|
|
|
- if response_function_call:
|
|
|
- function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
- name=response_function_call.name,
|
|
|
- arguments=response_function_call.arguments
|
|
|
- )
|
|
|
-
|
|
|
- tool_call = AssistantPromptMessage.ToolCall(
|
|
|
- id=response_function_call.name,
|
|
|
- type="function",
|
|
|
- function=function
|
|
|
- )
|
|
|
+ def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
|
|
|
+ if tool_calls_response:
|
|
|
+ for response_tool_call in tool_calls_response:
|
|
|
+ if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
|
|
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=response_tool_call.function.name,
|
|
|
+ arguments=response_tool_call.function.arguments
|
|
|
+ )
|
|
|
|
|
|
- return tool_call
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=response_tool_call.id,
|
|
|
+ type=response_tool_call.type,
|
|
|
+ function=function
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
+ elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
|
|
+ index = response_tool_call.index
|
|
|
+ if index < len(tool_calls):
|
|
|
+ tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
|
|
+ tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
|
|
+ if response_tool_call.function:
|
|
|
+ tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
|
|
|
+ tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
|
|
|
+ else:
|
|
|
+ assert response_tool_call.id is not None
|
|
|
+ assert response_tool_call.type is not None
|
|
|
+ assert response_tool_call.function is not None
|
|
|
+ assert response_tool_call.function.name is not None
|
|
|
+ assert response_tool_call.function.arguments is not None
|
|
|
+
|
|
|
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=response_tool_call.function.name,
|
|
|
+ arguments=response_tool_call.function.arguments
|
|
|
+ )
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=response_tool_call.id,
|
|
|
+ type=response_tool_call.type,
|
|
|
+ function=function
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
|
|
|
@staticmethod
|
|
|
- def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
|
|
-
|
|
|
+ def _convert_prompt_message_to_dict(message: PromptMessage):
|
|
|
if isinstance(message, UserPromptMessage):
|
|
|
message = cast(UserPromptMessage, message)
|
|
|
if isinstance(message.content, str):
|
|
|
message_dict = {"role": "user", "content": message.content}
|
|
|
else:
|
|
|
sub_messages = []
|
|
|
+ assert message.content is not None
|
|
|
for message_content in message.content:
|
|
|
if message_content.type == PromptMessageContentType.TEXT:
|
|
|
message_content = cast(TextPromptMessageContent, message_content)
|
|
@@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
}
|
|
|
}
|
|
|
sub_messages.append(sub_message_dict)
|
|
|
-
|
|
|
message_dict = {"role": "user", "content": sub_messages}
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
|
if message.tool_calls:
|
|
|
- # message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
|
|
|
- # message.tool_calls]
|
|
|
- function_call = message.tool_calls[0]
|
|
|
- message_dict["function_call"] = {
|
|
|
- "name": function_call.function.name,
|
|
|
- "arguments": function_call.function.arguments,
|
|
|
- }
|
|
|
+ message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
message_dict = {"role": "system", "content": message.content}
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
message = cast(ToolPromptMessage, message)
|
|
|
- # message_dict = {
|
|
|
- # "role": "tool",
|
|
|
- # "content": message.content,
|
|
|
- # "tool_call_id": message.tool_call_id
|
|
|
- # }
|
|
|
message_dict = {
|
|
|
- "role": "function",
|
|
|
+ "role": "tool",
|
|
|
+ "name": message.name,
|
|
|
"content": message.content,
|
|
|
- "name": message.tool_call_id
|
|
|
+ "tool_call_id": message.tool_call_id
|
|
|
}
|
|
|
else:
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
@@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
|
|
|
return num_tokens
|
|
|
|
|
|
- def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
|
|
|
- tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
|
+ def _num_tokens_from_messages(
|
|
|
+ self, credentials: dict, messages: list[PromptMessage],
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None
|
|
|
+ ) -> int:
|
|
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
|
|
|
|
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
|
@@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
|
|
|
if key == "tool_calls":
|
|
|
for tool_call in value:
|
|
|
+ assert isinstance(tool_call, dict)
|
|
|
for t_key, t_value in tool_call.items():
|
|
|
num_tokens += len(encoding.encode(t_key))
|
|
|
if t_key == "function":
|
|
@@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
num_tokens += len(encoding.encode('parameters'))
|
|
|
if 'title' in parameters:
|
|
|
num_tokens += len(encoding.encode('title'))
|
|
|
- num_tokens += len(encoding.encode(parameters.get("title")))
|
|
|
+ num_tokens += len(encoding.encode(parameters['title']))
|
|
|
num_tokens += len(encoding.encode('type'))
|
|
|
- num_tokens += len(encoding.encode(parameters.get("type")))
|
|
|
+ num_tokens += len(encoding.encode(parameters['type']))
|
|
|
if 'properties' in parameters:
|
|
|
num_tokens += len(encoding.encode('properties'))
|
|
|
- for key, value in parameters.get('properties').items():
|
|
|
+ for key, value in parameters['properties'].items():
|
|
|
num_tokens += len(encoding.encode(key))
|
|
|
for field_key, field_value in value.items():
|
|
|
num_tokens += len(encoding.encode(field_key))
|
|
@@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
return num_tokens
|
|
|
|
|
|
@staticmethod
|
|
|
- def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
|
|
+ def _get_ai_model_entity(base_model_name: str, model: str):
|
|
|
for ai_model_entity in LLM_BASE_MODELS:
|
|
|
if ai_model_entity.base_model_name == base_model_name:
|
|
|
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
|
@@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
ai_model_entity_copy.entity.label.en_US = model
|
|
|
ai_model_entity_copy.entity.label.zh_Hans = model
|
|
|
return ai_model_entity_copy
|
|
|
-
|
|
|
- return None
|