Bladeren bron

feat: moonshot function call (#3227)

Yeuoly 1 jaar geleden
bovenliggende
commit
a2c068d949

+ 1 - 1
api/core/agent/cot_agent_runner.py

@@ -687,4 +687,4 @@ class CotAgentRunner(BaseAgentRunner):
         try:
             return json.dumps(tools, ensure_ascii=False)
         except json.JSONDecodeError:
-            return json.dumps(tools)
+            return json.dumps(tools)

+ 12 - 12
api/core/agent/fc_agent_runner.py

@@ -207,19 +207,25 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                     )
                 )
 
+            assistant_message = AssistantPromptMessage(
+                content='',
+                tool_calls=[]
+            )
             if tool_calls:
-                prompt_messages.append(AssistantPromptMessage(
-                    content='',
-                    name='',
-                    tool_calls=[AssistantPromptMessage.ToolCall(
+                assistant_message.tool_calls=[
+                    AssistantPromptMessage.ToolCall(
                         id=tool_call[0],
                         type='function',
                         function=AssistantPromptMessage.ToolCall.ToolCallFunction(
                             name=tool_call[1],
                             arguments=json.dumps(tool_call[2], ensure_ascii=False)
                         )
-                    ) for tool_call in tool_calls]
-                ))
+                    ) for tool_call in tool_calls
+                ]
+            else:
+                assistant_message.content = response
+            
+            prompt_messages.append(assistant_message)
 
             # save thought
             self.save_agent_thought(
@@ -239,12 +245,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
             
             final_answer += response + '\n'
 
-            # update prompt messages
-            if response.strip():
-                prompt_messages.append(AssistantPromptMessage(
-                    content=response,
-                ))
-            
             # call tools
             tool_responses = []
             for tool_call_id, tool_call_name, tool_call_args in tool_calls:

+ 315 - 5
api/core/model_runtime/model_providers/moonshot/llm/llm.py

@@ -1,8 +1,31 @@
+import json
 from collections.abc import Generator
-from typing import Optional, Union
+from typing import Optional, Union, cast
 
-from core.model_runtime.entities.llm_entities import LLMResult
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+import requests
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContent,
+    PromptMessageContentType,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelFeature,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+)
 from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
 
 
@@ -13,6 +36,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
                 stream: bool = True, user: Optional[str] = None) \
             -> Union[LLMResult, Generator]:
         self._add_custom_parameters(credentials)
+        self._add_function_call(model, credentials)
         user = user[:32] if user else None
         return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
 
@@ -20,7 +44,293 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
         self._add_custom_parameters(credentials)
         super().validate_credentials(model, credentials)
 
-    @staticmethod
-    def _add_custom_parameters(credentials: dict) -> None:
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        return AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model, zh_Hans=model),
+            model_type=ModelType.LLM,
+            features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] 
+                if credentials.get('function_calling_type') == 'tool_call' 
+                else [],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
+                ModelPropertyKey.MODE: LLMMode.CHAT.value,
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name='temperature',
+                    use_template='temperature',
+                    label=I18nObject(en_US='Temperature', zh_Hans='温度'),
+                    type=ParameterType.FLOAT,
+                ),
+                ParameterRule(
+                    name='max_tokens',
+                    use_template='max_tokens',
+                    default=512,
+                    min=1,
+                    max=int(credentials.get('max_tokens', 4096)),
+                    label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
+                    type=ParameterType.INT,
+                ),
+                ParameterRule(
+                    name='top_p',
+                    use_template='top_p',
+                    label=I18nObject(en_US='Top P', zh_Hans='Top P'),
+                    type=ParameterType.FLOAT,
+                ),
+            ]
+        )
+
+    def _add_custom_parameters(self, credentials: dict) -> None:
         credentials['mode'] = 'chat'
         credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
+
+    def _add_function_call(self, model: str, credentials: dict) -> None:
+        model_schema = self.get_model_schema(model, credentials)
+        if model_schema and set([
+            ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
+        ]).intersection(model_schema.features or []):
+            credentials['function_calling_type'] = 'tool_call'
+
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+        """
+        Convert PromptMessage to dict for OpenAI API format
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "user", "content": message.content}
+            else:
+                sub_messages = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(PromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "text",
+                            "text": message_content.data
+                        }
+                        sub_messages.append(sub_message_dict)
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": message_content.data,
+                                "detail": message_content.detail.value
+                            }
+                        }
+                        sub_messages.append(sub_message_dict)
+                message_dict = {"role": "user", "content": sub_messages}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "assistant", "content": message.content}
+            if message.tool_calls:
+                message_dict["tool_calls"] = []
+                for function_call in message.tool_calls:
+                    message_dict["tool_calls"].append({
+                        "id": function_call.id,
+                        "type": function_call.type,
+                        "function": {
+                            "name": f"functions.{function_call.function.name}",
+                            "arguments": function_call.function.arguments
+                        }
+                    })
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
+            if not message.name.startswith("functions."):
+                message.name = f"functions.{message.name}"
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "system", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        if message.name:
+            message_dict["name"] = message.name
+
+        return message_dict
+
+    def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
+        """
+        Extract tool calls from response
+
+        :param response_tool_calls: response tool calls
+        :return: list of tool calls
+        """
+        tool_calls = []
+        if response_tool_calls:
+            for response_tool_call in response_tool_calls:
+                function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                    name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
+                    arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
+                )
+
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=response_tool_call["id"] if response_tool_call.get("id") else "",
+                    type=response_tool_call["type"] if response_tool_call.get("type") else "",
+                    function=function
+                )
+                tool_calls.append(tool_call)
+
+        return tool_calls
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param credentials: model credentials
+        :param response: streamed response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator
+        """
+        full_assistant_content = ''
+        chunk_index = 0
+
+        def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
+                -> LLMResultChunk:
+            # calculate num tokens
+            prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
+            completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
+
+            # transform usage
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            return LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=message,
+                    finish_reason=finish_reason,
+                    usage=usage
+                )
+            )
+
+        tools_calls: list[AssistantPromptMessage.ToolCall] = []
+        finish_reason = "Unknown"
+
+        def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
+            def get_tool_call(tool_name: str):
+                if not tool_name:
+                    return tools_calls[-1]
+
+                tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
+                if tool_call is None:
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id='',
+                        type='',
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
+                    )
+                    tools_calls.append(tool_call)
+
+                return tool_call
+
+            for new_tool_call in new_tool_calls:
+                # get tool call
+                tool_call = get_tool_call(new_tool_call.function.name)
+                # update tool call
+                if new_tool_call.id:
+                    tool_call.id = new_tool_call.id
+                if new_tool_call.type:
+                    tool_call.type = new_tool_call.type
+                if new_tool_call.function.name:
+                    # remove the functions. prefix
+                    if new_tool_call.function.name.startswith('functions.'):
+                        parts = new_tool_call.function.name.split('functions.')
+                        if len(parts) > 1:
+                            new_tool_call.function.name = parts[1]
+                    tool_call.function.name = new_tool_call.function.name
+                if new_tool_call.function.arguments:
+                    tool_call.function.arguments += new_tool_call.function.arguments
+
+        for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
+            if chunk:
+                # ignore sse comments
+                if chunk.startswith(':'):
+                    continue
+                decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
+                chunk_json = None
+                try:
+                    chunk_json = json.loads(decoded_chunk)
+                # stream ended
+                except json.JSONDecodeError as e:
+                    yield create_final_llm_result_chunk(
+                        index=chunk_index + 1,
+                        message=AssistantPromptMessage(content=""),
+                        finish_reason="Non-JSON encountered."
+                    )
+                    break
+                if not chunk_json or len(chunk_json['choices']) == 0:
+                    continue
+
+                choice = chunk_json['choices'][0]
+                finish_reason = chunk_json['choices'][0].get('finish_reason')
+                chunk_index += 1
+
+                if 'delta' in choice:
+                    delta = choice['delta']
+                    delta_content = delta.get('content')
+
+                    assistant_message_tool_calls = delta.get('tool_calls', None)
+                    # assistant_message_function_call = delta.delta.function_call
+
+                    # extract tool calls from response
+                    if assistant_message_tool_calls:
+                        tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+                        increase_tool_call(tool_calls)
+
+                    if delta_content is None or delta_content == '':
+                        continue
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(
+                        content=delta_content,
+                        tool_calls=tool_calls if assistant_message_tool_calls else []
+                    )
+
+                    full_assistant_content += delta_content
+                elif 'text' in choice:
+                    choice_text = choice.get('text', '')
+                    if choice_text == '':
+                        continue
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(content=choice_text)
+                    full_assistant_content += choice_text
+                else:
+                    continue
+
+                # check payload indicator for completion
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=chunk_index,
+                        message=assistant_prompt_message,
+                    )
+                )
+
+            chunk_index += 1
+        
+        if tools_calls:
+            yield LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=chunk_index,
+                    message=AssistantPromptMessage(
+                        tool_calls=tools_calls,
+                        content=""
+                    ),
+                )
+            )
+
+        yield create_final_llm_result_chunk(
+            index=chunk_index,
+            message=AssistantPromptMessage(content=""),
+            finish_reason=finish_reason
+        )

+ 49 - 0
api/core/model_runtime/model_providers/moonshot/moonshot.yaml

@@ -20,6 +20,7 @@ supported_model_types:
   - llm
 configurate_methods:
   - predefined-model
+  - customizable-model
 provider_credential_schema:
   credential_form_schemas:
     - variable: api_key
@@ -30,3 +31,51 @@ provider_credential_schema:
       placeholder:
         zh_Hans: 在此输入您的 API Key
         en_US: Enter your API Key
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: '4096'
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      default: '4096'
+      type: text-input
+    - variable: function_calling_type
+      label:
+        en_US: Function calling
+      type: select
+      required: false
+      default: no_call
+      options:
+        - value: no_call
+          label:
+            en_US: Not supported
+            zh_Hans: 不支持
+        - value: tool_call
+          label:
+            en_US: Tool Call
+            zh_Hans: Tool Call

+ 46 - 3
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -378,6 +378,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         delimiter = credentials.get("stream_mode_delimiter", "\n\n")
         delimiter = codecs.decode(delimiter, "unicode_escape")
 
+        tools_calls: list[AssistantPromptMessage.ToolCall] = []
+
+        def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
+            def get_tool_call(tool_call_id: str):
+                tool_call = next(
+                    (tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None
+                )
+                if tool_call is None:
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id='', 
+                        type='function', 
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name='',
+                            arguments=''
+                        )
+                    )
+                    tools_calls.append(tool_call)
+                return tool_call
+
+            for new_tool_call in new_tool_calls:
+                # get tool call
+                tool_call = get_tool_call(new_tool_call.id)
+                # update tool call
+                tool_call.id = new_tool_call.id
+                tool_call.type = new_tool_call.type
+                tool_call.function.name = new_tool_call.function.name
+                tool_call.function.arguments += new_tool_call.function.arguments
+
         for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             if chunk:
                 # ignore sse comments
@@ -405,8 +433,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 if 'delta' in choice:
                     delta = choice['delta']
                     delta_content = delta.get('content')
-                    if delta_content is None or delta_content == '':
-                        continue
 
                     assistant_message_tool_calls = delta.get('tool_calls', None)
                     # assistant_message_function_call = delta.delta.function_call
@@ -414,6 +440,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     # extract tool calls from response
                     if assistant_message_tool_calls:
                         tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+                        increase_tool_call(tool_calls)
+
+                    if delta_content is None or delta_content == '':
+                        continue
+
                     # function_call = self._extract_response_function_call(assistant_message_function_call)
                     # tool_calls = [function_call] if function_call else []
 
@@ -437,6 +468,18 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
                 # check payload indicator for completion
                 if finish_reason is not None:
+                    yield LLMResultChunk(
+                        model=model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=chunk_index,
+                            message=AssistantPromptMessage(
+                                tool_calls=tools_calls,
+                            ),
+                            finish_reason=finish_reason
+                        )
+                    )
+
                     yield create_final_llm_result_chunk(
                         index=chunk_index,
                         message=assistant_prompt_message,
@@ -735,4 +778,4 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 function=function
             )
 
-        return tool_call
+        return tool_call