Explorar o código

fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)

-LAN- hai 10 meses
pai
achega
ba67206bb9

+ 134 - 137
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -1,14 +1,13 @@
 import copy
 import copy
 import logging
 import logging
-from collections.abc import Generator
+from collections.abc import Generator, Sequence
 from typing import Optional, Union, cast
 from typing import Optional, Union, cast
 
 
 import tiktoken
 import tiktoken
 from openai import AzureOpenAI, Stream
 from openai import AzureOpenAI, Stream
 from openai.types import Completion
 from openai.types import Completion
 from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
 from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
-from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
-from openai.types.chat.chat_completion_message import FunctionCall
+from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
 
 
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
 from core.model_runtime.entities.message_entities import (
@@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
     ImagePromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessage,
     PromptMessageContentType,
     PromptMessageContentType,
+    PromptMessageFunction,
     PromptMessageTool,
     PromptMessageTool,
     SystemPromptMessage,
     SystemPromptMessage,
     TextPromptMessageContent,
     TextPromptMessageContent,
@@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
 from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
-from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
+from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
+from core.model_runtime.utils import helper
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 stream: bool = True, user: Optional[str] = None) \
                 stream: bool = True, user: Optional[str] = None) \
             -> Union[LLMResult, Generator]:
             -> Union[LLMResult, Generator]:
 
 
-        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise ValueError('Base Model Name is required')
+        ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
 
 
-        if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
+        if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
             # chat model
             # chat model
             return self._chat_generate(
             return self._chat_generate(
                 model=model,
                 model=model,
@@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 user=user
                 user=user
             )
             )
 
 
-    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                       tools: Optional[list[PromptMessageTool]] = None) -> int:
-
-        model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
-            ModelPropertyKey.MODE)
+    def get_num_tokens(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ) -> int:
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise ValueError('Base Model Name is required')
+        model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
+        if not model_entity:
+            raise ValueError(f'Base Model Name {base_model_name} is invalid')
+        model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
 
 
         if model_mode == LLMMode.CHAT.value:
         if model_mode == LLMMode.CHAT.value:
             # chat model
             # chat model
             return self._num_tokens_from_messages(credentials, prompt_messages, tools)
             return self._num_tokens_from_messages(credentials, prompt_messages, tools)
         else:
         else:
             # text completion model, do not support tool calling
             # text completion model, do not support tool calling
-            return self._num_tokens_from_string(credentials, prompt_messages[0].content)
+            content = prompt_messages[0].content
+            assert isinstance(content, str)
+            return self._num_tokens_from_string(credentials,content)
 
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
     def validate_credentials(self, model: str, credentials: dict) -> None:
         if 'openai_api_base' not in credentials:
         if 'openai_api_base' not in credentials:
@@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         if 'base_model_name' not in credentials:
         if 'base_model_name' not in credentials:
             raise CredentialsValidateFailedError('Base Model Name is required')
             raise CredentialsValidateFailedError('Base Model Name is required')
 
 
-        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise CredentialsValidateFailedError('Base Model Name is required')
+        ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
 
 
         if not ai_model_entity:
         if not ai_model_entity:
             raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
             raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
@@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             raise CredentialsValidateFailedError(str(ex))
             raise CredentialsValidateFailedError(str(ex))
 
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
     def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
-        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise ValueError('Base Model Name is required')
+        ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
         return ai_model_entity.entity if ai_model_entity else None
         return ai_model_entity.entity if ai_model_entity else None
 
 
     def _generate(self, model: str, credentials: dict,
     def _generate(self, model: str, credentials: dict,
@@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
 
         return self._handle_generate_response(model, credentials, response, prompt_messages)
         return self._handle_generate_response(model, credentials, response, prompt_messages)
 
 
-    def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
-                                  prompt_messages: list[PromptMessage]) -> LLMResult:
+    def _handle_generate_response(
+        self, model: str, credentials: dict, response: Completion,
+        prompt_messages: list[PromptMessage]
+    ):
         assistant_text = response.choices[0].text
         assistant_text = response.choices[0].text
 
 
         # transform assistant message to prompt message
         # transform assistant message to prompt message
@@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             completion_tokens = response.usage.completion_tokens
             completion_tokens = response.usage.completion_tokens
         else:
         else:
             # calculate num tokens
             # calculate num tokens
-            prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
+            content = prompt_messages[0].content
+            assert isinstance(content, str)
+            prompt_tokens = self._num_tokens_from_string(credentials, content)
             completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
             completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
 
 
         # transform usage
         # transform usage
@@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
 
         return result
         return result
 
 
-    def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
-                                         prompt_messages: list[PromptMessage]) -> Generator:
+    def _handle_generate_stream_response(
+        self, model: str, credentials: dict, response: Stream[Completion],
+        prompt_messages: list[PromptMessage]
+    ) -> Generator:
         full_text = ''
         full_text = ''
         for chunk in response:
         for chunk in response:
             if len(chunk.choices) == 0:
             if len(chunk.choices) == 0:
@@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                     completion_tokens = chunk.usage.completion_tokens
                     completion_tokens = chunk.usage.completion_tokens
                 else:
                 else:
                     # calculate num tokens
                     # calculate num tokens
-                    prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
+                    content = prompt_messages[0].content
+                    assert isinstance(content, str)
+                    prompt_tokens = self._num_tokens_from_string(credentials, content)
                     completion_tokens = self._num_tokens_from_string(credentials, full_text)
                     completion_tokens = self._num_tokens_from_string(credentials, full_text)
 
 
                 # transform usage
                 # transform usage
@@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         extra_model_kwargs = {}
         extra_model_kwargs = {}
 
 
         if tools:
         if tools:
-            # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
-            extra_model_kwargs['functions'] = [{
-                "name": tool.name,
-                "description": tool.description,
-                "parameters": tool.parameters
-            } for tool in tools]
+            extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
+            # extra_model_kwargs['functions'] = [{
+            #     "name": tool.name,
+            #     "description": tool.description,
+            #     "parameters": tool.parameters
+            # } for tool in tools]
 
 
         if stop:
         if stop:
             extra_model_kwargs['stop'] = stop
             extra_model_kwargs['stop'] = stop
@@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             extra_model_kwargs['user'] = user
             extra_model_kwargs['user'] = user
 
 
         # chat model
         # chat model
+        messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
         response = client.chat.completions.create(
         response = client.chat.completions.create(
-            messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+            messages=messages,
             model=model,
             model=model,
             stream=stream,
             stream=stream,
             **model_parameters,
             **model_parameters,
@@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
 
         return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
         return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
 
 
-    def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
-                                       prompt_messages: list[PromptMessage],
-                                       tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
-
+    def _handle_chat_generate_response(
+        self, model: str, credentials: dict, response: ChatCompletion,
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ):
         assistant_message = response.choices[0].message
         assistant_message = response.choices[0].message
-        # assistant_message_tool_calls = assistant_message.tool_calls
-        assistant_message_function_call = assistant_message.function_call
+        assistant_message_tool_calls = assistant_message.tool_calls
 
 
         # extract tool calls from response
         # extract tool calls from response
-        # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
-        function_call = self._extract_response_function_call(assistant_message_function_call)
-        tool_calls = [function_call] if function_call else []
+        tool_calls = []
+        self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
 
 
         # transform assistant message to prompt message
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
         assistant_prompt_message = AssistantPromptMessage(
@@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 
 
         # transform response
         # transform response
-        response = LLMResult(
+        result = LLMResult(
             model=response.model or model,
             model=response.model or model,
             prompt_messages=prompt_messages,
             prompt_messages=prompt_messages,
             message=assistant_prompt_message,
             message=assistant_prompt_message,
@@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             system_fingerprint=response.system_fingerprint,
             system_fingerprint=response.system_fingerprint,
         )
         )
 
 
-        return response
+        return result
 
 
-    def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
-                                              response: Stream[ChatCompletionChunk],
-                                              prompt_messages: list[PromptMessage],
-                                              tools: Optional[list[PromptMessageTool]] = None) -> Generator:
+    def _handle_chat_generate_stream_response(
+        self,
+        model: str,
+        credentials: dict,
+        response: Stream[ChatCompletionChunk],
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ):
         index = 0
         index = 0
         full_assistant_content = ''
         full_assistant_content = ''
-        delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
         real_model = model
         real_model = model
         system_fingerprint = None
         system_fingerprint = None
         completion = ''
         completion = ''
+        tool_calls = []
         for chunk in response:
         for chunk in response:
             if len(chunk.choices) == 0:
             if len(chunk.choices) == 0:
                 continue
                 continue
 
 
             delta = chunk.choices[0]
             delta = chunk.choices[0]
 
 
-            # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
-            if delta.delta is None or (
-                delta.finish_reason is None
-                and (delta.delta.content is None or delta.delta.content == '')
-                and delta.delta.function_call is None
-            ):
-                continue
-            
-            # assistant_message_tool_calls = delta.delta.tool_calls
-            assistant_message_function_call = delta.delta.function_call
-
             # extract tool calls from response
             # extract tool calls from response
-            if delta_assistant_message_function_call_storage is not None:
-                # handle process of stream function call
-                if assistant_message_function_call:
-                    # message has not ended ever
-                    delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
-                    continue
-                else:
-                    # message has ended
-                    assistant_message_function_call = delta_assistant_message_function_call_storage
-                    delta_assistant_message_function_call_storage = None
-            else:
-                if assistant_message_function_call:
-                    # start of stream function call
-                    delta_assistant_message_function_call_storage = assistant_message_function_call
-                    if delta_assistant_message_function_call_storage.arguments is None:
-                        delta_assistant_message_function_call_storage.arguments = ''
-                    continue
+            self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
 
 
-            # extract tool calls from response
-            # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
-            function_call = self._extract_response_function_call(assistant_message_function_call)
-            tool_calls = [function_call] if function_call else []
+            # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
+            if delta.finish_reason is None and not delta.delta.content:
+                continue
 
 
             # transform assistant message to prompt message
             # transform assistant message to prompt message
             assistant_prompt_message = AssistantPromptMessage(
             assistant_prompt_message = AssistantPromptMessage(
@@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         )
         )
 
 
     @staticmethod
     @staticmethod
-    def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-            -> list[AssistantPromptMessage.ToolCall]:
-
-        tool_calls = []
-        if response_tool_calls:
-            for response_tool_call in response_tool_calls:
-                function = AssistantPromptMessage.ToolCall.ToolCallFunction(
-                    name=response_tool_call.function.name,
-                    arguments=response_tool_call.function.arguments
-                )
-
-                tool_call = AssistantPromptMessage.ToolCall(
-                    id=response_tool_call.id,
-                    type=response_tool_call.type,
-                    function=function
-                )
-                tool_calls.append(tool_call)
-
-        return tool_calls
-
-    @staticmethod
-    def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-            -> AssistantPromptMessage.ToolCall:
-
-        tool_call = None
-        if response_function_call:
-            function = AssistantPromptMessage.ToolCall.ToolCallFunction(
-                name=response_function_call.name,
-                arguments=response_function_call.arguments
-            )
-
-            tool_call = AssistantPromptMessage.ToolCall(
-                id=response_function_call.name,
-                type="function",
-                function=function
-            )
+    def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
+        if tool_calls_response:
+            for response_tool_call in tool_calls_response:
+                if isinstance(response_tool_call, ChatCompletionMessageToolCall):
+                    function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                        name=response_tool_call.function.name,
+                        arguments=response_tool_call.function.arguments
+                    )
 
 
-        return tool_call
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id=response_tool_call.id,
+                        type=response_tool_call.type,
+                        function=function
+                    )
+                    tool_calls.append(tool_call)
+                elif isinstance(response_tool_call, ChoiceDeltaToolCall):
+                    index = response_tool_call.index
+                    if index < len(tool_calls):
+                        tool_calls[index].id = response_tool_call.id or tool_calls[index].id
+                        tool_calls[index].type = response_tool_call.type or tool_calls[index].type
+                        if response_tool_call.function:
+                            tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
+                            tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
+                    else:
+                        assert response_tool_call.id is not None
+                        assert response_tool_call.type is not None
+                        assert response_tool_call.function is not None
+                        assert response_tool_call.function.name is not None
+                        assert response_tool_call.function.arguments is not None
+
+                        function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=response_tool_call.function.name,
+                            arguments=response_tool_call.function.arguments
+                        )
+                        tool_call = AssistantPromptMessage.ToolCall(
+                            id=response_tool_call.id,
+                            type=response_tool_call.type,
+                            function=function
+                        )
+                        tool_calls.append(tool_call)
 
 
     @staticmethod
     @staticmethod
-    def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
-
+    def _convert_prompt_message_to_dict(message: PromptMessage):
         if isinstance(message, UserPromptMessage):
         if isinstance(message, UserPromptMessage):
             message = cast(UserPromptMessage, message)
             message = cast(UserPromptMessage, message)
             if isinstance(message.content, str):
             if isinstance(message.content, str):
                 message_dict = {"role": "user", "content": message.content}
                 message_dict = {"role": "user", "content": message.content}
             else:
             else:
                 sub_messages = []
                 sub_messages = []
+                assert message.content is not None
                 for message_content in message.content:
                 for message_content in message.content:
                     if message_content.type == PromptMessageContentType.TEXT:
                     if message_content.type == PromptMessageContentType.TEXT:
                         message_content = cast(TextPromptMessageContent, message_content)
                         message_content = cast(TextPromptMessageContent, message_content)
@@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                             }
                             }
                         }
                         }
                         sub_messages.append(sub_message_dict)
                         sub_messages.append(sub_message_dict)
-
                 message_dict = {"role": "user", "content": sub_messages}
                 message_dict = {"role": "user", "content": sub_messages}
         elif isinstance(message, AssistantPromptMessage):
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
             message = cast(AssistantPromptMessage, message)
             message_dict = {"role": "assistant", "content": message.content}
             message_dict = {"role": "assistant", "content": message.content}
             if message.tool_calls:
             if message.tool_calls:
-                # message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
-                #                               message.tool_calls]
-                function_call = message.tool_calls[0]
-                message_dict["function_call"] = {
-                    "name": function_call.function.name,
-                    "arguments": function_call.function.arguments,
-                }
+                message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
         elif isinstance(message, SystemPromptMessage):
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message = cast(SystemPromptMessage, message)
             message_dict = {"role": "system", "content": message.content}
             message_dict = {"role": "system", "content": message.content}
         elif isinstance(message, ToolPromptMessage):
         elif isinstance(message, ToolPromptMessage):
             message = cast(ToolPromptMessage, message)
             message = cast(ToolPromptMessage, message)
-            # message_dict = {
-            #     "role": "tool",
-            #     "content": message.content,
-            #     "tool_call_id": message.tool_call_id
-            # }
             message_dict = {
             message_dict = {
-                "role": "function",
+                "role": "tool",
+                "name": message.name,
                 "content": message.content,
                 "content": message.content,
-                "name": message.tool_call_id
+                "tool_call_id": message.tool_call_id
             }
             }
         else:
         else:
             raise ValueError(f"Got unknown type {message}")
             raise ValueError(f"Got unknown type {message}")
@@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
 
         return num_tokens
         return num_tokens
 
 
-    def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
-                                  tools: Optional[list[PromptMessageTool]] = None) -> int:
+    def _num_tokens_from_messages(
+        self, credentials: dict, messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ) -> int:
         """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
         """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
 
 
         Official documentation: https://github.com/openai/openai-cookbook/blob/
         Official documentation: https://github.com/openai/openai-cookbook/blob/
@@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
 
                 if key == "tool_calls":
                 if key == "tool_calls":
                     for tool_call in value:
                     for tool_call in value:
+                        assert isinstance(tool_call, dict)
                         for t_key, t_value in tool_call.items():
                         for t_key, t_value in tool_call.items():
                             num_tokens += len(encoding.encode(t_key))
                             num_tokens += len(encoding.encode(t_key))
                             if t_key == "function":
                             if t_key == "function":
@@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             num_tokens += len(encoding.encode('parameters'))
             num_tokens += len(encoding.encode('parameters'))
             if 'title' in parameters:
             if 'title' in parameters:
                 num_tokens += len(encoding.encode('title'))
                 num_tokens += len(encoding.encode('title'))
-                num_tokens += len(encoding.encode(parameters.get("title")))
+                num_tokens += len(encoding.encode(parameters['title']))
             num_tokens += len(encoding.encode('type'))
             num_tokens += len(encoding.encode('type'))
-            num_tokens += len(encoding.encode(parameters.get("type")))
+            num_tokens += len(encoding.encode(parameters['type']))
             if 'properties' in parameters:
             if 'properties' in parameters:
                 num_tokens += len(encoding.encode('properties'))
                 num_tokens += len(encoding.encode('properties'))
-                for key, value in parameters.get('properties').items():
+                for key, value in parameters['properties'].items():
                     num_tokens += len(encoding.encode(key))
                     num_tokens += len(encoding.encode(key))
                     for field_key, field_value in value.items():
                     for field_key, field_value in value.items():
                         num_tokens += len(encoding.encode(field_key))
                         num_tokens += len(encoding.encode(field_key))
@@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         return num_tokens
         return num_tokens
 
 
     @staticmethod
     @staticmethod
-    def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
+    def _get_ai_model_entity(base_model_name: str, model: str):
         for ai_model_entity in LLM_BASE_MODELS:
         for ai_model_entity in LLM_BASE_MODELS:
             if ai_model_entity.base_model_name == base_model_name:
             if ai_model_entity.base_model_name == base_model_name:
                 ai_model_entity_copy = copy.deepcopy(ai_model_entity)
                 ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 ai_model_entity_copy.entity.label.en_US = model
                 ai_model_entity_copy.entity.label.en_US = model
                 ai_model_entity_copy.entity.label.zh_Hans = model
                 ai_model_entity_copy.entity.label.zh_Hans = model
                 return ai_model_entity_copy
                 return ai_model_entity_copy
-
-        return None

+ 4 - 6
api/tests/integration_tests/model_runtime/__mock/openai_chat.py

@@ -73,17 +73,15 @@ class MockChatClass:
         return FunctionCall(name=function_name, arguments=dumps(parameters))
         return FunctionCall(name=function_name, arguments=dumps(parameters))
         
         
     @staticmethod
     @staticmethod
-    def generate_tool_calls(
-        tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
-    ) -> Optional[list[ChatCompletionMessageToolCall]]:
+    def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
         list_tool_calls = []
         list_tool_calls = []
         if not tools or len(tools) == 0:
         if not tools or len(tools) == 0:
             return None
             return None
-        tool: ChatCompletionToolParam = tools[0]
+        tool = tools[0]
 
 
-        if tools['type'] != 'function':
+        if 'type' in tools and tools['type'] != 'function':
             return None
             return None
-        
+
         function = tool['function']
         function = tool['function']
 
 
         function_call = MockChatClass.generate_function_call(functions=[function])
         function_call = MockChatClass.generate_function_call(functions=[function])