Explorar o código

improve: exract Code Node provider for each supported scripting language (#4164)

Bowen Liang hai 11 meses
pai
achega
083ef2e6fc

+ 0 - 0
api/core/helper/code_executor/__init__.py


+ 15 - 6
api/core/helper/code_executor/code_executor.py

@@ -10,9 +10,10 @@ from yarl import URL
 
 from config import get_env
 from core.helper.code_executor.entities import CodeDependency
-from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
-from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
-from core.helper.code_executor.python_transformer import PYTHON_STANDARD_PACKAGES, PythonTemplateTransformer
+from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
+from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
+from core.helper.code_executor.python3.python3_transformer import PYTHON_STANDARD_PACKAGES, Python3TemplateTransformer
+from core.helper.code_executor.template_transformer import TemplateTransformer
 
 logger = logging.getLogger(__name__)
 
@@ -34,6 +35,7 @@ class CodeExecutionResponse(BaseModel):
     message: str
     data: Data
 
+
 class CodeLanguage(str, Enum):
     PYTHON3 = 'python3'
     JINJA2 = 'jinja2'
@@ -44,8 +46,8 @@ class CodeExecutor:
     dependencies_cache = {}
     dependencies_cache_lock = Lock()
 
-    code_template_transformers = {
-        CodeLanguage.PYTHON3: PythonTemplateTransformer,
+    code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = {
+        CodeLanguage.PYTHON3: Python3TemplateTransformer,
         CodeLanguage.JINJA2: Jinja2TemplateTransformer,
         CodeLanguage.JAVASCRIPT: NodeJsTemplateTransformer,
     }
@@ -56,6 +58,10 @@ class CodeExecutor:
         CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
     }
 
+    supported_dependencies_languages: set[CodeLanguage] = {
+        CodeLanguage.PYTHON3
+    }
+
     @classmethod
     def execute_code(cls, 
                      language: Literal['python3', 'javascript', 'jinja2'], 
@@ -133,7 +139,10 @@ class CodeExecutor:
         return template_transformer.transform_response(response)
     
     @classmethod
-    def list_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
+    def list_dependencies(cls, language: str) -> list[CodeDependency]:
+        if language not in cls.supported_dependencies_languages:
+            return []
+
         with cls.dependencies_cache_lock:
             if language in cls.dependencies_cache:
                 # check expiration

+ 55 - 0
api/core/helper/code_executor/code_node_provider.py

@@ -0,0 +1,55 @@
+from abc import abstractmethod
+
+from pydantic import BaseModel
+
+from core.helper.code_executor.code_executor import CodeExecutor
+
+
+class CodeNodeProvider(BaseModel):
+    @staticmethod
+    @abstractmethod
+    def get_language() -> str:
+        pass
+
+    @classmethod
+    def is_accept_language(cls, language: str) -> bool:
+        return language == cls.get_language()
+
+    @classmethod
+    @abstractmethod
+    def get_default_code(cls) -> str:
+        """
+        get default code in specific programming language for the code node
+        """
+        pass
+
+    @classmethod
+    def get_default_available_packages(cls) -> list[dict]:
+        return [p.dict() for p in CodeExecutor.list_dependencies(cls.get_language())]
+
+    @classmethod
+    def get_default_config(cls) -> dict:
+        return {
+            "type": "code",
+            "config": {
+                "variables": [
+                    {
+                        "variable": "arg1",
+                        "value_selector": []
+                    },
+                    {
+                        "variable": "arg2",
+                        "value_selector": []
+                    }
+                ],
+                "code_language": cls.get_language(),
+                "code": cls.get_default_code(),
+                "outputs": {
+                    "result": {
+                        "type": "string",
+                        "children": None
+                    }
+                }
+            },
+            "available_dependencies": cls.get_default_available_packages(),
+        }

+ 1 - 1
api/core/helper/code_executor/entities.py

@@ -3,4 +3,4 @@ from pydantic import BaseModel
 
 class CodeDependency(BaseModel):
     name: str
-    version: str
+    version: str

+ 0 - 0
api/core/helper/code_executor/javascript/__init__.py


+ 21 - 0
api/core/helper/code_executor/javascript/javascript_code_provider.py

@@ -0,0 +1,21 @@
+from textwrap import dedent
+
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+
+
+class JavascriptCodeProvider(CodeNodeProvider):
+    @staticmethod
+    def get_language() -> str:
+        return CodeLanguage.JAVASCRIPT
+
+    @classmethod
+    def get_default_code(cls) -> str:
+        return dedent(
+            """
+            function main({arg1, arg2}) {
+                return {
+                    result: arg1 + arg2
+                }
+            }
+            """)

+ 1 - 0
api/core/helper/code_executor/javascript_transformer.py → api/core/helper/code_executor/javascript/javascript_transformer.py

@@ -22,6 +22,7 @@ console.log(result)
 
 NODEJS_PRELOAD = """"""
 
+
 class NodeJsTemplateTransformer(TemplateTransformer):
     @classmethod
     def transform_caller(cls, code: str, inputs: dict, 

+ 0 - 0
api/core/helper/code_executor/jinja2/__init__.py


+ 3 - 3
api/core/helper/code_executor/jinja2_formatter.py → api/core/helper/code_executor/jinja2/jinja2_formatter.py

@@ -1,4 +1,4 @@
-from core.helper.code_executor.code_executor import CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
 
 
 class Jinja2Formatter:
@@ -11,7 +11,7 @@ class Jinja2Formatter:
         :return:
         """
         result = CodeExecutor.execute_workflow_code_template(
-            language='jinja2', code=template, inputs=inputs
+            language=CodeLanguage.JINJA2, code=template, inputs=inputs
         )
 
-        return result['result']
+        return result['result']

+ 1 - 1
api/core/helper/code_executor/jinja2_transformer.py → api/core/helper/code_executor/jinja2/jinja2_transformer.py

@@ -4,7 +4,7 @@ from base64 import b64encode
 from typing import Optional
 
 from core.helper.code_executor.entities import CodeDependency
-from core.helper.code_executor.python_transformer import PYTHON_STANDARD_PACKAGES
+from core.helper.code_executor.python3.python3_transformer import PYTHON_STANDARD_PACKAGES
 from core.helper.code_executor.template_transformer import TemplateTransformer
 
 PYTHON_RUNNER = """

+ 0 - 0
api/core/helper/code_executor/python3/__init__.py


+ 20 - 0
api/core/helper/code_executor/python3/python3_code_provider.py

@@ -0,0 +1,20 @@
+from textwrap import dedent
+
+from core.helper.code_executor.code_executor import CodeLanguage
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+
+
+class Python3CodeProvider(CodeNodeProvider):
+    @staticmethod
+    def get_language() -> str:
+        return CodeLanguage.PYTHON3
+
+    @classmethod
+    def get_default_code(cls) -> str:
+        return dedent(
+            """
+            def main(arg1: int, arg2: int) -> dict:
+                return {
+                    "result": arg1 + arg2,
+                }
+            """)

+ 9 - 6
api/core/helper/code_executor/python_transformer.py → api/core/helper/code_executor/python3/python3_transformer.py

@@ -1,12 +1,14 @@
 import json
 import re
 from base64 import b64encode
+from textwrap import dedent
 from typing import Optional
 
 from core.helper.code_executor.entities import CodeDependency
 from core.helper.code_executor.template_transformer import TemplateTransformer
 
-PYTHON_RUNNER = """# declare main function here
+PYTHON_RUNNER = dedent("""
+# declare main function here
 {{code}}
 
 from json import loads, dumps
@@ -25,16 +27,17 @@ result = f'''<<RESULT>>
 <<RESULT>>'''
 
 print(result)
-"""
+""")
 
 PYTHON_PRELOAD = """"""
 
-PYTHON_STANDARD_PACKAGES = set([
+PYTHON_STANDARD_PACKAGES = {
     'json', 'datetime', 'math', 'random', 're', 'string', 'sys', 'time', 'traceback', 'uuid', 'os', 'base64',
-    'hashlib', 'hmac', 'binascii', 'collections', 'functools', 'operator', 'itertools', 'uuid', 
-])
+    'hashlib', 'hmac', 'binascii', 'collections', 'functools', 'operator', 'itertools', 'uuid',
+}
 
-class PythonTemplateTransformer(TemplateTransformer):
+
+class Python3TemplateTransformer(TemplateTransformer):
     @classmethod
     def transform_caller(cls, code: str, inputs: dict, 
                          dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:

+ 1 - 1
api/core/prompt/advanced_prompt_transform.py

@@ -2,7 +2,7 @@ from typing import Optional, Union
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.file.file_obj import FileVar
-from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
+from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,

+ 11 - 62
api/core/workflow/nodes/code/code_node.py

@@ -2,7 +2,9 @@ import os
 from typing import Optional, Union, cast
 
 from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
-from core.model_runtime.utils.encoders import jsonable_encoder
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
+from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
@@ -18,16 +20,6 @@ MAX_STRING_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_STRING_ARRAY_LENGTH', '30
 MAX_OBJECT_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_OBJECT_ARRAY_LENGTH', '30'))
 MAX_NUMBER_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_NUMBER_ARRAY_LENGTH', '1000'))
 
-JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) {
-    return {
-        result: arg1 + arg2
-    }
-}"""
-
-PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict:
-    return {
-        "result": arg1 + arg2,
-    }"""
 
 class CodeNode(BaseNode):
     _node_data_cls = CodeNodeData
@@ -40,58 +32,15 @@ class CodeNode(BaseNode):
         :param filters: filter by node config parameters.
         :return:
         """
-        if filters and filters.get("code_language") == CodeLanguage.JAVASCRIPT:
-            return {
-                "type": "code",
-                "config": {
-                    "variables": [
-                        {
-                            "variable": "arg1",
-                            "value_selector": []
-                        },
-                        {
-                            "variable": "arg2",
-                            "value_selector": []
-                        }
-                    ],
-                    "code_language": CodeLanguage.JAVASCRIPT,
-                    "code": JAVASCRIPT_DEFAULT_CODE,
-                    "outputs": {
-                        "result": {
-                            "type": "string",
-                            "children": None
-                        }
-                    }
-                },
-                "available_dependencies": []
-            }
+        code_language = CodeLanguage.PYTHON3
+        if filters:
+            code_language = (filters.get("code_language", CodeLanguage.PYTHON3))
 
-        return {
-            "type": "code",
-            "config": {
-                "variables": [
-                    {
-                        "variable": "arg1",
-                        "value_selector": []
-                    },
-                    {
-                        "variable": "arg2",
-                        "value_selector": []
-                    }
-                ],
-                "code_language": CodeLanguage.PYTHON3,
-                "code": PYTHON_DEFAULT_CODE,
-                "outputs": {
-                    "result": {
-                        "type": "string",
-                        "children": None
-                    }
-                },
-                "dependencies": [
-                ]
-            },
-            "available_dependencies": jsonable_encoder(CodeExecutor.list_dependencies('python3'))
-        }
+        providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
+        code_provider: type[CodeNodeProvider] = next(p for p in providers
+                                                     if p.is_accept_language(code_language))
+
+        return code_provider.get_default_config()
 
     def _run(self, variable_pool: VariablePool) -> NodeRunResult:
         """

+ 2 - 1
api/core/workflow/nodes/code/entities.py

@@ -2,6 +2,7 @@ from typing import Literal, Optional
 
 from pydantic import BaseModel
 
+from core.helper.code_executor.code_executor import CodeLanguage
 from core.helper.code_executor.entities import CodeDependency
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.variable_entities import VariableSelector
@@ -16,7 +17,7 @@ class CodeNodeData(BaseNodeData):
         children: Optional[dict[str, 'Output']]
 
     variables: list[VariableSelector]
-    code_language: Literal['python3', 'javascript']
+    code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
     code: str
     outputs: dict[str, Output]
     dependencies: Optional[list[CodeDependency]] = None

+ 2 - 2
api/core/workflow/nodes/template_transform/template_transform_node.py

@@ -1,7 +1,7 @@
 import os
 from typing import Optional, cast
 
-from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
@@ -53,7 +53,7 @@ class TemplateTransformNode(BaseNode):
         # Run code
         try:
             result = CodeExecutor.execute_workflow_code_template(
-                language='jinja2',
+                language=CodeLanguage.JINJA2,
                 code=node_data.template,
                 inputs=variables
             )

+ 10 - 9
api/tests/integration_tests/workflow/nodes/__mock/code_executor.py

@@ -5,7 +5,7 @@ import pytest
 from _pytest.monkeypatch import MonkeyPatch
 from jinja2 import Template
 
-from core.helper.code_executor.code_executor import CodeExecutor
+from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
 from core.helper.code_executor.entities import CodeDependency
 
 MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
@@ -15,14 +15,15 @@ class MockedCodeExecutor:
     def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], 
                code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
         # invoke directly
-        if language == 'python3':
-            return {
-                "result": 3
-            }
-        elif language == 'jinja2':
-            return {
-                "result": Template(code).render(inputs)
-            }
+        match language:
+            case CodeLanguage.PYTHON3:
+                return {
+                    "result": 3
+                }
+            case CodeLanguage.JINJA2:
+                return {
+                    "result": Template(code).render(inputs)
+                }
 
 @pytest.fixture
 def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):

+ 9 - 2
api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py

@@ -1,7 +1,7 @@
 from textwrap import dedent
 
 from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
-from core.workflow.nodes.code.code_node import JAVASCRIPT_DEFAULT_CODE
+from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
 
 CODE_LANGUAGE = CodeLanguage.JAVASCRIPT
 
@@ -23,5 +23,12 @@ def test_javascript_json():
 
 def test_javascript_with_code_template():
     result = CodeExecutor.execute_workflow_code_template(
-        language=CODE_LANGUAGE, code=JAVASCRIPT_DEFAULT_CODE, inputs={'arg1': 'Hello', 'arg2': 'World'})
+        language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'})
     assert result == {'result': 'HelloWorld'}
+
+
+def test_javascript_list_default_available_packages():
+    packages = JavascriptCodeProvider.get_default_available_packages()
+
+    # no default packages available for javascript
+    assert len(packages) == 0

+ 1 - 1
api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py

@@ -1,7 +1,7 @@
 import base64
 
 from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
-from core.helper.code_executor.jinja2_transformer import JINJA2_PRELOAD, PYTHON_RUNNER
+from core.helper.code_executor.jinja2.jinja2_transformer import JINJA2_PRELOAD, PYTHON_RUNNER
 
 CODE_LANGUAGE = CodeLanguage.JINJA2
 

+ 12 - 2
api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py

@@ -1,7 +1,8 @@
+import json
 from textwrap import dedent
 
 from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
-from core.workflow.nodes.code.code_node import PYTHON_DEFAULT_CODE
+from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
 
 CODE_LANGUAGE = CodeLanguage.PYTHON3
 
@@ -23,5 +24,14 @@ def test_python3_json():
 
 def test_python3_with_code_template():
     result = CodeExecutor.execute_workflow_code_template(
-        language=CODE_LANGUAGE, code=PYTHON_DEFAULT_CODE, inputs={'arg1': 'Hello', 'arg2': 'World'})
+        language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'})
     assert result == {'result': 'HelloWorld'}
+
+
+def test_python3_list_default_available_packages():
+    packages = Python3CodeProvider.get_default_available_packages()
+    assert len(packages) > 0
+    assert {'requests', 'httpx'}.issubset(p['name'] for p in packages)
+
+    # check JSON serializable
+    assert len(str(json.dumps(packages))) > 0