Bläddra i källkod

Fix tool provider credential caching issue (#2433)

Yeuoly 1 år sedan
förälder
incheckning
23e95fd7ab
2 ändrade filer med 42 tillägg och 12 borttagningar
  1. 9 1
      api/core/tools/utils/configuration.py
  2. 33 11
      api/services/tools_manage_service.py

+ 9 - 1
api/core/tools/utils/configuration.py

@@ -85,4 +85,12 @@ class ToolConfiguration(BaseModel):
                         pass
 
         cache.set(credentials)
-        return credentials
+        return credentials
+    
+    def delete_tool_credentials_cache(self):
+        cache = ToolProviderCredentialsCache(
+            tenant_id=self.tenant_id, 
+            identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
+            cache_type=ToolProviderCredentialsCacheType.PROVIDER
+        )
+        cache.delete()

+ 33 - 11
api/services/tools_manage_service.py

@@ -355,10 +355,12 @@ class ToolManageService:
 
         else:
             provider.encrypted_credentials = json.dumps(credentials)
-
             db.session.add(provider)
             db.session.commit()
 
+            # delete cache
+            tool_configuration.delete_tool_credentials_cache()
+
         return { 'result': 'success' }
     
     @staticmethod
@@ -393,7 +395,6 @@ class ToolManageService:
         provider.description = extra_info.get('description', '')
         provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
         provider.tools_str = serialize_base_model_array(tool_bundles)
-        provider.credentials_str = json.dumps(credentials)
         provider.privacy_policy = privacy_policy
 
         if 'auth_type' not in credentials:
@@ -403,33 +404,54 @@ class ToolManageService:
         auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
 
         # create provider entity
-        provider_entity = ApiBasedToolProviderController.from_db(provider, auth_type)
+        provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)
         # load tools into provider entity
-        provider_entity.load_bundled_tools(tool_bundles)
+        provider_controller.load_bundled_tools(tool_bundles)
+
+        # get original credentials if exists
+        tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
+
+        original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
+        masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
+        # check if the credential has changed, save the original credential
+        for name, value in credentials.items():
+            if name in masked_credentials and value == masked_credentials[name]:
+                credentials[name] = original_credentials[name]
+
+        credentials = tool_configuration.encrypt_tool_credentials(credentials)
+        provider.credentials_str = json.dumps(credentials)
 
         db.session.add(provider)
         db.session.commit()
 
+        # delete cache
+        tool_configuration.delete_tool_credentials_cache()
+
         return { 'result': 'success' }
     
     @staticmethod
     def delete_builtin_tool_provider(
-        user_id: str, tenant_id: str, provider: str
+        user_id: str, tenant_id: str, provider_name: str
     ):
         """
             delete tool provider
         """
         provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
             BuiltinToolProvider.tenant_id == tenant_id,
-            BuiltinToolProvider.provider == provider,
+            BuiltinToolProvider.provider == provider_name,
         ).first()
 
         if provider is None:
-            raise ValueError(f'you have not added provider {provider}')
+            raise ValueError(f'you have not added provider {provider_name}')
         
         db.session.delete(provider)
         db.session.commit()
 
+        # delete cache
+        provider_controller = ToolManager.get_builtin_provider(provider_name)
+        tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
+        tool_configuration.delete_tool_credentials_cache()
+
         return { 'result': 'success' }
     
     @staticmethod
@@ -437,7 +459,7 @@ class ToolManageService:
         provider: str
     ):
         """
-            get tool provider icon and it's minetype
+            get tool provider icon and it's mimetype
         """
         icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
         with open(icon_path, 'rb') as f:
@@ -447,18 +469,18 @@ class ToolManageService:
     
     @staticmethod
     def delete_api_tool_provider(
-        user_id: str, tenant_id: str, provider: str
+        user_id: str, tenant_id: str, provider_name: str
     ):
         """
             delete tool provider
         """
         provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
             ApiToolProvider.tenant_id == tenant_id,
-            ApiToolProvider.name == provider,
+            ApiToolProvider.name == provider_name,
         ).first()
 
         if provider is None:
-            raise ValueError(f'you have not added provider {provider}')
+            raise ValueError(f'you have not added provider {provider_name}')
         
         db.session.delete(provider)
         db.session.commit()