Browse Source

fix: tool call message role according to credentials (#5625)

Co-authored-by: sunxichen <sun.xc@digitalcnzz.com>
sunxichen 10 months ago
parent
commit
bafc8a0bde

+ 6 - 6
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -88,7 +88,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         :param tools: tools for tool calling
         :return:
         """
-        return self._num_tokens_from_messages(model, prompt_messages, tools)
+        return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -305,7 +305,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         if completion_type is LLMMode.CHAT:
             endpoint_url = urljoin(endpoint_url, 'chat/completions')
-            data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
+            data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
         elif completion_type is LLMMode.COMPLETION:
             endpoint_url = urljoin(endpoint_url, 'completions')
             data['prompt'] = prompt_messages[0].content
@@ -582,7 +582,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         return result
 
-    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+    def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: dict = None) -> dict:
         """
         Convert PromptMessage to dict for OpenAI API format
         """
@@ -636,7 +636,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             #     "tool_call_id": message.tool_call_id
             # }
             message_dict = {
-                "role": "function",
+                "role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
                 "content": message.content,
                 "name": message.tool_call_id
             }
@@ -675,7 +675,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         return num_tokens
 
     def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
-                                  tools: Optional[list[PromptMessageTool]] = None) -> int:
+                                  tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int:
         """
         Approximate num tokens with GPT2 tokenizer.
         """
@@ -684,7 +684,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         tokens_per_name = 1
 
         num_tokens = 0
-        messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+        messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
         for message in messages_dict:
             num_tokens += tokens_per_message
             for key, value in message.items():