Sfoglia il codice sorgente

feat: support function call for ollama block chat api (#10784)

GeorgeCaoJ 5 mesi fa
parent
commit
fbfc811a44

+ 63 - 5
api/core/model_runtime/model_providers/ollama/llm/llm.py

@@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.entities.model_entities import (
@@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
             credentials=credentials,
             prompt_messages=prompt_messages,
             model_parameters=model_parameters,
+            tools=tools,
             stop=stop,
             stream=stream,
             user=user,
@@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
         credentials: dict,
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
+        tools: Optional[list[PromptMessageTool]] = None,
         stop: Optional[list[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
@@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
         if completion_type is LLMMode.CHAT:
             endpoint_url = urljoin(endpoint_url, "api/chat")
             data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
+            if tools:
+                data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
         else:
             endpoint_url = urljoin(endpoint_url, "api/generate")
             first_prompt_message = prompt_messages[0]
@@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
         if stream:
             return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
 
-        return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
+        return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
 
     def _handle_generate_response(
         self,
@@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
         completion_type: LLMMode,
         response: requests.Response,
         prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]],
     ) -> LLMResult:
         """
         Handle llm completion response
@@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
         :return: llm result
         """
         response_json = response.json()
-
+        tool_calls = []
         if completion_type is LLMMode.CHAT:
             message = response_json.get("message", {})
             response_content = message.get("content", "")
+            response_tool_calls = message.get("tool_calls", [])
+            tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
         else:
             response_content = response_json["response"]
 
-        assistant_message = AssistantPromptMessage(content=response_content)
+        assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
 
         if "prompt_eval_count" in response_json and "eval_count" in response_json:
             # transform usage
@@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
 
             chunk_index += 1
 
+    def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
+        """
+        Convert PromptMessageTool to dict for Ollama API
+
+        :param tool: tool
+        :return: tool dict
+        """
+        return {
+            "type": "function",
+            "function": {
+                "name": tool.name,
+                "description": tool.description,
+                "parameters": tool.parameters,
+            },
+        }
+
     def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
         """
         Convert PromptMessage to dict for Ollama API
+
+        :param message: prompt message
+        :return: message dict
         """
         if isinstance(message, UserPromptMessage):
             message = cast(UserPromptMessage, message)
@@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {"role": "tool", "content": message.content}
         else:
             raise ValueError(f"Got unknown type {message}")
 
@@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
 
         return num_tokens
 
+    def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
+        """
+        Extract response tool call
+        """
+        tool_call = None
+        if response_tool_call and "function" in response_tool_call:
+            # Convert arguments to JSON string if it's a dict
+            arguments = response_tool_call.get("function").get("arguments")
+            if isinstance(arguments, dict):
+                arguments = json.dumps(arguments)
+
+            function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                name=response_tool_call.get("function").get("name"),
+                arguments=arguments,
+            )
+            tool_call = AssistantPromptMessage.ToolCall(
+                id=response_tool_call.get("function").get("name"),
+                type="function",
+                function=function,
+            )
+
+        return tool_call
+
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
         """
         Get customizable model schema.
@@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
 
         :return: model schema
         """
-        extras = {}
+        extras = {
+            "features": [],
+        }
 
         if "vision_support" in credentials and credentials["vision_support"] == "true":
-            extras["features"] = [ModelFeature.VISION]
+            extras["features"].append(ModelFeature.VISION)
+        if "function_call_support" in credentials and credentials["function_call_support"] == "true":
+            extras["features"].append(ModelFeature.TOOL_CALL)
+            extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
 
         entity = AIModelEntity(
             model=model,

+ 19 - 0
api/core/model_runtime/model_providers/ollama/ollama.yaml

@@ -96,3 +96,22 @@ model_credential_schema:
           label:
             en_US: 'No'
             zh_Hans: 否
+    - variable: function_call_support
+      label:
+        zh_Hans: 是否支持函数调用
+        en_US: Function call support
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: 'false'
+      type: radio
+      required: false
+      options:
+        - value: 'true'
+          label:
+            en_US: 'Yes'
+            zh_Hans: 是
+        - value: 'false'
+          label:
+            en_US: 'No'
+            zh_Hans: 否