|
@@ -6,11 +6,17 @@ from os import listdir, path
|
|
|
from typing import Any, Union
|
|
|
|
|
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
|
|
+from core.entities.application_entities import AgentToolEntity
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage
|
|
|
from core.provider_manager import ProviderManager
|
|
|
from core.tools.entities.common_entities import I18nObject
|
|
|
from core.tools.entities.constant import DEFAULT_PROVIDERS
|
|
|
-from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
|
|
|
+from core.tools.entities.tool_entities import (
|
|
|
+ ApiProviderAuthType,
|
|
|
+ ToolInvokeMessage,
|
|
|
+ ToolParameter,
|
|
|
+ ToolProviderCredentials,
|
|
|
+)
|
|
|
from core.tools.entities.user_entities import UserToolProvider
|
|
|
from core.tools.errors import ToolProviderNotFoundError
|
|
|
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
|
@@ -21,7 +27,12 @@ from core.tools.provider.model_tool_provider import ModelToolProviderController
|
|
|
from core.tools.provider.tool_provider import ToolProviderController
|
|
|
from core.tools.tool.api_tool import ApiTool
|
|
|
from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
-from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
|
|
|
+from core.tools.tool.tool import Tool
|
|
|
+from core.tools.utils.configuration import (
|
|
|
+ ModelToolConfigurationManager,
|
|
|
+ ToolConfigurationManager,
|
|
|
+ ToolParameterConfigurationManager,
|
|
|
+)
|
|
|
from core.tools.utils.encoder import serialize_base_model_dict
|
|
|
from extensions.ext_database import db
|
|
|
from models.tools import ApiToolProvider, BuiltinToolProvider
|
|
@@ -172,7 +183,7 @@ class ToolManager:
|
|
|
# decrypt the credentials
|
|
|
credentials = builtin_provider.credentials
|
|
|
controller = ToolManager.get_builtin_provider(provider_name)
|
|
|
- tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
|
|
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
|
|
|
|
@@ -189,7 +200,7 @@ class ToolManager:
|
|
|
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
|
|
|
|
|
|
# decrypt the credentials
|
|
|
- tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
|
|
|
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
|
|
|
|
|
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
|
|
@@ -214,6 +225,71 @@ class ToolManager:
|
|
|
else:
|
|
|
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
|
|
+ """
|
|
|
+ get the agent tool runtime
|
|
|
+ """
|
|
|
+ tool_entity = ToolManager.get_tool_runtime(
|
|
|
+ provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ agent_callback=agent_callback
|
|
|
+ )
|
|
|
+ runtime_parameters = {}
|
|
|
+ parameters = tool_entity.get_all_runtime_parameters()
|
|
|
+ for parameter in parameters:
|
|
|
+ if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
|
|
+ # get tool parameter from form
|
|
|
+ tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
|
|
|
+ if not tool_parameter_config:
|
|
|
+ # get default value
|
|
|
+ tool_parameter_config = parameter.default
|
|
|
+ if not tool_parameter_config and parameter.required:
|
|
|
+ raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
|
|
+
|
|
|
+ if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
|
|
+ # check if tool_parameter_config in options
|
|
|
+ options = list(map(lambda x: x.value, parameter.options))
|
|
|
+ if tool_parameter_config not in options:
|
|
|
+ raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
|
|
+
|
|
|
+ # convert tool parameter config to correct type
|
|
|
+ try:
|
|
|
+ if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
|
|
+ # check if tool parameter is integer
|
|
|
+ if isinstance(tool_parameter_config, int):
|
|
|
+ tool_parameter_config = tool_parameter_config
|
|
|
+ elif isinstance(tool_parameter_config, float):
|
|
|
+ tool_parameter_config = tool_parameter_config
|
|
|
+ elif isinstance(tool_parameter_config, str):
|
|
|
+ if '.' in tool_parameter_config:
|
|
|
+ tool_parameter_config = float(tool_parameter_config)
|
|
|
+ else:
|
|
|
+ tool_parameter_config = int(tool_parameter_config)
|
|
|
+ elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
|
|
+ tool_parameter_config = bool(tool_parameter_config)
|
|
|
+ elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
|
|
+ tool_parameter_config = str(tool_parameter_config)
|
|
|
+ elif parameter.type == ToolParameter.ToolParameterType:
|
|
|
+ tool_parameter_config = str(tool_parameter_config)
|
|
|
+ except Exception as e:
|
|
|
+ raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
|
|
+
|
|
|
+ # save tool parameter to tool entity memory
|
|
|
+ runtime_parameters[parameter.name] = tool_parameter_config
|
|
|
+
|
|
|
+ # decrypt runtime parameters
|
|
|
+ encryption_manager = ToolParameterConfigurationManager(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ tool_runtime=tool_entity,
|
|
|
+ provider_name=agent_tool.provider_id,
|
|
|
+ provider_type=agent_tool.provider_type,
|
|
|
+ )
|
|
|
+ runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
|
|
+
|
|
|
+ tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
|
|
+ return tool_entity
|
|
|
+
|
|
|
@staticmethod
|
|
|
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
|
|
|
"""
|
|
@@ -396,7 +472,7 @@ class ToolManager:
|
|
|
controller = ToolManager.get_builtin_provider(provider_name)
|
|
|
|
|
|
# init tool configuration
|
|
|
- tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
|
|
# decrypt the credentials and mask the credentials
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
|
@@ -463,7 +539,7 @@ class ToolManager:
|
|
|
)
|
|
|
|
|
|
# init tool configuration
|
|
|
- tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
|
|
|
|
|
# decrypt the credentials and mask the credentials
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
|
@@ -523,7 +599,7 @@ class ToolManager:
|
|
|
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
|
|
)
|
|
|
# init tool configuration
|
|
|
- tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
|
|
+ tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
|
|
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|