Parcourir la source

refactor(tool-engine): Improve tool provider handling with session ma… (#14291)

Yeuoly il y a 1 mois
Parent
commit
9fb78ce827
1 fichiers modifiés avec 79 ajouts et 78 suppressions
  1. 79 78
      api/core/tools/tool_manager.py

+ 79 - 78
api/core/tools/tool_manager.py

@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
 from yarl import URL
 
 import contexts
-from core.plugin.entities.plugin import GenericProviderID
+from core.plugin.entities.plugin import ToolProviderID
 from core.plugin.manager.tool import PluginToolManager
 from core.tools.__base.tool_provider import ToolProviderController
 from core.tools.__base.tool_runtime import ToolRuntime
@@ -188,7 +188,7 @@ class ToolManager:
                 )
 
             if isinstance(provider_controller, PluginToolProviderController):
-                provider_id_entity = GenericProviderID(provider_id)
+                provider_id_entity = ToolProviderID(provider_id)
                 # get credentials
                 builtin_provider: BuiltinToolProvider | None = (
                     db.session.query(BuiltinToolProvider)
@@ -572,95 +572,96 @@ class ToolManager:
         else:
             filters.append(typ)
 
-        if "builtin" in filters:
-            # get builtin providers
-            builtin_providers = cls.list_builtin_providers(tenant_id)
+        with db.session.no_autoflush:
+            if "builtin" in filters:
+                # get builtin providers
+                builtin_providers = cls.list_builtin_providers(tenant_id)
 
-            # get db builtin providers
-            db_builtin_providers: list[BuiltinToolProvider] = (
-                db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
-            )
-
-            # rewrite db_builtin_providers
-            for db_provider in db_builtin_providers:
-                tool_provider_id = GenericProviderID(db_provider.provider)
-                db_provider.provider = tool_provider_id.to_string()
-
-            def find_db_builtin_provider(provider):
-                return next((x for x in db_builtin_providers if x.provider == provider), None)
-
-            # append builtin providers
-            for provider in builtin_providers:
-                # handle include, exclude
-                if is_filtered(
-                    include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
-                    exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
-                    data=provider,
-                    name_func=lambda x: x.identity.name,
-                ):
-                    continue
-
-                user_provider = ToolTransformService.builtin_provider_to_user_provider(
-                    provider_controller=provider,
-                    db_provider=find_db_builtin_provider(provider.entity.identity.name),
-                    decrypt_credentials=False,
+                # get db builtin providers
+                db_builtin_providers: list[BuiltinToolProvider] = (
+                    db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
                 )
 
-                if isinstance(provider, PluginToolProviderController):
-                    result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
-                else:
-                    result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
-
-        # get db api providers
-
-        if "api" in filters:
-            db_api_providers: list[ApiToolProvider] = (
-                db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
-            )
+                # rewrite db_builtin_providers
+                for db_provider in db_builtin_providers:
+                    tool_provider_id = str(ToolProviderID(db_provider.provider))
+                    db_provider.provider = tool_provider_id
+
+                def find_db_builtin_provider(provider):
+                    return next((x for x in db_builtin_providers if x.provider == provider), None)
+
+                # append builtin providers
+                for provider in builtin_providers:
+                    # handle include, exclude
+                    if is_filtered(
+                        include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
+                        exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
+                        data=provider,
+                        name_func=lambda x: x.identity.name,
+                    ):
+                        continue
+
+                    user_provider = ToolTransformService.builtin_provider_to_user_provider(
+                        provider_controller=provider,
+                        db_provider=find_db_builtin_provider(provider.entity.identity.name),
+                        decrypt_credentials=False,
+                    )
 
-            api_provider_controllers: list[dict[str, Any]] = [
-                {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
-                for provider in db_api_providers
-            ]
+                    if isinstance(provider, PluginToolProviderController):
+                        result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
+                    else:
+                        result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
 
-            # get labels
-            labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
+            # get db api providers
 
-            for api_provider_controller in api_provider_controllers:
-                user_provider = ToolTransformService.api_provider_to_user_provider(
-                    provider_controller=api_provider_controller["controller"],
-                    db_provider=api_provider_controller["provider"],
-                    decrypt_credentials=False,
-                    labels=labels.get(api_provider_controller["controller"].provider_id, []),
+            if "api" in filters:
+                db_api_providers: list[ApiToolProvider] = (
+                    db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
                 )
-                result_providers[f"api_provider.{user_provider.name}"] = user_provider
 
-        if "workflow" in filters:
-            # get workflow providers
-            workflow_providers: list[WorkflowToolProvider] = (
-                db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
-            )
+                api_provider_controllers: list[dict[str, Any]] = [
+                    {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
+                    for provider in db_api_providers
+                ]
 
-            workflow_provider_controllers: list[WorkflowToolProviderController] = []
-            for provider in workflow_providers:
-                try:
-                    workflow_provider_controllers.append(
-                        ToolTransformService.workflow_provider_to_controller(db_provider=provider)
+                # get labels
+                labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
+
+                for api_provider_controller in api_provider_controllers:
+                    user_provider = ToolTransformService.api_provider_to_user_provider(
+                        provider_controller=api_provider_controller["controller"],
+                        db_provider=api_provider_controller["provider"],
+                        decrypt_credentials=False,
+                        labels=labels.get(api_provider_controller["controller"].provider_id, []),
                     )
-                except Exception:
-                    # app has been deleted
-                    pass
+                    result_providers[f"api_provider.{user_provider.name}"] = user_provider
 
-            labels = ToolLabelManager.get_tools_labels(
-                [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
-            )
+            if "workflow" in filters:
+                # get workflow providers
+                workflow_providers: list[WorkflowToolProvider] = (
+                    db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
+                )
 
-            for provider_controller in workflow_provider_controllers:
-                user_provider = ToolTransformService.workflow_provider_to_user_provider(
-                    provider_controller=provider_controller,
-                    labels=labels.get(provider_controller.provider_id, []),
+                workflow_provider_controllers: list[WorkflowToolProviderController] = []
+                for provider in workflow_providers:
+                    try:
+                        workflow_provider_controllers.append(
+                            ToolTransformService.workflow_provider_to_controller(db_provider=provider)
+                        )
+                    except Exception:
+                        # app has been deleted
+                        pass
+
+                labels = ToolLabelManager.get_tools_labels(
+                    [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
                 )
-                result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
+
+                for provider_controller in workflow_provider_controllers:
+                    user_provider = ToolTransformService.workflow_provider_to_user_provider(
+                        provider_controller=provider_controller,
+                        labels=labels.get(provider_controller.provider_id, []),
+                    )
+                    result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
 
         return BuiltinToolProviderSort.sort(list(result_providers.values()))