浏览代码

fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)

-LAN- 10 月之前
父节点
当前提交
ba67206bb9

+ 134 - 137
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -1,14 +1,13 @@
 import copy
 import logging
-from collections.abc import Generator
+from collections.abc import Generator, Sequence
 from typing import Optional, Union, cast
 
 import tiktoken
 from openai import AzureOpenAI, Stream
 from openai.types import Completion
 from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
-from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
-from openai.types.chat.chat_completion_message import FunctionCall
+from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
 
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
@@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContentType,
+    PromptMessageFunction,
     PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
@@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
-from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
+from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
+from core.model_runtime.utils import helper
 
 logger = logging.getLogger(__name__)
 
@@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 stream: bool = True, user: Optional[str] = None) \
             -> Union[LLMResult, Generator]:
 
-        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise ValueError('Base Model Name is required')
+        ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
 
-        if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
+        if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
             # chat model
             return self._chat_generate(
                 model=model,
@@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 user=user
             )
 
-    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                       tools: Optional[list[PromptMessageTool]] = None) -> int:
-
-        model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
-            ModelPropertyKey.MODE)
+    def get_num_tokens(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ) -> int:
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise ValueError('Base Model Name is required')
+        model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
+        if not model_entity:
+            raise ValueError(f'Base Model Name {base_model_name} is invalid')
+        model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
 
         if model_mode == LLMMode.CHAT.value:
             # chat model
             return self._num_tokens_from_messages(credentials, prompt_messages, tools)
         else:
             # text completion model, do not support tool calling
-            return self._num_tokens_from_string(credentials, prompt_messages[0].content)
+            content = prompt_messages[0].content
+            assert isinstance(content, str)
+            return self._num_tokens_from_string(credentials,content)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         if 'openai_api_base' not in credentials:
@@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         if 'base_model_name' not in credentials:
             raise CredentialsValidateFailedError('Base Model Name is required')
 
-        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise CredentialsValidateFailedError('Base Model Name is required')
+        ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
 
         if not ai_model_entity:
             raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
@@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             raise CredentialsValidateFailedError(str(ex))
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
-        ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
+        base_model_name = credentials.get('base_model_name')
+        if not base_model_name:
+            raise ValueError('Base Model Name is required')
+        ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
         return ai_model_entity.entity if ai_model_entity else None
 
     def _generate(self, model: str, credentials: dict,
@@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
         return self._handle_generate_response(model, credentials, response, prompt_messages)
 
-    def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
-                                  prompt_messages: list[PromptMessage]) -> LLMResult:
+    def _handle_generate_response(
+        self, model: str, credentials: dict, response: Completion,
+        prompt_messages: list[PromptMessage]
+    ):
         assistant_text = response.choices[0].text
 
         # transform assistant message to prompt message
@@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             completion_tokens = response.usage.completion_tokens
         else:
             # calculate num tokens
-            prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
+            content = prompt_messages[0].content
+            assert isinstance(content, str)
+            prompt_tokens = self._num_tokens_from_string(credentials, content)
             completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
 
         # transform usage
@@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
         return result
 
-    def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
-                                         prompt_messages: list[PromptMessage]) -> Generator:
+    def _handle_generate_stream_response(
+        self, model: str, credentials: dict, response: Stream[Completion],
+        prompt_messages: list[PromptMessage]
+    ) -> Generator:
         full_text = ''
         for chunk in response:
             if len(chunk.choices) == 0:
@@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                     completion_tokens = chunk.usage.completion_tokens
                 else:
                     # calculate num tokens
-                    prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
+                    content = prompt_messages[0].content
+                    assert isinstance(content, str)
+                    prompt_tokens = self._num_tokens_from_string(credentials, content)
                     completion_tokens = self._num_tokens_from_string(credentials, full_text)
 
                 # transform usage
@@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         extra_model_kwargs = {}
 
         if tools:
-            # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
-            extra_model_kwargs['functions'] = [{
-                "name": tool.name,
-                "description": tool.description,
-                "parameters": tool.parameters
-            } for tool in tools]
+            extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
+            # extra_model_kwargs['functions'] = [{
+            #     "name": tool.name,
+            #     "description": tool.description,
+            #     "parameters": tool.parameters
+            # } for tool in tools]
 
         if stop:
             extra_model_kwargs['stop'] = stop
@@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             extra_model_kwargs['user'] = user
 
         # chat model
+        messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
         response = client.chat.completions.create(
-            messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+            messages=messages,
             model=model,
             stream=stream,
             **model_parameters,
@@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
         return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
 
-    def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
-                                       prompt_messages: list[PromptMessage],
-                                       tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
-
+    def _handle_chat_generate_response(
+        self, model: str, credentials: dict, response: ChatCompletion,
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ):
         assistant_message = response.choices[0].message
-        # assistant_message_tool_calls = assistant_message.tool_calls
-        assistant_message_function_call = assistant_message.function_call
+        assistant_message_tool_calls = assistant_message.tool_calls
 
         # extract tool calls from response
-        # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
-        function_call = self._extract_response_function_call(assistant_message_function_call)
-        tool_calls = [function_call] if function_call else []
+        tool_calls = []
+        self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
 
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
@@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 
         # transform response
-        response = LLMResult(
+        result = LLMResult(
             model=response.model or model,
             prompt_messages=prompt_messages,
             message=assistant_prompt_message,
@@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             system_fingerprint=response.system_fingerprint,
         )
 
-        return response
+        return result
 
-    def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
-                                              response: Stream[ChatCompletionChunk],
-                                              prompt_messages: list[PromptMessage],
-                                              tools: Optional[list[PromptMessageTool]] = None) -> Generator:
+    def _handle_chat_generate_stream_response(
+        self,
+        model: str,
+        credentials: dict,
+        response: Stream[ChatCompletionChunk],
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ):
         index = 0
         full_assistant_content = ''
-        delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
         real_model = model
         system_fingerprint = None
         completion = ''
+        tool_calls = []
         for chunk in response:
             if len(chunk.choices) == 0:
                 continue
 
             delta = chunk.choices[0]
 
-            # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
-            if delta.delta is None or (
-                delta.finish_reason is None
-                and (delta.delta.content is None or delta.delta.content == '')
-                and delta.delta.function_call is None
-            ):
-                continue
-            
-            # assistant_message_tool_calls = delta.delta.tool_calls
-            assistant_message_function_call = delta.delta.function_call
-
             # extract tool calls from response
-            if delta_assistant_message_function_call_storage is not None:
-                # handle process of stream function call
-                if assistant_message_function_call:
-                    # message has not ended ever
-                    delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
-                    continue
-                else:
-                    # message has ended
-                    assistant_message_function_call = delta_assistant_message_function_call_storage
-                    delta_assistant_message_function_call_storage = None
-            else:
-                if assistant_message_function_call:
-                    # start of stream function call
-                    delta_assistant_message_function_call_storage = assistant_message_function_call
-                    if delta_assistant_message_function_call_storage.arguments is None:
-                        delta_assistant_message_function_call_storage.arguments = ''
-                    continue
+            self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
 
-            # extract tool calls from response
-            # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
-            function_call = self._extract_response_function_call(assistant_message_function_call)
-            tool_calls = [function_call] if function_call else []
+            # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
+            if delta.finish_reason is None and not delta.delta.content:
+                continue
 
             # transform assistant message to prompt message
             assistant_prompt_message = AssistantPromptMessage(
@@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         )
 
     @staticmethod
-    def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-            -> list[AssistantPromptMessage.ToolCall]:
-
-        tool_calls = []
-        if response_tool_calls:
-            for response_tool_call in response_tool_calls:
-                function = AssistantPromptMessage.ToolCall.ToolCallFunction(
-                    name=response_tool_call.function.name,
-                    arguments=response_tool_call.function.arguments
-                )
-
-                tool_call = AssistantPromptMessage.ToolCall(
-                    id=response_tool_call.id,
-                    type=response_tool_call.type,
-                    function=function
-                )
-                tool_calls.append(tool_call)
-
-        return tool_calls
-
-    @staticmethod
-    def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-            -> AssistantPromptMessage.ToolCall:
-
-        tool_call = None
-        if response_function_call:
-            function = AssistantPromptMessage.ToolCall.ToolCallFunction(
-                name=response_function_call.name,
-                arguments=response_function_call.arguments
-            )
-
-            tool_call = AssistantPromptMessage.ToolCall(
-                id=response_function_call.name,
-                type="function",
-                function=function
-            )
+    def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
+        if tool_calls_response:
+            for response_tool_call in tool_calls_response:
+                if isinstance(response_tool_call, ChatCompletionMessageToolCall):
+                    function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                        name=response_tool_call.function.name,
+                        arguments=response_tool_call.function.arguments
+                    )
 
-        return tool_call
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id=response_tool_call.id,
+                        type=response_tool_call.type,
+                        function=function
+                    )
+                    tool_calls.append(tool_call)
+                elif isinstance(response_tool_call, ChoiceDeltaToolCall):
+                    index = response_tool_call.index
+                    if index < len(tool_calls):
+                        tool_calls[index].id = response_tool_call.id or tool_calls[index].id
+                        tool_calls[index].type = response_tool_call.type or tool_calls[index].type
+                        if response_tool_call.function:
+                            tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
+                            tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
+                    else:
+                        assert response_tool_call.id is not None
+                        assert response_tool_call.type is not None
+                        assert response_tool_call.function is not None
+                        assert response_tool_call.function.name is not None
+                        assert response_tool_call.function.arguments is not None
+
+                        function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=response_tool_call.function.name,
+                            arguments=response_tool_call.function.arguments
+                        )
+                        tool_call = AssistantPromptMessage.ToolCall(
+                            id=response_tool_call.id,
+                            type=response_tool_call.type,
+                            function=function
+                        )
+                        tool_calls.append(tool_call)
 
     @staticmethod
-    def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
-
+    def _convert_prompt_message_to_dict(message: PromptMessage):
         if isinstance(message, UserPromptMessage):
             message = cast(UserPromptMessage, message)
             if isinstance(message.content, str):
                 message_dict = {"role": "user", "content": message.content}
             else:
                 sub_messages = []
+                assert message.content is not None
                 for message_content in message.content:
                     if message_content.type == PromptMessageContentType.TEXT:
                         message_content = cast(TextPromptMessageContent, message_content)
@@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                             }
                         }
                         sub_messages.append(sub_message_dict)
-
                 message_dict = {"role": "user", "content": sub_messages}
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
             message_dict = {"role": "assistant", "content": message.content}
             if message.tool_calls:
-                # message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
-                #                               message.tool_calls]
-                function_call = message.tool_calls[0]
-                message_dict["function_call"] = {
-                    "name": function_call.function.name,
-                    "arguments": function_call.function.arguments,
-                }
+                message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = {"role": "system", "content": message.content}
         elif isinstance(message, ToolPromptMessage):
             message = cast(ToolPromptMessage, message)
-            # message_dict = {
-            #     "role": "tool",
-            #     "content": message.content,
-            #     "tool_call_id": message.tool_call_id
-            # }
             message_dict = {
-                "role": "function",
+                "role": "tool",
+                "name": message.name,
                 "content": message.content,
-                "name": message.tool_call_id
+                "tool_call_id": message.tool_call_id
             }
         else:
             raise ValueError(f"Got unknown type {message}")
@@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
         return num_tokens
 
-    def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
-                                  tools: Optional[list[PromptMessageTool]] = None) -> int:
+    def _num_tokens_from_messages(
+        self, credentials: dict, messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None
+    ) -> int:
         """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
 
         Official documentation: https://github.com/openai/openai-cookbook/blob/
@@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
                 if key == "tool_calls":
                     for tool_call in value:
+                        assert isinstance(tool_call, dict)
                         for t_key, t_value in tool_call.items():
                             num_tokens += len(encoding.encode(t_key))
                             if t_key == "function":
@@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             num_tokens += len(encoding.encode('parameters'))
             if 'title' in parameters:
                 num_tokens += len(encoding.encode('title'))
-                num_tokens += len(encoding.encode(parameters.get("title")))
+                num_tokens += len(encoding.encode(parameters['title']))
             num_tokens += len(encoding.encode('type'))
-            num_tokens += len(encoding.encode(parameters.get("type")))
+            num_tokens += len(encoding.encode(parameters['type']))
             if 'properties' in parameters:
                 num_tokens += len(encoding.encode('properties'))
-                for key, value in parameters.get('properties').items():
+                for key, value in parameters['properties'].items():
                     num_tokens += len(encoding.encode(key))
                     for field_key, field_value in value.items():
                         num_tokens += len(encoding.encode(field_key))
@@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         return num_tokens
 
     @staticmethod
-    def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
+    def _get_ai_model_entity(base_model_name: str, model: str):
         for ai_model_entity in LLM_BASE_MODELS:
             if ai_model_entity.base_model_name == base_model_name:
                 ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                 ai_model_entity_copy.entity.label.en_US = model
                 ai_model_entity_copy.entity.label.zh_Hans = model
                 return ai_model_entity_copy
-
-        return None

+ 4 - 6
api/tests/integration_tests/model_runtime/__mock/openai_chat.py

@@ -73,17 +73,15 @@ class MockChatClass:
         return FunctionCall(name=function_name, arguments=dumps(parameters))
         
     @staticmethod
-    def generate_tool_calls(
-        tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
-    ) -> Optional[list[ChatCompletionMessageToolCall]]:
+    def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
         list_tool_calls = []
         if not tools or len(tools) == 0:
             return None
-        tool: ChatCompletionToolParam = tools[0]
+        tool = tools[0]
 
-        if tools['type'] != 'function':
+        if 'type' in tools and tools['type'] != 'function':
             return None
-        
+
         function = tool['function']
 
         function_call = MockChatClass.generate_function_call(functions=[function])