|
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
|
|
from yarl import URL
|
|
|
|
|
|
import contexts
|
|
|
-from core.plugin.entities.plugin import GenericProviderID
|
|
|
+from core.plugin.entities.plugin import ToolProviderID
|
|
|
from core.plugin.manager.tool import PluginToolManager
|
|
|
from core.tools.__base.tool_provider import ToolProviderController
|
|
|
from core.tools.__base.tool_runtime import ToolRuntime
|
|
@@ -188,7 +188,7 @@ class ToolManager:
|
|
|
)
|
|
|
|
|
|
if isinstance(provider_controller, PluginToolProviderController):
|
|
|
- provider_id_entity = GenericProviderID(provider_id)
|
|
|
+ provider_id_entity = ToolProviderID(provider_id)
|
|
|
# get credentials
|
|
|
builtin_provider: BuiltinToolProvider | None = (
|
|
|
db.session.query(BuiltinToolProvider)
|
|
@@ -572,95 +572,96 @@ class ToolManager:
|
|
|
else:
|
|
|
filters.append(typ)
|
|
|
|
|
|
- if "builtin" in filters:
|
|
|
- # get builtin providers
|
|
|
- builtin_providers = cls.list_builtin_providers(tenant_id)
|
|
|
+ with db.session.no_autoflush:
|
|
|
+ if "builtin" in filters:
|
|
|
+ # get builtin providers
|
|
|
+ builtin_providers = cls.list_builtin_providers(tenant_id)
|
|
|
|
|
|
- # get db builtin providers
|
|
|
- db_builtin_providers: list[BuiltinToolProvider] = (
|
|
|
- db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
|
|
- )
|
|
|
-
|
|
|
- # rewrite db_builtin_providers
|
|
|
- for db_provider in db_builtin_providers:
|
|
|
- tool_provider_id = GenericProviderID(db_provider.provider)
|
|
|
- db_provider.provider = tool_provider_id.to_string()
|
|
|
-
|
|
|
- def find_db_builtin_provider(provider):
|
|
|
- return next((x for x in db_builtin_providers if x.provider == provider), None)
|
|
|
-
|
|
|
- # append builtin providers
|
|
|
- for provider in builtin_providers:
|
|
|
- # handle include, exclude
|
|
|
- if is_filtered(
|
|
|
- include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
|
|
- exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
|
|
- data=provider,
|
|
|
- name_func=lambda x: x.identity.name,
|
|
|
- ):
|
|
|
- continue
|
|
|
-
|
|
|
- user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
|
|
- provider_controller=provider,
|
|
|
- db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
|
|
- decrypt_credentials=False,
|
|
|
+ # get db builtin providers
|
|
|
+ db_builtin_providers: list[BuiltinToolProvider] = (
|
|
|
+ db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
|
|
)
|
|
|
|
|
|
- if isinstance(provider, PluginToolProviderController):
|
|
|
- result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
|
|
- else:
|
|
|
- result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
|
|
-
|
|
|
- # get db api providers
|
|
|
-
|
|
|
- if "api" in filters:
|
|
|
- db_api_providers: list[ApiToolProvider] = (
|
|
|
- db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
|
|
- )
|
|
|
+ # rewrite db_builtin_providers
|
|
|
+ for db_provider in db_builtin_providers:
|
|
|
+ tool_provider_id = str(ToolProviderID(db_provider.provider))
|
|
|
+ db_provider.provider = tool_provider_id
|
|
|
+
|
|
|
+ def find_db_builtin_provider(provider):
|
|
|
+ return next((x for x in db_builtin_providers if x.provider == provider), None)
|
|
|
+
|
|
|
+ # append builtin providers
|
|
|
+ for provider in builtin_providers:
|
|
|
+ # handle include, exclude
|
|
|
+ if is_filtered(
|
|
|
+ include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
|
|
+ exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
|
|
+ data=provider,
|
|
|
+ name_func=lambda x: x.identity.name,
|
|
|
+ ):
|
|
|
+ continue
|
|
|
+
|
|
|
+ user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
|
|
+ provider_controller=provider,
|
|
|
+ db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
|
|
+ decrypt_credentials=False,
|
|
|
+ )
|
|
|
|
|
|
- api_provider_controllers: list[dict[str, Any]] = [
|
|
|
- {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
|
|
- for provider in db_api_providers
|
|
|
- ]
|
|
|
+ if isinstance(provider, PluginToolProviderController):
|
|
|
+ result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
|
|
+ else:
|
|
|
+ result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
|
|
|
|
|
- # get labels
|
|
|
- labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
|
|
+ # get db api providers
|
|
|
|
|
|
- for api_provider_controller in api_provider_controllers:
|
|
|
- user_provider = ToolTransformService.api_provider_to_user_provider(
|
|
|
- provider_controller=api_provider_controller["controller"],
|
|
|
- db_provider=api_provider_controller["provider"],
|
|
|
- decrypt_credentials=False,
|
|
|
- labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
|
|
+ if "api" in filters:
|
|
|
+ db_api_providers: list[ApiToolProvider] = (
|
|
|
+ db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
|
|
)
|
|
|
- result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
|
|
|
|
|
- if "workflow" in filters:
|
|
|
- # get workflow providers
|
|
|
- workflow_providers: list[WorkflowToolProvider] = (
|
|
|
- db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
|
|
- )
|
|
|
+ api_provider_controllers: list[dict[str, Any]] = [
|
|
|
+ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
|
|
+ for provider in db_api_providers
|
|
|
+ ]
|
|
|
|
|
|
- workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
|
|
- for provider in workflow_providers:
|
|
|
- try:
|
|
|
- workflow_provider_controllers.append(
|
|
|
- ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
|
|
+ # get labels
|
|
|
+ labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
|
|
+
|
|
|
+ for api_provider_controller in api_provider_controllers:
|
|
|
+ user_provider = ToolTransformService.api_provider_to_user_provider(
|
|
|
+ provider_controller=api_provider_controller["controller"],
|
|
|
+ db_provider=api_provider_controller["provider"],
|
|
|
+ decrypt_credentials=False,
|
|
|
+ labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
|
|
)
|
|
|
- except Exception:
|
|
|
- # app has been deleted
|
|
|
- pass
|
|
|
+ result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
|
|
|
|
|
- labels = ToolLabelManager.get_tools_labels(
|
|
|
- [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
|
|
- )
|
|
|
+ if "workflow" in filters:
|
|
|
+ # get workflow providers
|
|
|
+ workflow_providers: list[WorkflowToolProvider] = (
|
|
|
+ db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
|
|
+ )
|
|
|
|
|
|
- for provider_controller in workflow_provider_controllers:
|
|
|
- user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
|
|
- provider_controller=provider_controller,
|
|
|
- labels=labels.get(provider_controller.provider_id, []),
|
|
|
+ workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
|
|
+ for provider in workflow_providers:
|
|
|
+ try:
|
|
|
+ workflow_provider_controllers.append(
|
|
|
+ ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
|
|
+ )
|
|
|
+ except Exception:
|
|
|
+ # app has been deleted
|
|
|
+ pass
|
|
|
+
|
|
|
+ labels = ToolLabelManager.get_tools_labels(
|
|
|
+ [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
|
|
)
|
|
|
- result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
|
|
+
|
|
|
+ for provider_controller in workflow_provider_controllers:
|
|
|
+ user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
|
|
+ provider_controller=provider_controller,
|
|
|
+ labels=labels.get(provider_controller.provider_id, []),
|
|
|
+ )
|
|
|
+ result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
|
|
|
|
|
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
|
|
|