Forráskód Böngészése

chore: improve position map conversion and tolerate empty position yaml file (#6541)

Bowen Liang 8 hónapja
szülő
commit
20268708cc

+ 4 - 12
api/core/helper/position_helper.py

@@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
     :param file_name: the YAML file name, default to '_position.yaml'
     :return: a dict with name as key and index as value
     """
-    position_file_name = os.path.join(folder_path, file_name)
-    if not position_file_name or not os.path.exists(position_file_name):
-        return {}
-    
-    positions = load_yaml_file(position_file_name, ignore_error=True)
-    position_map = {}
-    index = 0
-    for _, name in enumerate(positions):
-        if name and isinstance(name, str):
-            position_map[name.strip()] = index
-            index += 1
-    return position_map
+    position_file_path = os.path.join(folder_path, file_name)
+    yaml_content = load_yaml_file(file_path=position_file_path, default_value=[])
+    positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()]
+    return {name: index for index, name in enumerate(positions)}
 
 
 def sort_by_position_map(

+ 1 - 1
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -162,7 +162,7 @@ class AIModel(ABC):
         # traverse all model_schema_yaml_paths
         for model_schema_yaml_path in model_schema_yaml_paths:
             # read yaml data from yaml file
-            yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
+            yaml_data = load_yaml_file(model_schema_yaml_path)
 
             new_parameter_rules = []
             for parameter_rule in yaml_data.get('parameter_rules', []):

+ 1 - 1
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -44,7 +44,7 @@ class ModelProvider(ABC):
     
         # read provider schema from yaml file
         yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
-        yaml_data = load_yaml_file(yaml_path, ignore_error=True)
+        yaml_data = load_yaml_file(yaml_path)
     
         try:
             # yaml_data to entity

+ 2 - 2
api/core/tools/provider/builtin_tool_provider.py

@@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController):
         provider = self.__class__.__module__.split('.')[-1]
         yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
         try:
-            provider_yaml = load_yaml_file(yaml_path)
+            provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
         except Exception as e:
             raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')
 
@@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController):
         for tool_file in tool_files:
             # get tool name
             tool_name = tool_file.split(".")[0]
-            tool = load_yaml_file(path.join(tool_path, tool_file))
+            tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
 
             # get tool class, import the module
             assistant_tool_class = load_single_subclass_from_source(

+ 10 - 14
api/core/tools/utils/yaml_utils.py

@@ -1,35 +1,31 @@
 import logging
-import os
+from typing import Any
 
 import yaml
 from yaml import YAMLError
 
 logger = logging.getLogger(__name__)
 
-def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
+
+def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
     """
-    Safe loading a YAML file to a dict
+    Safe loading a YAML file
     :param file_path: the path of the YAML file
     :param ignore_error:
-        if True, return empty dict if error occurs and the error will be logged in warning level
+        if True, return default_value if error occurs and the error will be logged in warning level
         if False, raise error if error occurs
-    :return: a dict of the YAML content
+    :param default_value: the value returned when errors ignored
+    :return: an object of the YAML content
     """
     try:
-        if not file_path or not os.path.exists(file_path):
-            raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')
-
-        with open(file_path, encoding='utf-8') as file:
+        with open(file_path, encoding='utf-8') as yaml_file:
             try:
-                return yaml.safe_load(file)
+                return yaml.safe_load(yaml_file)
             except Exception as e:
                 raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
-    except FileNotFoundError as e:
-        logger.debug(f'Failed to load YAML file {file_path}: {e}')
-        return {}
     except Exception as e:
         if ignore_error:
             logger.warning(f'Failed to load YAML file {file_path}: {e}')
-            return {}
+            return default_value
         else:
             raise e

+ 21 - 0
api/tests/unit_tests/utils/position_helper/test_position_helper.py

@@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
     return str(tmp_path)
 
 
+@pytest.fixture
+def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
+    monkeypatch.chdir(tmp_path)
+    tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent(
+        """\
+        # - commented1
+        # - commented2
+        - 
+        -   
+        
+        """))
+    return str(tmp_path)
+
+
 def test_position_helper(prepare_example_positions_yaml):
     position_map = get_position_map(
         folder_path=prepare_example_positions_yaml,
@@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml):
         'third': 2,
         'forth': 3,
     }
+
+
+def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
+    position_map = get_position_map(
+        folder_path=prepare_empty_commented_positions_yaml,
+        file_name='example_positions_all_commented.yaml')
+    assert position_map == {}

+ 5 - 2
api/tests/unit_tests/utils/yaml/test_yaml_utils.py

@@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file():
     assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
     assert load_yaml_file(file_path='') == {}
 
+    with pytest.raises(FileNotFoundError):
+        load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
+
 
 def test_load_valid_yaml_file(prepare_example_yaml_file):
     yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
@@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file):
 def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
     # yaml syntax error
     with pytest.raises(YAMLError):
-        load_yaml_file(file_path=prepare_invalid_yaml_file)
+        load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
 
     # ignore error
-    assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {}
+    assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}