Parcourir la source

refactor(code_executor): update input type annotations to use Mapping for better type safety (#10478)

-LAN- il y a 5 mois
Parent
commit
b8b6cd409a

+ 3 - 0
api/core/helper/code_executor/__init__.py

@@ -0,0 +1,3 @@
+from .code_executor import CodeExecutor, CodeLanguage
+
+__all__ = ["CodeExecutor", "CodeLanguage"]

+ 3 - 2
api/core/helper/code_executor/code_executor.py

@@ -1,7 +1,8 @@
 import logging
+from collections.abc import Mapping
 from enum import Enum
 from threading import Lock
-from typing import Optional
+from typing import Any, Optional
 
 from httpx import Timeout, post
 from pydantic import BaseModel
@@ -117,7 +118,7 @@ class CodeExecutor:
         return response.data.stdout or ""
 
     @classmethod
-    def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict:
+    def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
         """
         Execute code
         :param language: code language

+ 5 - 3
api/core/helper/code_executor/template_transformer.py

@@ -2,6 +2,8 @@ import json
 import re
 from abc import ABC, abstractmethod
 from base64 import b64encode
+from collections.abc import Mapping
+from typing import Any
 
 
 class TemplateTransformer(ABC):
@@ -10,7 +12,7 @@ class TemplateTransformer(ABC):
     _result_tag: str = "<<RESULT>>"
 
     @classmethod
-    def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
+    def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
         """
         Transform code to python runner
         :param code: code
@@ -48,13 +50,13 @@ class TemplateTransformer(ABC):
         pass
 
     @classmethod
-    def serialize_inputs(cls, inputs: dict) -> str:
+    def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
         inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
         input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
         return input_base64_encoded
 
     @classmethod
-    def assemble_runner_script(cls, code: str, inputs: dict) -> str:
+    def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
         # assemble runner script
         script = cls.get_runner_script()
         script = script.replace(cls._code_placeholder, code)