Selaa lähdekoodia

Feat/open ai compatible functioncall (#2783)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 vuosi sitten
vanhempi
commit
e54c9cd401

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/llm.py

@@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         else:
             raise ValueError(f"Got unknown type {message}")
 
-        if message.name is not None:
+        if message.name:
             message_dict["user_name"] = message.name
 
         return message_dict

+ 1 - 1
api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml

@@ -24,7 +24,7 @@ parameter_rules:
     min: 1
     max: 8000
   - name: safe_prompt
-    defulat: false
+    default: false
     type: boolean
     help:
       en_US: Whether to inject a safety prompt before all conversations.

+ 1 - 1
api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml

@@ -24,7 +24,7 @@ parameter_rules:
     min: 1
     max: 8000
   - name: safe_prompt
-    defulat: false
+    default: false
     type: boolean
     help:
       en_US: Whether to inject a safety prompt before all conversations.

+ 1 - 1
api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml

@@ -24,7 +24,7 @@ parameter_rules:
     min: 1
     max: 8000
   - name: safe_prompt
-    defulat: false
+    default: false
     type: boolean
     help:
       en_US: Whether to inject a safety prompt before all conversations.

+ 1 - 1
api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml

@@ -24,7 +24,7 @@ parameter_rules:
     min: 1
     max: 2048
   - name: safe_prompt
-    defulat: false
+    default: false
     type: boolean
     help:
       en_US: Whether to inject a safety prompt before all conversations.

+ 1 - 1
api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml

@@ -24,7 +24,7 @@ parameter_rules:
     min: 1
     max: 8000
   - name: safe_prompt
-    defulat: false
+    default: false
     type: boolean
     help:
       en_US: Whether to inject a safety prompt before all conversations.

+ 77 - 34
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import (
     AIModelEntity,
     DefaultParameterName,
     FetchFrom,
+    ModelFeature,
     ModelPropertyKey,
     ModelType,
     ParameterRule,
@@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         """
             generate custom model entities from credentials
         """
+        support_function_call = False
+        features = []
+        function_calling_type = credentials.get('function_calling_type', 'no_call')
+        if function_calling_type == 'function_call':
+            features = [ModelFeature.TOOL_CALL]
+            support_function_call = True
+        endpoint_url = credentials["endpoint_url"]
+        # if not endpoint_url.endswith('/'):
+        #     endpoint_url += '/'
+        # if 'https://api.openai.com/v1/' == endpoint_url:
+        #     features = [ModelFeature.STREAM_TOOL_CALL]
         entity = AIModelEntity(
             model=model,
             label=I18nObject(en_US=model),
             model_type=ModelType.LLM,
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            features=features if support_function_call else [],
             model_properties={
                 ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
                 ModelPropertyKey.MODE: credentials.get('mode'),
@@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     max=1,
                     precision=2
                 ),
-                ParameterRule(
-                    name="top_k",
-                    label=I18nObject(en_US="Top K"),
-                    type=ParameterType.INT,
-                    default=int(credentials.get('top_k', 1)),
-                    min=1,
-                    max=100
-                ),
                 ParameterRule(
                     name=DefaultParameterName.FREQUENCY_PENALTY.value,
                     label=I18nObject(en_US="Frequency Penalty"),
@@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 output=Decimal(credentials.get('output_price', 0)),
                 unit=Decimal(credentials.get('unit', 0)),
                 currency=credentials.get('currency', "USD")
-            )
+            ),
         )
 
         if credentials['mode'] == 'chat':
@@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             raise ValueError("Unsupported completion type for model configuration.")
 
         # annotate tools with names, descriptions, etc.
+        function_calling_type = credentials.get('function_calling_type', 'no_call')
         formatted_tools = []
         if tools:
-            data["tool_choice"] = "auto"
+            if function_calling_type == 'function_call':
+                data['functions'] = [{
+                    "name": tool.name,
+                    "description": tool.description,
+                    "parameters": tool.parameters
+                } for tool in tools]
+            elif function_calling_type == 'tool_call':
+                data["tool_choice"] = "auto"
 
-            for tool in tools:
-                formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
+                for tool in tools:
+                    formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
 
-            data["tools"] = formatted_tools
+                data["tools"] = formatted_tools
 
         if stop:
             data["stop"] = stop
@@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             if chunk:
-                #ignore sse comments
+                # ignore sse comments
                 if chunk.startswith(':'):
-                    continue                 
+                    continue
                 decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
                 chunk_json = None
                 try:
@@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         response_content = ''
         tool_calls = None
-
+        function_calling_type = credentials.get('function_calling_type', 'no_call')
         if completion_type is LLMMode.CHAT:
             response_content = output.get('message', {})['content']
-            tool_calls = output.get('message', {}).get('tool_calls')
+            if function_calling_type == 'tool_call':
+                tool_calls = output.get('message', {}).get('tool_calls')
+            elif function_calling_type == 'function_call':
+                tool_calls = output.get('message', {}).get('function_call')
 
         elif completion_type is LLMMode.COMPLETION:
             response_content = output['text']
@@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
 
         if tool_calls:
-            assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
+            if function_calling_type == 'tool_call':
+                assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
+            elif function_calling_type == 'function_call':
+                assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
 
         usage = response_json.get("usage")
         if usage:
@@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             message = cast(AssistantPromptMessage, message)
             message_dict = {"role": "assistant", "content": message.content}
             if message.tool_calls:
-                message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
-                                              in
-                                              message.tool_calls]
-                # function_call = message.tool_calls[0]
-                # message_dict["function_call"] = {
-                #     "name": function_call.function.name,
-                #     "arguments": function_call.function.arguments,
-                # }
+                # message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
+                #                               in
+                #                               message.tool_calls]
+
+                function_call = message.tool_calls[0]
+                message_dict["function_call"] = {
+                    "name": function_call.function.name,
+                    "arguments": function_call.function.arguments,
+                }
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = {"role": "system", "content": message.content}
         elif isinstance(message, ToolPromptMessage):
             message = cast(ToolPromptMessage, message)
-            message_dict = {
-                "role": "tool",
-                "content": message.content,
-                "tool_call_id": message.tool_call_id
-            }
             # message_dict = {
-            #     "role": "function",
+            #     "role": "tool",
             #     "content": message.content,
-            #     "name": message.tool_call_id
+            #     "tool_call_id": message.tool_call_id
             # }
+            message_dict = {
+                "role": "function",
+                "content": message.content,
+                "name": message.tool_call_id
+            }
         else:
             raise ValueError(f"Got unknown type {message}")
 
-        if message.name is not None:
+        if message.name:
             message_dict["name"] = message.name
 
         return message_dict
@@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 tool_calls.append(tool_call)
 
         return tool_calls
+
+    def _extract_response_function_call(self, response_function_call) \
+            -> AssistantPromptMessage.ToolCall:
+        """
+        Extract function call from response
+
+        :param response_function_call: response function call
+        :return: tool call
+        """
+        tool_call = None
+        if response_function_call:
+            function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                name=response_function_call['name'],
+                arguments=response_function_call['arguments']
+            )
+
+            tool_call = AssistantPromptMessage.ToolCall(
+                id=response_function_call['name'],
+                type="function",
+                function=function
+            )
+
+        return tool_call

+ 22 - 0
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -75,6 +75,28 @@ model_credential_schema:
           value: llm
       default: '4096'
       type: text-input
+    - variable: function_calling_type
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Function calling
+      type: select
+      required: false
+      default: no_call
+      options:
+        - value: function_call
+          label:
+            en_US: Support
+            zh_Hans: 支持
+#        - value: tool_call
+#          label:
+#            en_US: Tool Call
+#            zh_Hans: Tool Call
+        - value: no_call
+          label:
+            en_US: Not Support
+            zh_Hans: 不支持
     - variable: stream_mode_delimiter
       label:
         zh_Hans: 流模式返回结果的分隔符