ソースを参照

improve: generalize tool parameter converter (#4786)

Bowen Liang 10 ヶ月 前
コミット
3542d55e67

+ 8 - 27
api/core/agent/base_agent_runner.py

@@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import (
 from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tools.tool.tool import Tool
 from core.tools.tool_manager import ToolManager
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from extensions.ext_database import db
 from models.model import Conversation, Message, MessageAgentThought
 from models.tools import ToolConversationVariables
@@ -186,21 +187,11 @@ class BaseAgentRunner(AppRunner):
             if parameter.form != ToolParameter.ToolParameterForm.LLM:
                 continue
 
-            parameter_type = 'string'
+            parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
             enum = []
-            if parameter.type == ToolParameter.ToolParameterType.STRING:
-                parameter_type = 'string'
-            elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
-                parameter_type = 'boolean'
-            elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
-                parameter_type = 'number'
-            elif parameter.type == ToolParameter.ToolParameterType.SELECT:
-                for option in parameter.options:
-                    enum.append(option.value)
-                parameter_type = 'string'
-            else:
-                raise ValueError(f"parameter type {parameter.type} is not supported")
-            
+            if parameter.type == ToolParameter.ToolParameterType.SELECT:
+                enum = [option.value for option in parameter.options]
+
             message_tool.parameters['properties'][parameter.name] = {
                 "type": parameter_type,
                 "description": parameter.llm_description or '',
@@ -281,20 +272,10 @@ class BaseAgentRunner(AppRunner):
             if parameter.form != ToolParameter.ToolParameterForm.LLM:
                 continue
 
-            parameter_type = 'string'
+            parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
             enum = []
-            if parameter.type == ToolParameter.ToolParameterType.STRING:
-                parameter_type = 'string'
-            elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
-                parameter_type = 'boolean'
-            elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
-                parameter_type = 'number'
-            elif parameter.type == ToolParameter.ToolParameterType.SELECT:
-                for option in parameter.options:
-                    enum.append(option.value)
-                parameter_type = 'string'
-            else:
-                raise ValueError(f"parameter type {parameter.type} is not supported")
+            if parameter.type == ToolParameter.ToolParameterType.SELECT:
+                enum = [option.value for option in parameter.options]
         
             prompt_tool.parameters['properties'][parameter.name] = {
                 "type": parameter_type,

+ 0 - 0
api/core/tools/__init__.py


+ 2 - 1
api/core/tools/entities/tool_entities.py

@@ -116,8 +116,9 @@ class ToolParameterOption(BaseModel):
     value: str = Field(..., description="The value of the option")
     label: I18nObject = Field(..., description="The label of the option")
 
+
 class ToolParameter(BaseModel):
-    class ToolParameterType(Enum):
+    class ToolParameterType(str, Enum):
         STRING = "string"
         NUMBER = "number"
         BOOLEAN = "boolean"

+ 3 - 10
api/core/tools/provider/builtin_tool_provider.py

@@ -12,6 +12,7 @@ from core.tools.errors import (
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.module_import_helper import load_single_subclass_from_source
 
@@ -200,16 +201,8 @@ class BuiltinToolProviderController(ToolProviderController):
             
             # the parameter is not set currently, set the default value if needed
             if parameter_schema.default is not None:
-                default_value = parameter_schema.default
-                # parse default value into the correct type
-                if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
-                    parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
-                    default_value = str(default_value)
-                elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
-                    default_value = float(default_value)
-                elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
-                    default_value = bool(default_value)
-
+                default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
+                                                                              parameter_schema.type)
                 tool_parameters[parameter] = default_value
     
     def validate_credentials(self, credentials: dict[str, Any]) -> None:

+ 3 - 11
api/core/tools/provider/tool_provider.py

@@ -11,6 +11,7 @@ from core.tools.entities.tool_entities import (
 )
 from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
 from core.tools.tool.tool import Tool
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 
 
 class ToolProviderController(BaseModel, ABC):
@@ -122,17 +123,8 @@ class ToolProviderController(BaseModel, ABC):
             
             # the parameter is not set currently, set the default value if needed
             if parameter_schema.default is not None:
-                default_value = parameter_schema.default
-                # parse default value into the correct type
-                if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
-                    parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
-                    default_value = str(default_value)
-                elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
-                    default_value = float(default_value)
-                elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
-                    default_value = bool(default_value)
-
-                tool_parameters[parameter] = default_value
+                tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
+                                                                                           parameter_schema.type)
 
     def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
         """

+ 3 - 40
api/core/tools/tool/tool.py

@@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
     ToolRuntimeVariablePool,
 )
 from core.tools.tool_file_manager import ToolFileManager
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 
 
 class Tool(BaseModel, ABC):
@@ -228,46 +229,8 @@ class Tool(BaseModel, ABC):
         """
         Transform tool parameters type
         """
-        for parameter in self.parameters:
-            if parameter.name in tool_parameters:
-                if parameter.type in [
-                    ToolParameter.ToolParameterType.SECRET_INPUT, 
-                    ToolParameter.ToolParameterType.STRING, 
-                    ToolParameter.ToolParameterType.SELECT,
-                ] and not isinstance(tool_parameters[parameter.name], str):
-                    if tool_parameters[parameter.name] is None:
-                        tool_parameters[parameter.name] = ''
-                    else:
-                        tool_parameters[parameter.name] = str(tool_parameters[parameter.name])
-                elif parameter.type == ToolParameter.ToolParameterType.NUMBER \
-                    and not isinstance(tool_parameters[parameter.name], int | float):
-                    if isinstance(tool_parameters[parameter.name], str):
-                        try:
-                            tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
-                        except ValueError:
-                            tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
-                    elif isinstance(tool_parameters[parameter.name], bool):
-                        tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
-                    elif tool_parameters[parameter.name] is None:
-                        tool_parameters[parameter.name] = 0
-                elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
-                    if not isinstance(tool_parameters[parameter.name], bool):
-                        # check if it is a string
-                        if isinstance(tool_parameters[parameter.name], str):
-                            # check true false
-                            if tool_parameters[parameter.name].lower() in ['true', 'false']:
-                                tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true'
-                            # check 1 0
-                            elif tool_parameters[parameter.name] in ['1', '0']:
-                                tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1'
-                            else:
-                                tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
-                        elif isinstance(tool_parameters[parameter.name], int | float):
-                            tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0
-                        else:
-                            tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
-                            
-        return tool_parameters
+        return {p.name: ToolParameterConverter.cast_parameter_by_type(tool_parameters[p.name], p.type)
+                for p in self.parameters if p.name in tool_parameters}
 
     @abstractmethod
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:

+ 2 - 25
api/core/tools/tool_manager.py

@@ -11,7 +11,6 @@ from flask import current_app
 from core.agent.entities import AgentToolEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.tools import *
 from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import (
@@ -31,6 +30,7 @@ from core.tools.utils.configuration import (
     ToolConfigurationManager,
     ToolParameterConfigurationManager,
 )
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.utils.module_import_helper import load_single_subclass_from_source
 from core.workflow.nodes.tool.entities import ToolEntity
 from extensions.ext_database import db
@@ -214,30 +214,7 @@ class ToolManager:
                 raise ValueError(
                     f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
 
-        # convert tool parameter config to correct type
-        try:
-            if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
-                # check if tool parameter is integer
-                if isinstance(parameter_value, int):
-                    parameter_value = parameter_value
-                elif isinstance(parameter_value, float):
-                    parameter_value = parameter_value
-                elif isinstance(parameter_value, str):
-                    if '.' in parameter_value:
-                        parameter_value = float(parameter_value)
-                    else:
-                        parameter_value = int(parameter_value)
-            elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
-                parameter_value = bool(parameter_value)
-            elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
-                                             ToolParameter.ToolParameterType.STRING]:
-                parameter_value = str(parameter_value)
-            elif parameter_rule.type == ToolParameter.ToolParameterType:
-                parameter_value = str(parameter_value)
-        except Exception as e:
-            raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
-
-        return parameter_value
+        return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
 
     @classmethod
     def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:

+ 0 - 0
api/core/tools/utils/__init__.py


+ 66 - 0
api/core/tools/utils/tool_parameter_converter.py

@@ -0,0 +1,66 @@
+from typing import Any
+
+from core.tools.entities.tool_entities import ToolParameter
+
+
+class ToolParameterConverter:
+    @staticmethod
+    def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
+        match parameter_type:
+            case ToolParameter.ToolParameterType.STRING \
+                 | ToolParameter.ToolParameterType.SECRET_INPUT \
+                 | ToolParameter.ToolParameterType.SELECT:
+                return 'string'
+
+            case ToolParameter.ToolParameterType.BOOLEAN:
+                return 'boolean'
+
+            case ToolParameter.ToolParameterType.NUMBER:
+                return 'number'
+
+            case _:
+                raise ValueError(f"Unsupported parameter type {parameter_type}")
+
+    @staticmethod
+    def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
+        # convert tool parameter config to correct type
+        try:
+            match parameter_type:
+                case ToolParameter.ToolParameterType.STRING \
+                     | ToolParameter.ToolParameterType.SECRET_INPUT \
+                     | ToolParameter.ToolParameterType.SELECT:
+                    if value is None:
+                        return ''
+                    else:
+                        return value if isinstance(value, str) else str(value)
+
+                case ToolParameter.ToolParameterType.BOOLEAN:
+                    if value is None:
+                        return False
+                    elif isinstance(value, str):
+                        # Allowed YAML boolean value strings: https://yaml.org/type/bool.html
+                        # and also '0' for False and '1' for True
+                        match value.lower():
+                            case 'true' | 'yes' | 'y' | '1':
+                                return True
+                            case 'false' | 'no' | 'n' | '0':
+                                return False
+                            case _:
+                                return bool(value)
+                    else:
+                        return value if isinstance(value, bool) else bool(value)
+
+                case ToolParameter.ToolParameterType.NUMBER:
+                    if isinstance(value, int) | isinstance(value, float):
+                        return value
+                    elif isinstance(value, str):
+                        if '.' in value:
+                            return float(value)
+                        else:
+                            return int(value)
+
+                case _:
+                    return str(value)
+
+        except Exception:
+            raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")

+ 56 - 0
api/tests/unit_tests/core/tools/test_tool_parameter_converter.py

@@ -0,0 +1,56 @@
+import pytest
+
+from core.tools.entities.tool_entities import ToolParameter
+from core.tools.utils.tool_parameter_converter import ToolParameterConverter
+
+
+def test_get_parameter_type():
+    assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string'
+    assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string'
+    assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean'
+    assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number'
+    with pytest.raises(ValueError):
+        ToolParameterConverter.get_parameter_type('unsupported_type')
+
+
+def test_cast_parameter_by_type():
+    # string
+    assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test'
+    assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1'
+    assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0'
+    assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ''
+
+    # secret input
+    assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test'
+    assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1'
+    assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0'
+    assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ''
+
+    # select
+    assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test'
+    assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1'
+    assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0'
+    assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ''
+
+    # boolean
+    true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something']
+    for value in true_values:
+        assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
+
+    false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, '']
+    for value in false_values:
+        assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
+
+    # number
+    assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1
+    assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0
+    assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0
+    assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
+    assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
+    assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
+    assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
+
+    # unknown
+    assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1'
+    assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1'
+    assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None