Parcourir la source

refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444)

-LAN- il y a 8 mois
Parent
commit
a10b207de2

+ 27 - 36
api/core/app/app_config/easy_ui_based_app/variables/manager.py

@@ -1,6 +1,6 @@
 import re
 
-from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
+from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
 from core.external_data_tool.factory import ExternalDataToolFactory
 
 
@@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
         :param config: model config args
         """
         external_data_variables = []
-        variables = []
+        variable_entities = []
 
         # old external_data_tools
         external_data_tools = config.get('external_data_tools', [])
@@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
             )
 
         # variables and external_data_tools
-        for variable in config.get('user_input_form', []):
-            typ = list(variable.keys())[0]
-            if typ == 'external_data_tool':
-                val = variable[typ]
-                if 'config' not in val:
+        for variables in config.get('user_input_form', []):
+            variable_type = list(variables.keys())[0]
+            if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
+                variable = variables[variable_type]
+                if 'config' not in variable:
                     continue
 
                 external_data_variables.append(
                     ExternalDataVariableEntity(
-                        variable=val['variable'],
-                        type=val['type'],
-                        config=val['config']
+                        variable=variable['variable'],
+                        type=variable['type'],
+                        config=variable['config']
                     )
                 )
-            elif typ in [
-                VariableEntity.Type.TEXT_INPUT.value,
-                VariableEntity.Type.PARAGRAPH.value,
-                VariableEntity.Type.NUMBER.value,
+            elif variable_type in [
+                VariableEntityType.TEXT_INPUT,
+                VariableEntityType.PARAGRAPH,
+                VariableEntityType.NUMBER,
+                VariableEntityType.SELECT,
             ]:
-                variables.append(
-                    VariableEntity(
-                        type=VariableEntity.Type.value_of(typ),
-                        variable=variable[typ].get('variable'),
-                        description=variable[typ].get('description'),
-                        label=variable[typ].get('label'),
-                        required=variable[typ].get('required', False),
-                        max_length=variable[typ].get('max_length'),
-                        default=variable[typ].get('default'),
-                    )
-                )
-            elif typ == VariableEntity.Type.SELECT.value:
-                variables.append(
+                variable = variables[variable_type]
+                variable_entities.append(
                     VariableEntity(
-                        type=VariableEntity.Type.SELECT,
-                        variable=variable[typ].get('variable'),
-                        description=variable[typ].get('description'),
-                        label=variable[typ].get('label'),
-                        required=variable[typ].get('required', False),
-                        options=variable[typ].get('options'),
-                        default=variable[typ].get('default'),
+                        type=variable_type,
+                        variable=variable.get('variable'),
+                        description=variable.get('description'),
+                        label=variable.get('label'),
+                        required=variable.get('required', False),
+                        max_length=variable.get('max_length'),
+                        options=variable.get('options'),
+                        default=variable.get('default'),
                     )
                 )
 
-        return variables, external_data_variables
+        return variable_entities, external_data_variables
 
     @classmethod
     def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
@@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
                 config=config
             )
 
-        return config, ["external_data_tools"]
+        return config, ["external_data_tools"]

+ 10 - 24
api/core/app/app_config/entities.py

@@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
     advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
 
 
+class VariableEntityType(str, Enum):
+    TEXT_INPUT = "text-input"
+    SELECT = "select"
+    PARAGRAPH = "paragraph"
+    NUMBER = "number"
+    EXTERNAL_DATA_TOOL = "external-data-tool"
+
+
 class VariableEntity(BaseModel):
     """
     Variable Entity.
     """
-    class Type(Enum):
-        TEXT_INPUT = 'text-input'
-        SELECT = 'select'
-        PARAGRAPH = 'paragraph'
-        NUMBER = 'number'
-
-        @classmethod
-        def value_of(cls, value: str) -> 'VariableEntity.Type':
-            """
-            Get value of given mode.
-
-            :param value: mode value
-            :return: mode
-            """
-            for mode in cls:
-                if mode.value == value:
-                    return mode
-            raise ValueError(f'invalid variable type value {value}')
 
     variable: str
     label: str
     description: Optional[str] = None
-    type: Type
+    type: VariableEntityType
     required: bool = False
     max_length: Optional[int] = None
     options: Optional[list[str]] = None
     default: Optional[str] = None
     hint: Optional[str] = None
 
-    @property
-    def name(self) -> str:
-        return self.variable
-
 
 class ExternalDataVariableEntity(BaseModel):
     """
@@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
     """
     Workflow UI Based App Config Entity.
     """
-    workflow_id: str
+    workflow_id: str

+ 14 - 14
api/core/app/apps/base_app_generator.py

@@ -1,7 +1,7 @@
 from collections.abc import Mapping
 from typing import Any, Optional
 
-from core.app.app_config.entities import AppConfig, VariableEntity
+from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
 
 
 class BaseAppGenerator:
@@ -9,29 +9,29 @@ class BaseAppGenerator:
         user_inputs = user_inputs or {}
         # Filter input variables from form configuration, handle required fields, default values, and option values
         variables = app_config.variables
-        filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
+        filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
         filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
         return filtered_inputs
 
     def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
-        user_input_value = inputs.get(var.name)
+        user_input_value = inputs.get(var.variable)
         if var.required and not user_input_value:
-            raise ValueError(f'{var.name} is required in input form')
+            raise ValueError(f'{var.variable} is required in input form')
         if not var.required and not user_input_value:
             # TODO: should we return None here if the default value is None?
             return var.default or ''
         if (
             var.type
             in (
-                VariableEntity.Type.TEXT_INPUT,
-                VariableEntity.Type.SELECT,
-                VariableEntity.Type.PARAGRAPH,
+                VariableEntityType.TEXT_INPUT,
+                VariableEntityType.SELECT,
+                VariableEntityType.PARAGRAPH,
             )
             and user_input_value
             and not isinstance(user_input_value, str)
         ):
-            raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
-        if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
+            raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
+        if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
             # may raise ValueError if user_input_value is not a valid number
             try:
                 if '.' in user_input_value:
@@ -39,14 +39,14 @@ class BaseAppGenerator:
                 else:
                     return int(user_input_value)
             except ValueError:
-                raise ValueError(f"{var.name} in input form must be a valid number")
-        if var.type == VariableEntity.Type.SELECT:
+                raise ValueError(f"{var.variable} in input form must be a valid number")
+        if var.type == VariableEntityType.SELECT:
             options = var.options or []
             if user_input_value not in options:
-                raise ValueError(f'{var.name} in input form must be one of the following: {options}')
-        elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
+                raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
+        elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
             if var.max_length and user_input_value and len(user_input_value) > var.max_length:
-                raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
+                raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
 
         return user_input_value
 

+ 19 - 24
api/core/tools/provider/workflow_tool_provider.py

@@ -1,6 +1,6 @@
 from typing import Optional
 
-from core.app.app_config.entities import VariableEntity
+from core.app.app_config.entities import VariableEntity, VariableEntityType
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import (
@@ -18,6 +18,13 @@ from models.model import App, AppMode
 from models.tools import WorkflowToolProvider
 from models.workflow import Workflow
 
+VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
+    VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
+    VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
+    VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
+    VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
+}
+
 
 class WorkflowToolProviderController(ToolProviderController):
     provider_id: str
@@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
 
         if not app:
             raise ValueError('app not found')
-        
+
         controller = WorkflowToolProviderController(**{
             'identity': {
                 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
@@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
             'credentials_schema': {},
             'provider_id': db_provider.id or '',
         })
-        
+
         # init tools
 
         controller.tools = [controller._get_db_provider_tool(db_provider, app)]
@@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
     @property
     def provider_type(self) -> ToolProviderType:
         return ToolProviderType.WORKFLOW
-    
+
     def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
         """
             get db provider tool
@@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
             if variable:
                 parameter_type = None
                 options = None
-                if variable.type in [
-                    VariableEntity.Type.TEXT_INPUT, 
-                    VariableEntity.Type.PARAGRAPH, 
-                ]:
-                    parameter_type = ToolParameter.ToolParameterType.STRING
-                elif variable.type in [
-                    VariableEntity.Type.SELECT
-                ]:
-                    parameter_type = ToolParameter.ToolParameterType.SELECT
-                elif variable.type in [
-                    VariableEntity.Type.NUMBER
-                ]:
-                    parameter_type = ToolParameter.ToolParameterType.NUMBER
-                else:
+                if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
                     raise ValueError(f'unsupported variable type {variable.type}')
-                
-                if variable.type == VariableEntity.Type.SELECT and variable.options:
+                parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
+
+                if variable.type == VariableEntityType.SELECT and variable.options:
                     options = [
                         ToolParameterOption(
                             value=option,
@@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
         """
         if self.tools is not None:
             return self.tools
-        
+
         db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
             WorkflowToolProvider.tenant_id == tenant_id,
             WorkflowToolProvider.app_id == self.provider_id,
@@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
 
         if not db_providers:
             return []
-        
+
         self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
 
         return self.tools
-    
+
     def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
         """
             get tool by name
@@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
         for tool in self.tools:
             if tool.identity.name == tool_name:
                 return tool
-        
+
         return None

+ 5 - 1
api/core/workflow/nodes/start/entities.py

@@ -1,3 +1,7 @@
+from collections.abc import Sequence
+
+from pydantic import Field
+
 from core.app.app_config.entities import VariableEntity
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 
@@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
     """
     Start Node Data
     """
-    variables: list[VariableEntity] = []
+    variables: Sequence[VariableEntity] = Field(default_factory=list)

+ 7 - 5
api/tests/unit_tests/services/workflow/test_workflow_converter.py

@@ -14,6 +14,7 @@ from core.app.app_config.entities import (
     ModelConfigEntity,
     PromptTemplateEntity,
     VariableEntity,
+    VariableEntityType,
 )
 from core.helper import encrypter
 from core.model_runtime.entities.llm_entities import LLMMode
@@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter
 
 @pytest.fixture
 def default_variables():
-    return [
+    value = [
         VariableEntity(
             variable="text_input",
             label="text-input",
-            type=VariableEntity.Type.TEXT_INPUT
+            type=VariableEntityType.TEXT_INPUT,
         ),
         VariableEntity(
             variable="paragraph",
             label="paragraph",
-            type=VariableEntity.Type.PARAGRAPH
+            type=VariableEntityType.PARAGRAPH,
         ),
         VariableEntity(
             variable="select",
             label="select",
-            type=VariableEntity.Type.SELECT
-        )
+            type=VariableEntityType.SELECT,
+        ),
     ]
+    return value
 
 
 def test__convert_to_start_node(default_variables):