Bladeren bron

refactor: update builtin tool provider methods to use session management (#11938)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 maanden geleden
bovenliggende
commit
606aadb891
2 gewijzigde bestanden met toevoegingen van 30 en 27 verwijderingen
  1. 14 10
      api/controllers/console/workspace/tool_providers.py
  2. 16 17
      api/services/tools/builtin_tools_manage_service.py

+ 14 - 10
api/controllers/console/workspace/tool_providers.py

@@ -3,12 +3,14 @@ import io
 from flask import send_file
 from flask_login import current_user
 from flask_restful import Resource, reqparse
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import Forbidden
 
 from configs import dify_config
 from controllers.console import api
 from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
 from core.model_runtime.utils.encoders import jsonable_encoder
+from extensions.ext_database import db
 from libs.helper import alphanumeric, uuid_value
 from libs.login import login_required
 from services.tools.api_tools_manage_service import ApiToolManageService
@@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
 
         args = parser.parse_args()
 
-        return BuiltinToolManageService.update_builtin_tool_provider(
-            user_id,
-            tenant_id,
-            provider,
-            args["credentials"],
-        )
+        with Session(db.engine) as session:
+            result = BuiltinToolManageService.update_builtin_tool_provider(
+                session=session,
+                user_id=user_id,
+                tenant_id=tenant_id,
+                provider_name=provider,
+                credentials=args["credentials"],
+            )
+            session.commit()
+        return result
 
 
 class ToolBuiltinProviderGetCredentialsApi(Resource):
@@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
     @login_required
     @account_initialization_required
     def get(self, provider):
-        user_id = current_user.id
         tenant_id = current_user.current_tenant_id
 
         return BuiltinToolManageService.get_builtin_tool_provider_credentials(
-            user_id,
-            tenant_id,
-            provider,
+            tenant_id=tenant_id,
+            provider_name=provider,
         )
 
 

+ 16 - 17
api/services/tools/builtin_tools_manage_service.py

@@ -2,6 +2,9 @@ import json
 import logging
 from pathlib import Path
 
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
 from configs import dify_config
 from core.helper.position_helper import is_filtered
 from core.model_runtime.utils.encoders import jsonable_encoder
@@ -32,7 +35,7 @@ class BuiltinToolManageService:
             tenant_id=tenant_id, provider_controller=provider_controller
         )
         # check if user has added the provider
-        builtin_provider: BuiltinToolProvider = (
+        builtin_provider = (
             db.session.query(BuiltinToolProvider)
             .filter(
                 BuiltinToolProvider.tenant_id == tenant_id,
@@ -71,19 +74,18 @@ class BuiltinToolManageService:
         return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
 
     @staticmethod
-    def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
+    def update_builtin_tool_provider(
+        session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
+    ):
         """
         update builtin tool provider
         """
         # get if the provider exists
-        provider: BuiltinToolProvider = (
-            db.session.query(BuiltinToolProvider)
-            .filter(
-                BuiltinToolProvider.tenant_id == tenant_id,
-                BuiltinToolProvider.provider == provider_name,
-            )
-            .first()
+        stmt = select(BuiltinToolProvider).where(
+            BuiltinToolProvider.tenant_id == tenant_id,
+            BuiltinToolProvider.provider == provider_name,
         )
+        provider = session.scalar(stmt)
 
         try:
             # get provider
@@ -115,13 +117,10 @@ class BuiltinToolManageService:
                 encrypted_credentials=json.dumps(credentials),
             )
 
-            db.session.add(provider)
-            db.session.commit()
+            session.add(provider)
 
         else:
             provider.encrypted_credentials = json.dumps(credentials)
-            db.session.add(provider)
-            db.session.commit()
 
             # delete cache
             tool_configuration.delete_tool_credentials_cache()
@@ -129,15 +128,15 @@ class BuiltinToolManageService:
         return {"result": "success"}
 
     @staticmethod
-    def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
+    def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
         """
         get builtin tool provider credentials
         """
-        provider: BuiltinToolProvider = (
+        provider = (
             db.session.query(BuiltinToolProvider)
             .filter(
                 BuiltinToolProvider.tenant_id == tenant_id,
-                BuiltinToolProvider.provider == provider,
+                BuiltinToolProvider.provider == provider_name,
             )
             .first()
         )
@@ -156,7 +155,7 @@ class BuiltinToolManageService:
         """
         delete tool provider
         """
-        provider: BuiltinToolProvider = (
+        provider = (
             db.session.query(BuiltinToolProvider)
             .filter(
                 BuiltinToolProvider.tenant_id == tenant_id,