Kaynağa Gözat

refactor(api): improve handling of `tools` field and cleanup variable usage (#10553)

-LAN- 5 ay önce
ebeveyn
işleme
16b9665033

+ 7 - 2
api/core/tools/entities/api_entities.py

@@ -1,6 +1,6 @@
 from typing import Literal, Optional
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field, field_validator
 
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.common_entities import I18nObject
@@ -32,9 +32,14 @@ class UserToolProvider(BaseModel):
     original_credentials: Optional[dict] = None
     is_team_authorization: bool = False
     allow_delete: bool = True
-    tools: list[UserTool] | None = None
+    tools: list[UserTool] = Field(default_factory=list)
     labels: list[str] | None = None
 
+    @field_validator("tools", mode="before")
+    @classmethod
+    def convert_none_to_empty_list(cls, v):
+        return v if v is not None else []
+
     def to_dict(self) -> dict:
         # -------------
         # overwrite tool parameter types for temp fix

+ 7 - 8
api/services/tools/api_tools_manage_service.py

@@ -116,7 +116,7 @@ class ApiToolManageService:
         provider_name = provider_name.strip()
 
         # check if the provider exists
-        provider: ApiToolProvider = (
+        provider = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -201,16 +201,15 @@ class ApiToolManageService:
         return {"schema": schema}
 
     @staticmethod
-    def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
+    def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
         """
         list api tool provider tools
         """
-        provider_name = provider
-        provider: ApiToolProvider = (
+        provider = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
-                ApiToolProvider.name == provider,
+                ApiToolProvider.name == provider_name,
             )
             .first()
         )
@@ -252,7 +251,7 @@ class ApiToolManageService:
         provider_name = provider_name.strip()
 
         # check if the provider exists
-        provider: ApiToolProvider = (
+        provider = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -319,7 +318,7 @@ class ApiToolManageService:
         """
         delete tool provider
         """
-        provider: ApiToolProvider = (
+        provider = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -369,7 +368,7 @@ class ApiToolManageService:
         if tool_bundle is None:
             raise ValueError(f"invalid tool name {tool_name}")
 
-        db_provider: ApiToolProvider = (
+        db_provider = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,