Bläddra i källkod

refactor: tool parameter cache (#3703)

Yeuoly 1 år sedan
förälder
incheckning
3480f1c59e

+ 3 - 44
api/controllers/console/app/app.py

@@ -1,5 +1,3 @@
-import json
-
 from flask_login import current_user
 from flask_restful import Resource, inputs, marshal_with, reqparse
 from werkzeug.exceptions import BadRequest, Forbidden
@@ -8,17 +6,12 @@ from controllers.console import api
 from controllers.console.app.wraps import get_app_model
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
-from core.agent.entities import AgentToolEntity
-from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ToolParameterConfigurationManager
-from extensions.ext_database import db
 from fields.app_fields import (
     app_detail_fields,
     app_detail_fields_with_site,
     app_pagination_fields,
 )
 from libs.login import login_required
-from models.model import App, AppMode, AppModelConfig
 from services.app_service import AppService
 
 ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
@@ -108,43 +101,9 @@ class AppApi(Resource):
     @marshal_with(app_detail_fields_with_site)
     def get(self, app_model):
         """Get app detail"""
-        # get original app model config
-        if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
-            model_config: AppModelConfig = app_model.app_model_config
-            agent_mode = model_config.agent_mode_dict
-            # decrypt agent tool parameters if it's secret-input
-            for tool in agent_mode.get('tools') or []:
-                if not isinstance(tool, dict) or len(tool.keys()) <= 3:
-                    continue
-                agent_tool_entity = AgentToolEntity(**tool)
-                # get tool
-                try:
-                    tool_runtime = ToolManager.get_agent_tool_runtime(
-                        tenant_id=current_user.current_tenant_id,
-                        agent_tool=agent_tool_entity,
-                    )
-                    manager = ToolParameterConfigurationManager(
-                        tenant_id=current_user.current_tenant_id,
-                        tool_runtime=tool_runtime,
-                        provider_name=agent_tool_entity.provider_id,
-                        provider_type=agent_tool_entity.provider_type,
-                    )
-
-                    # get decrypted parameters
-                    if agent_tool_entity.tool_parameters:
-                        parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
-                        masked_parameter = manager.mask_tool_parameters(parameters or {})
-                    else:
-                        masked_parameter = {}
-
-                    # override tool parameters
-                    tool['tool_parameters'] = masked_parameter
-                except Exception as e:
-                    pass
-
-            # override agent mode
-            model_config.agent_mode = json.dumps(agent_mode)
-            db.session.commit()
+        app_service = AppService()
+
+        app_model = app_service.get_app(app_model)
 
         return app_model
 

+ 9 - 3
api/controllers/console/app/model_config.py

@@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
                 try:
                     tool_runtime = ToolManager.get_agent_tool_runtime(
                         tenant_id=current_user.current_tenant_id,
+                        app_id=app_model.id,
                         agent_tool=agent_tool_entity,
                     )
                     manager = ToolParameterConfigurationManager(
@@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
                         tool_runtime=tool_runtime,
                         provider_name=agent_tool_entity.provider_id,
                         provider_type=agent_tool_entity.provider_type,
+                        identity_id=f'AGENT.{app_model.id}'
                     )
                 except Exception as e:
                     continue
@@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
                     try:
                         tool_runtime = ToolManager.get_agent_tool_runtime(
                             tenant_id=current_user.current_tenant_id,
+                            app_id=app_model.id,
                             agent_tool=agent_tool_entity,
                         )
                     except Exception as e:
@@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
                     tool_runtime=tool_runtime,
                     provider_name=agent_tool_entity.provider_id,
                     provider_type=agent_tool_entity.provider_type,
+                    identity_id=f'AGENT.{app_model.id}'
                 )
                 manager.delete_tool_parameters_cache()
 
@@ -111,9 +115,11 @@ class ModelConfigResource(Resource):
                 if agent_tool_entity.tool_parameters:
                     if key not in masked_parameter_map:
                         continue
-
-                    if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
-                        agent_tool_entity.tool_parameters = parameter_map[key]
+                    
+                    for masked_key, masked_value in masked_parameter_map[key].items():
+                        if masked_key in agent_tool_entity.tool_parameters and \
+                                agent_tool_entity.tool_parameters[masked_key] == masked_value:
+                            agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
 
                 # encrypt parameters
                 if agent_tool_entity.tool_parameters:

+ 1 - 0
api/core/agent/base_agent_runner.py

@@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
         """
         tool_entity = ToolManager.get_agent_tool_runtime(
             tenant_id=self.tenant_id,
+            app_id=self.app_config.app_id,
             agent_tool=tool,
         )
         tool_entity.load_variables(self.variables_pool)

+ 6 - 5
api/core/helper/tool_parameter_cache.py

@@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
 
 class ToolParameterCache:
     def __init__(self, 
-                 tenant_id: str, 
-                 provider: str, 
-                 tool_name: str, 
-                 cache_type: ToolParameterCacheType
+            tenant_id: str, 
+            provider: str, 
+            tool_name: str, 
+            cache_type: ToolParameterCacheType,
+            identity_id: str
         ):
-        self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
+        self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
 
     def get(self) -> Optional[dict]:
         """

+ 4 - 2
api/core/tools/tool_manager.py

@@ -222,7 +222,7 @@ class ToolManager:
         return parameter_value
 
     @classmethod
-    def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
+    def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
         """
             get the agent tool runtime
         """
@@ -245,6 +245,7 @@ class ToolManager:
             tool_runtime=tool_entity,
             provider_name=agent_tool.provider_id,
             provider_type=agent_tool.provider_type,
+            identity_id=f'AGENT.{app_id}'
         )
         runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
 
@@ -252,7 +253,7 @@ class ToolManager:
         return tool_entity
 
     @classmethod
-    def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
+    def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
         """
             get the workflow tool runtime
         """
@@ -277,6 +278,7 @@ class ToolManager:
             tool_runtime=tool_entity,
             provider_name=workflow_tool.provider_id,
             provider_type=workflow_tool.provider_type,
+            identity_id=f'WORKFLOW.{app_id}.{node_id}'
         )
 
         if runtime_parameters:

+ 8 - 3
api/core/tools/utils/configuration.py

@@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel):
     tool_runtime: Tool
     provider_name: str
     provider_type: str
+    identity_id: str
 
     def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         deep copy parameters
         """
-        return {key: value for key, value in parameters.items()}
+        return deepcopy(parameters)
     
     def _merge_parameters(self) -> list[ToolParameter]:
         """
@@ -176,6 +177,8 @@ class ToolParameterConfigurationManager(BaseModel):
         # override parameters
         current_parameters = self._merge_parameters()
 
+        parameters = self._deep_copy(parameters)
+
         for parameter in current_parameters:
             if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
                 if parameter.name in parameters:
@@ -194,7 +197,8 @@ class ToolParameterConfigurationManager(BaseModel):
             tenant_id=self.tenant_id, 
             provider=f'{self.provider_type}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
-            cache_type=ToolParameterCacheType.PARAMETER
+            cache_type=ToolParameterCacheType.PARAMETER,
+            identity_id=self.identity_id
         )
         cached_parameters = cache.get()
         if cached_parameters:
@@ -223,7 +227,8 @@ class ToolParameterConfigurationManager(BaseModel):
             tenant_id=self.tenant_id, 
             provider=f'{self.provider_type}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
-            cache_type=ToolParameterCacheType.PARAMETER
+            cache_type=ToolParameterCacheType.PARAMETER,
+            identity_id=self.identity_id
         )
         cache.delete()
 

+ 2 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -39,7 +39,8 @@ class ToolNode(BaseNode):
         parameters = self._generate_parameters(variable_pool, node_data)
         # get tool runtime
         try:
-            tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
+            self.app_id
+            tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
         except Exception as e:
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,

+ 1 - 0
api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py

@@ -22,5 +22,6 @@ def handle(sender, **kwargs):
                 tool_runtime=tool_runtime,
                 provider_name=tool_entity.provider_name,
                 provider_type=tool_entity.provider_type,
+                identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
             )
             manager.delete_tool_parameters_cache()

+ 62 - 0
api/services/app_service.py

@@ -5,13 +5,17 @@ from typing import cast
 
 import yaml
 from flask import current_app
+from flask_login import current_user
 from flask_sqlalchemy.pagination import Pagination
 
 from constants.model_template import default_app_templates
+from core.agent.entities import AgentToolEntity
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.tools.tool_manager import ToolManager
+from core.tools.utils.configuration import ToolParameterConfigurationManager
 from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
 from extensions.ext_database import db
 from models.account import Account
@@ -240,6 +244,64 @@ class AppService:
 
         return yaml.dump(export_data)
 
+    def get_app(self, app: App) -> App:
+        """
+        Get App
+        """
+        # get original app model config
+        if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
+            model_config: AppModelConfig = app.app_model_config
+            agent_mode = model_config.agent_mode_dict
+            # decrypt agent tool parameters if it's secret-input
+            for tool in agent_mode.get('tools') or []:
+                if not isinstance(tool, dict) or len(tool.keys()) <= 3:
+                    continue
+                agent_tool_entity = AgentToolEntity(**tool)
+                # get tool
+                try:
+                    tool_runtime = ToolManager.get_agent_tool_runtime(
+                        tenant_id=current_user.current_tenant_id,
+                        app_id=app.id,
+                        agent_tool=agent_tool_entity,
+                    )
+                    manager = ToolParameterConfigurationManager(
+                        tenant_id=current_user.current_tenant_id,
+                        tool_runtime=tool_runtime,
+                        provider_name=agent_tool_entity.provider_id,
+                        provider_type=agent_tool_entity.provider_type,
+                        identity_id=f'AGENT.{app.id}'
+                    )
+
+                    # get decrypted parameters
+                    if agent_tool_entity.tool_parameters:
+                        parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
+                        masked_parameter = manager.mask_tool_parameters(parameters or {})
+                    else:
+                        masked_parameter = {}
+
+                    # override tool parameters
+                    tool['tool_parameters'] = masked_parameter
+                except Exception as e:
+                    pass
+
+            # override agent mode
+            model_config.agent_mode = json.dumps(agent_mode)
+
+            class ModifiedApp(App):
+                """
+                Modified App class
+                """
+                def __init__(self, app):
+                    self.__dict__.update(app.__dict__)
+
+                @property
+                def app_model_config(self):
+                    return model_config
+                
+            app = ModifiedApp(app)
+
+        return app
+
     def update_app(self, app: App, args: dict) -> App:
         """
         Update app