Переглянути джерело

feat: gemini pro function call (#3406)

Yeuoly 1 рік тому
батько
коміт
a258a90291

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml

@@ -5,6 +5,8 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml

@@ -4,6 +4,8 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 30720

+ 138 - 62
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -1,7 +1,9 @@
+import json
 import logging
 from collections.abc import Generator
 from typing import Optional, Union
 
+import google.ai.generativelanguage as glm
 import google.api_core.exceptions as exceptions
 import google.generativeai as genai
 import google.generativeai.client as client
@@ -13,9 +15,9 @@ from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
     PromptMessageContentType,
-    PromptMessageRole,
     PromptMessageTool,
     SystemPromptMessage,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.errors.invoke import (
@@ -62,7 +64,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         :return: full response or stream response chunk generator result
         """
         # invoke model
-        return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
     
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                        tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -94,6 +96,32 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         )
 
         return text.rstrip()
+    
+    def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
+        """
+        Convert tool messages to glm tools
+
+        :param tools: tool messages
+        :return: glm tools
+        """
+        return glm.Tool(
+            function_declarations=[
+                glm.FunctionDeclaration(
+                    name=tool.name,
+                    parameters=glm.Schema(
+                        type=glm.Type.OBJECT,
+                        properties={
+                            key: {
+                                'type_': value.get('type', 'string').upper(),
+                                'description': value.get('description', ''),
+                                'enum': value.get('enum', [])
+                            } for key, value in tool.parameters.get('properties', {}).items()
+                        },
+                        required=tool.parameters.get('required', [])
+                    ),
+                ) for tool in tools
+            ]
+        )
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -105,7 +133,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         """
         
         try:
-            ping_message = PromptMessage(content="ping", role="system")
+            ping_message = SystemPromptMessage(content="ping")
             self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
             
         except Exception as ex:
@@ -114,8 +142,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 
     def _generate(self, model: str, credentials: dict,
                   prompt_messages: list[PromptMessage], model_parameters: dict,
-                  stop: Optional[list[str]] = None, stream: bool = True,
-                  user: Optional[str] = None) -> Union[LLMResult, Generator]:
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, 
+                  stream: bool = True, user: Optional[str] = None
+        ) -> Union[LLMResult, Generator]:
         """
         Invoke large language model
 
@@ -153,7 +182,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
                 else:
                     history.append(content)
 
-
         # Create a new ClientManager with tenant's API key
         new_client_manager = client._ClientManager()
         new_client_manager.configure(api_key=credentials["google_api_key"])
@@ -167,14 +195,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
             HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
             HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
         }
-        
+
         response = google_model.generate_content(
             contents=history,
             generation_config=genai.types.GenerationConfig(
                 **config_kwargs
             ),
             stream=stream,
-            safety_settings=safety_settings
+            safety_settings=safety_settings,
+            tools=self._convert_tools_to_glm_tool(tools) if tools else None,
         )
 
         if stream:
@@ -228,43 +257,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         """
         index = -1
         for chunk in response:
-            content = chunk.text
-            index += 1
-           
-            assistant_prompt_message = AssistantPromptMessage(
-                content=content if content else '',
-            )
-  
-            if not response._done:
-                
-                # transform assistant message to prompt message
-                yield LLMResultChunk(
-                    model=model,
-                    prompt_messages=prompt_messages,
-                    delta=LLMResultChunkDelta(
-                        index=index,
-                        message=assistant_prompt_message
-                    )
+            for part in chunk.parts:
+                assistant_prompt_message = AssistantPromptMessage(
+                    content=''
                 )
-            else:
-                
-                # calculate num tokens
-                prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
-                completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
-
-                # transform usage
-                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-                
-                yield LLMResultChunk(
-                    model=model,
-                    prompt_messages=prompt_messages,
-                    delta=LLMResultChunkDelta(
-                        index=index,
-                        message=assistant_prompt_message,
-                        finish_reason=chunk.candidates[0].finish_reason,
-                        usage=usage
+
+                if part.text:
+                    assistant_prompt_message.content += part.text
+
+                if part.function_call:
+                    assistant_prompt_message.tool_calls = [
+                        AssistantPromptMessage.ToolCall(
+                            id=part.function_call.name,
+                            type='function',
+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                name=part.function_call.name,
+                                arguments=json.dumps({
+                                    key: value 
+                                    for key, value in part.function_call.args.items()
+                                })
+                            )
+                        )
+                    ]
+
+                index += 1
+    
+                if not response._done:
+                    
+                    # transform assistant message to prompt message
+                    yield LLMResultChunk(
+                        model=model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=index,
+                            message=assistant_prompt_message
+                        )
+                    )
+                else:
+                    
+                    # calculate num tokens
+                    prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+                    completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+                    # transform usage
+                    usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+                    
+                    yield LLMResultChunk(
+                        model=model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=index,
+                            message=assistant_prompt_message,
+                            finish_reason=chunk.candidates[0].finish_reason,
+                            usage=usage
+                        )
                     )
-                )
 
     def _convert_one_message_to_text(self, message: PromptMessage) -> str:
         """
@@ -288,6 +335,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
             message_text = f"{ai_prompt} {content}"
         elif isinstance(message, SystemPromptMessage):
             message_text = f"{human_prompt} {content}"
+        elif isinstance(message, ToolPromptMessage):
+            message_text = f"{human_prompt} {content}"
         else:
             raise ValueError(f"Got unknown type {message}")
 
@@ -300,26 +349,53 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         :param message: one PromptMessage
         :return: glm Content representation of message
         """
-
-        parts = []
-        if (isinstance(message.content, str)):
-            parts.append(to_part(message.content))
+        if isinstance(message, UserPromptMessage):
+            glm_content = {
+                "role": "user",
+                "parts": []
+            }
+            if (isinstance(message.content, str)):
+                glm_content['parts'].append(to_part(message.content))
+            else:
+                for c in message.content:
+                    if c.type == PromptMessageContentType.TEXT:
+                        glm_content['parts'].append(to_part(c.data))
+                    else:
+                        metadata, data = c.data.split(',', 1)
+                        mime_type = metadata.split(';', 1)[0].split(':')[1]
+                        blob = {"inline_data":{"mime_type":mime_type,"data":data}}
+                        glm_content['parts'].append(blob)
+            return glm_content
+        elif isinstance(message, AssistantPromptMessage):
+            glm_content = {
+                "role": "model",
+                "parts": []
+            }
+            if message.content:
+                glm_content['parts'].append(to_part(message.content))
+            if message.tool_calls:
+                glm_content["parts"].append(to_part(glm.FunctionCall(
+                    name=message.tool_calls[0].function.name,
+                    args=json.loads(message.tool_calls[0].function.arguments),
+                )))
+            return glm_content
+        elif isinstance(message, SystemPromptMessage):
+            return {
+                "role": "user",
+                "parts": [to_part(message.content)]
+            }
+        elif isinstance(message, ToolPromptMessage):
+            return {
+                "role": "function",
+                "parts": [glm.Part(function_response=glm.FunctionResponse(
+                    name=message.name,
+                    response={
+                        "response": message.content
+                    }
+                ))]
+            }
         else:
-            for c in message.content:
-                if c.type == PromptMessageContentType.TEXT:
-                    parts.append(to_part(c.data))
-                else:
-                    metadata, data = c.data.split(',', 1)
-                    mime_type = metadata.split(';', 1)[0].split(':')[1]
-                    blob = {"inline_data":{"mime_type":mime_type,"data":data}}
-                    parts.append(blob)
-
-        glm_content = {
-            "role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model",
-            "parts": parts
-        }
-        
-        return glm_content
+            raise ValueError(f"Got unknown type {message}")
     
     @property
     def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:

+ 10 - 1
api/tests/integration_tests/model_runtime/__mock/google.py

@@ -10,6 +10,7 @@ from google.generativeai import GenerativeModel
 from google.generativeai.client import _ClientManager, configure
 from google.generativeai.types import GenerateContentResponse
 from google.generativeai.types.generation_types import BaseGenerateContentResponse
+from google.ai.generativelanguage_v1beta.types import content as gag_content
 
 current_api_key = ''
 
@@ -29,7 +30,7 @@ class MockGoogleResponseClass(object):
 
                     }),
                     chunks=[]
-                )                
+                )
             else:
                 yield GenerateContentResponse(
                     done=False,
@@ -43,6 +44,14 @@ class MockGoogleResponseClass(object):
 class MockGoogleResponseCandidateClass(object):
     finish_reason = 'stop'
 
+    @property
+    def content(self) -> gag_content.Content:
+        return gag_content.Content(
+            parts=[
+                gag_content.Part(text='it\'s google!')
+            ]
+        )
+
 class MockGoogleClass(object):
     @staticmethod
     def generate_content_sync() -> GenerateContentResponse: