Sfoglia il codice sorgente

improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031)

Bowen Liang 11 mesi fa
parent
commit
3fda2245a4

+ 2 - 4
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -3,8 +3,6 @@ import os
 from abc import ABC, abstractmethod
 from typing import Optional
 
-import yaml
-
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
 from core.model_runtime.entities.model_entities import (
@@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import (
 )
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
+from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.position_helper import get_position_map, sort_by_position_map
 
 
@@ -154,8 +153,7 @@ class AIModel(ABC):
         # traverse all model_schema_yaml_paths
         for model_schema_yaml_path in model_schema_yaml_paths:
             # read yaml data from yaml file
-            with open(model_schema_yaml_path, encoding='utf-8') as f:
-                yaml_data = yaml.safe_load(f)
+            yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
 
             new_parameter_rules = []
             for parameter_rule in yaml_data.get('parameter_rules', []):

+ 2 - 6
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -1,11 +1,10 @@
 import os
 from abc import ABC, abstractmethod
 
-import yaml
-
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.model_runtime.model_providers.__base.ai_model import AIModel
+from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
 
 
@@ -44,10 +43,7 @@ class ModelProvider(ABC):
 
         # read provider schema from yaml file
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
-        yaml_data = {}
-        if os.path.exists(yaml_path):
-            with open(yaml_path, encoding='utf-8') as f:
-                yaml_data = yaml.safe_load(f)
+        yaml_data = load_yaml_file(yaml_path, ignore_error=True)
 
         try:
             # yaml_data to entity

+ 16 - 18
api/core/tools/provider/builtin_tool_provider.py

@@ -2,8 +2,6 @@ from abc import abstractmethod
 from os import listdir, path
 from typing import Any
 
-from yaml import FullLoader, load
-
 from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
 from core.tools.entities.user_entities import UserToolProviderCredentials
 from core.tools.errors import (
@@ -15,6 +13,7 @@ from core.tools.errors import (
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
+from core.tools.utils.yaml_utils import load_yaml_file
 from core.utils.module_import_helper import load_single_subclass_from_source
 
 
@@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController):
         provider = self.__class__.__module__.split('.')[-1]
         yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
         try:
-            with open(yaml_path, 'rb') as f:
-                provider_yaml = load(f.read(), FullLoader)
-        except:
-            raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
+            provider_yaml = load_yaml_file(yaml_path)
+        except Exception as e:
+            raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
 
         if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
             # set credentials name
@@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController):
         tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
         tools = []
         for tool_file in tool_files:
-            with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
-                # get tool name
-                tool_name = tool_file.split(".")[0]
-                tool = load(f.read(), FullLoader)
-                # get tool class, import the module
-                assistant_tool_class = load_single_subclass_from_source(
-                    module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
-                    script_path=path.join(path.dirname(path.realpath(__file__)),
-                                           'builtin', provider, 'tools', f'{tool_name}.py'),
-                    parent_type=BuiltinTool)
-                tool["identity"]["provider"] = provider
-                tools.append(assistant_tool_class(**tool))
+            # get tool name
+            tool_name = tool_file.split(".")[0]
+            tool = load_yaml_file(path.join(tool_path, tool_file))
+
+            # get tool class, import the module
+            assistant_tool_class = load_single_subclass_from_source(
+                module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
+                script_path=path.join(path.dirname(path.realpath(__file__)),
+                                       'builtin', provider, 'tools', f'{tool_name}.py'),
+                parent_type=BuiltinTool)
+            tool["identity"]["provider"] = provider
+            tools.append(assistant_tool_class(**tool))
 
         self.tools = tools
         return tools

+ 17 - 17
api/core/tools/utils/configuration.py

@@ -23,7 +23,7 @@ class ToolConfigurationManager(BaseModel):
         deep copy credentials
         """
         return deepcopy(credentials)
-    
+
     def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
         """
         encrypt tool credentials with tenant id
@@ -39,9 +39,9 @@ class ToolConfigurationManager(BaseModel):
                 if field_name in credentials:
                     encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
                     credentials[field_name] = encrypted
-        
+
         return credentials
-    
+
     def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
         """
         mask tool credentials
@@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel):
                     if len(credentials[field_name]) > 6:
                         credentials[field_name] = \
                             credentials[field_name][:2] + \
-                            '*' * (len(credentials[field_name]) - 4) +\
+                            '*' * (len(credentials[field_name]) - 4) + \
                             credentials[field_name][-2:]
                     else:
                         credentials[field_name] = '*' * len(credentials[field_name])
@@ -72,7 +72,7 @@ class ToolConfigurationManager(BaseModel):
         return a deep copy of credentials with decrypted values
         """
         cache = ToolProviderCredentialsCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
@@ -92,10 +92,10 @@ class ToolConfigurationManager(BaseModel):
 
         cache.set(credentials)
         return credentials
-    
+
     def delete_tool_credentials_cache(self):
         cache = ToolProviderCredentialsCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
@@ -116,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel):
         deep copy parameters
         """
         return deepcopy(parameters)
-    
+
     def _merge_parameters(self) -> list[ToolParameter]:
         """
         merge parameters
@@ -139,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel):
                 current_parameters.append(runtime_parameter)
 
         return current_parameters
-    
+
     def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         mask tool parameters
@@ -157,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel):
                     if len(parameters[parameter.name]) > 6:
                         parameters[parameter.name] = \
                             parameters[parameter.name][:2] + \
-                            '*' * (len(parameters[parameter.name]) - 4) +\
+                            '*' * (len(parameters[parameter.name]) - 4) + \
                             parameters[parameter.name][-2:]
                     else:
                         parameters[parameter.name] = '*' * len(parameters[parameter.name])
 
         return parameters
-    
+
     def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         encrypt tool parameters with tenant id
@@ -180,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel):
                 if parameter.name in parameters:
                     encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
                     parameters[parameter.name] = encrypted
-        
+
         return parameters
-    
+
     def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
         """
         decrypt tool parameters with tenant id
@@ -190,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel):
         return a deep copy of parameters with decrypted values
         """
         cache = ToolParameterCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             provider=f'{self.provider_type}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
@@ -212,15 +212,15 @@ class ToolParameterConfigurationManager(BaseModel):
                         parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
                     except:
                         pass
-        
+
         if has_secret_input:
             cache.set(parameters)
 
         return parameters
-    
+
     def delete_tool_parameters_cache(self):
         cache = ToolParameterCache(
-            tenant_id=self.tenant_id, 
+            tenant_id=self.tenant_id,
             provider=f'{self.provider_type}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,

+ 34 - 0
api/core/tools/utils/yaml_utils.py

@@ -0,0 +1,34 @@
+import logging
+import os
+
+import yaml
+from yaml import YAMLError
+
+
+def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
+    """
+    Safe loading a YAML file to a dict
+    :param file_path: the path of the YAML file
+    :param ignore_error:
+        if True, return empty dict if error occurs and the error will be logged in warning level
+        if False, raise error if error occurs
+    :return: a dict of the YAML content
+    """
+    try:
+        if not file_path or not os.path.exists(file_path):
+            raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
+
+        with open(file_path, encoding='utf-8') as file:
+            try:
+                return yaml.safe_load(file)
+            except Exception as e:
+                raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
+    except FileNotFoundError as e:
+        logging.debug(f'Failed to load YAML file {file_path}: {e}')
+        return {}
+    except Exception as e:
+        if ignore_error:
+            logging.warning(f'Failed to load YAML file {file_path}: {e}')
+            return {}
+        else:
+            raise e

+ 10 - 17
api/core/utils/position_helper.py

@@ -1,10 +1,9 @@
-import logging
 import os
 from collections import OrderedDict
 from collections.abc import Callable
 from typing import Any, AnyStr
 
-import yaml
+from core.tools.utils.yaml_utils import load_yaml_file
 
 
 def get_position_map(
@@ -17,21 +16,15 @@ def get_position_map(
     :param file_name: the YAML file name, default to '_position.yaml'
     :return: a dict with name as key and index as value
     """
-    try:
-        position_file_name = os.path.join(folder_path, file_name)
-        if not os.path.exists(position_file_name):
-            return {}
-
-        with open(position_file_name, encoding='utf-8') as f:
-            positions = yaml.safe_load(f)
-        position_map = {}
-        for index, name in enumerate(positions):
-            if name and isinstance(name, str):
-                position_map[name.strip()] = index
-        return position_map
-    except:
-        logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
-        return {}
+    position_file_name = os.path.join(folder_path, file_name)
+    positions = load_yaml_file(position_file_name, ignore_error=True)
+    position_map = {}
+    index = 0
+    for _, name in enumerate(positions):
+        if name and isinstance(name, str):
+            position_map[name.strip()] = index
+            index += 1
+    return position_map
 
 
 def sort_by_position_map(

+ 1 - 0
api/pyproject.toml

@@ -14,6 +14,7 @@ select = [
     "I", # isort rules
     "UP",   # pyupgrade rules
     "RUF019", # unnecessary-key-check
+    "S506", # unsafe-yaml-load
 ]
 ignore = [
     "F403", # undefined-local-with-import-star

+ 0 - 0
api/tests/unit_tests/utils/__init__.py


+ 34 - 0
api/tests/unit_tests/utils/position_helper/test_position_helper.py

@@ -0,0 +1,34 @@
+from textwrap import dedent
+
+import pytest
+
+from core.utils.position_helper import get_position_map
+
+
+@pytest.fixture
+def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    tmp_path.joinpath("example_positions.yaml").write_text(dedent(
+        """\
+        - first
+        - second
+        # - commented
+        - third
+        
+        - 9999999999999
+        - forth
+        """))
+    return str(tmp_path)
+
+
+def test_position_helper(prepare_example_positions_yaml):
+    position_map = get_position_map(
+        folder_path=prepare_example_positions_yaml,
+        file_name='example_positions.yaml')
+    assert len(position_map) == 4
+    assert position_map == {
+        'first': 0,
+        'second': 1,
+        'third': 2,
+        'forth': 3,
+    }

+ 0 - 0
api/tests/unit_tests/utils/yaml/__init__.py


+ 74 - 0
api/tests/unit_tests/utils/yaml/test_yaml_utils.py

@@ -0,0 +1,74 @@
+from textwrap import dedent
+
+import pytest
+from yaml import YAMLError
+
+from core.tools.utils.yaml_utils import load_yaml_file
+
+EXAMPLE_YAML_FILE = 'example_yaml.yaml'
+INVALID_YAML_FILE = 'invalid_yaml.yaml'
+NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
+
+
+@pytest.fixture
+def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
+    file_path.write_text(dedent(
+        """\
+        address:
+            city: Example City
+            country: Example Country
+        age: 30
+        gender: male
+        languages:
+            - Python
+            - Java
+            - C++
+        empty_key:
+        """))
+    return str(file_path)
+
+
+@pytest.fixture
+def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    file_path = tmp_path.joinpath(INVALID_YAML_FILE)
+    file_path.write_text(dedent(
+        """\
+        address:
+                   city: Example City
+            country: Example Country
+        age: 30
+        gender: male
+        languages:
+        - Python
+        - Java
+        - C++
+        """))
+    return str(file_path)
+
+
+def test_load_yaml_non_existing_file():
+    assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
+    assert load_yaml_file(file_path='') == {}
+
+
+def test_load_valid_yaml_file(prepare_example_yaml_file):
+    yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
+    assert len(yaml_data) > 0
+    assert yaml_data['age'] == 30
+    assert yaml_data['gender'] == 'male'
+    assert yaml_data['address']['city'] == 'Example City'
+    assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
+    assert yaml_data.get('empty_key') is None
+    assert yaml_data.get('non_existed_key') is None
+
+
+def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
+    # yaml syntax error
+    with pytest.raises(YAMLError):
+        load_yaml_file(file_path=prepare_invalid_yaml_file)
+
+    # ignore error
+    assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}