瀏覽代碼

fix: buitin tool aippt (#10234)

Co-authored-by: jinqi.guo <jinqi.guo@ubtrobot.com>
guogeer 5 月之前
父節點
當前提交
971defbbbd
共有 2 個文件被更改,包括 50 次插入30 次删除
  1. 49 29
      api/core/tools/provider/builtin/aippt/tools/aippt.py
  2. 1 1
      api/core/workflow/nodes/tool/tool_node.py

+ 49 - 29
api/core/tools/provider/builtin/aippt/tools/aippt.py

@@ -4,7 +4,7 @@ from hmac import new as hmac_new
 from json import loads as json_loads
 from threading import Lock
 from time import sleep, time
-from typing import Any, Optional
+from typing import Any
 
 from httpx import get, post
 from requests import get as requests_get
@@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter,
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
-class AIPPTGenerateTool(BuiltinTool):
+class AIPPTGenerateToolAdapter:
     """
     A tool for generating a ppt
     """
 
     _api_base_url = URL("https://co.aippt.cn/api")
     _api_token_cache = {}
-    _api_token_cache_lock: Optional[Lock] = None
     _style_cache = {}
-    _style_cache_lock: Optional[Lock] = None
+
+    _api_token_cache_lock = Lock()
+    _style_cache_lock = Lock()
 
     _task = {}
     _task_type_map = {
         "auto": 1,
         "markdown": 7,
     }
+    _tool: BuiltinTool
 
-    def __init__(self, **kwargs: Any):
-        super().__init__(**kwargs)
-        self._api_token_cache_lock = Lock()
-        self._style_cache_lock = Lock()
+    def __init__(self, tool: BuiltinTool = None):
+        self._tool = tool
 
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
         """
@@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool):
         """
         title = tool_parameters.get("title", "")
         if not title:
-            return self.create_text_message("Please provide a title for the ppt")
+            return self._tool.create_text_message("Please provide a title for the ppt")
 
         model = tool_parameters.get("model", "aippt")
         if not model:
-            return self.create_text_message("Please provide a model for the ppt")
+            return self._tool.create_text_message("Please provide a model for the ppt")
 
         outline = tool_parameters.get("outline", "")
 
@@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool):
         )
 
         # get suit
-        color = tool_parameters.get("color")
-        style = tool_parameters.get("style")
+        color: str = tool_parameters.get("color")
+        style: str = tool_parameters.get("style")
 
         if color == "__default__":
             color_id = ""
@@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool):
         # generate ppt
         _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
 
-        return self.create_text_message(
+        return self._tool.create_text_message(
             """the ppt has been created successfully,"""
-            f"""the ppt url is {ppt_url}"""
+            f"""the ppt url is {ppt_url} ."""
             """please give the ppt url to user and direct user to download it."""
         )
 
@@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool):
         """
         headers = {
             "x-channel": "",
-            "x-api-key": self.runtime.credentials["aippt_access_key"],
-            "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+            "x-api-key": self._tool.runtime.credentials["aippt_access_key"],
+            "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
         }
         response = post(
             str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
@@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool):
 
         headers = {
             "x-channel": "",
-            "x-api-key": self.runtime.credentials["aippt_access_key"],
-            "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+            "x-api-key": self._tool.runtime.credentials["aippt_access_key"],
+            "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
         }
 
         response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
@@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool):
 
         headers = {
             "x-channel": "",
-            "x-api-key": self.runtime.credentials["aippt_access_key"],
-            "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+            "x-api-key": self._tool.runtime.credentials["aippt_access_key"],
+            "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
         }
 
         response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
@@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool):
         """
         headers = {
             "x-channel": "",
-            "x-api-key": self.runtime.credentials["aippt_access_key"],
-            "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
+            "x-api-key": self._tool.runtime.credentials["aippt_access_key"],
+            "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
         }
 
         response = post(
             str(self._api_base_url / "design" / "v2" / "save"),
             headers=headers,
             data={"task_id": task_id, "template_id": suit_id},
+            timeout=(10, 60),
         )
 
         if response.status_code != 200:
@@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool):
 
         return token
 
-    @classmethod
-    def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
+    @staticmethod
+    def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
         return b64encode(
             hmac_new(
-                key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
+                key=secret_key.encode("utf-8"),
+                msg=f"GET@/api/grant/token/@{timestamp}".encode(),
+                digestmod=sha1,
             ).digest()
         ).decode("utf-8")
 
@@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool):
         :param credentials: the credentials
         :return: Tuple[list[dict[id, color]], list[dict[id, style]]
         """
-        if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
+        if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get(
+            "aippt_secret_key"
+        ):
             raise Exception("Please provide aippt credentials")
 
-        return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
+        return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id)
 
     def _get_suit(self, style_id: int, colour_id: int) -> int:
         """
@@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool):
         """
         headers = {
             "x-channel": "",
-            "x-api-key": self.runtime.credentials["aippt_access_key"],
-            "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
+            "x-api-key": self._tool.runtime.credentials["aippt_access_key"],
+            "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"),
         }
         response = get(
             str(self._api_base_url / "template_component" / "suit" / "search"),
@@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool):
                 ],
             ),
         ]
+
+
+class AIPPTGenerateTool(BuiltinTool):
+    def __init__(self, **kwargs: Any):
+        super().__init__(**kwargs)
+
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+        return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
+
+    def get_runtime_parameters(self) -> list[ToolParameter]:
+        return AIPPTGenerateToolAdapter(self).get_runtime_parameters()
+
+    @classmethod
+    def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
+        return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id)

+ 1 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -53,7 +53,7 @@ class ToolNode(BaseNode[ToolNodeData]):
             )
 
         # get parameters
-        tool_parameters = tool_runtime.get_runtime_parameters() or []
+        tool_parameters = tool_runtime.parameters or []
         parameters = self._generate_parameters(
             tool_parameters=tool_parameters,
             variable_pool=self.graph_runtime_state.variable_pool,