Browse Source

generalize position helper for parsing _position.yaml and sorting objects by name (#2803)

Bowen Liang 1 year ago
parent
commit
8b15b742ad

+ 8 - 6
api/core/extension/extensible.py

@@ -3,11 +3,12 @@ import importlib.util
 import json
 import logging
 import os
-from collections import OrderedDict
 from typing import Any, Optional
 
 from pydantic import BaseModel
 
+from core.utils.position_helper import sort_to_dict_by_position_map
+
 
 class ExtensionModule(enum.Enum):
     MODERATION = 'moderation'
@@ -36,7 +37,8 @@ class Extensible:
 
     @classmethod
     def scan_extensions(cls):
-        extensions = {}
+        extensions: list[ModuleExtension] = []
+        position_map = {}
 
         # get the path of the current class
         current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
@@ -63,6 +65,7 @@ class Extensible:
                     if os.path.exists(builtin_file_path):
                         with open(builtin_file_path, encoding='utf-8') as f:
                             position = int(f.read().strip())
+                position_map[extension_name] = position
 
                 if (extension_name + '.py') not in file_names:
                     logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
@@ -96,16 +99,15 @@ class Extensible:
                         with open(json_path, encoding='utf-8') as f:
                             json_data = json.load(f)
 
-                extensions[extension_name] = ModuleExtension(
+                extensions.append(ModuleExtension(
                     extension_class=extension_class,
                     name=extension_name,
                     label=json_data.get('label'),
                     form_schema=json_data.get('form_schema'),
                     builtin=builtin,
                     position=position
-                )
+                ))
 
-        sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
-        sorted_extensions = OrderedDict(sorted_items)
+        sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
 
         return sorted_extensions

+ 3 - 11
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
 )
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
+from core.utils.position_helper import get_position_map, sort_by_position_map
 
 
 class AIModel(ABC):
@@ -148,15 +149,7 @@ class AIModel(ABC):
         ]
 
         # get _position.yaml file path
-        position_file_path = os.path.join(provider_model_type_path, '_position.yaml')
-
-        # read _position.yaml file
-        position_map = {}
-        if os.path.exists(position_file_path):
-            with open(position_file_path, encoding='utf-8') as f:
-                positions = yaml.safe_load(f)
-                # convert list to dict with key as model provider name, value as index
-                position_map = {position: index for index, position in enumerate(positions)}
+        position_map = get_position_map(provider_model_type_path)
 
         # traverse all model_schema_yaml_paths
         for model_schema_yaml_path in model_schema_yaml_paths:
@@ -206,8 +199,7 @@ class AIModel(ABC):
             model_schemas.append(model_schema)
 
         # resort model schemas by position
-        if position_map:
-            model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
+        model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
 
         # cache model schemas
         self.model_schemas = model_schemas

+ 6 - 16
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -1,10 +1,8 @@
 import importlib
 import logging
 import os
-from collections import OrderedDict
 from typing import Optional
 
-import yaml
 from pydantic import BaseModel
 
 from core.model_runtime.entities.model_entities import ModelType
@@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
 from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
 from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
+from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
 
 logger = logging.getLogger(__name__)
 
@@ -200,7 +199,6 @@ class ModelProviderFactory:
         if self.model_provider_extensions:
             return self.model_provider_extensions
 
-        model_providers = {}
 
         # get the path of current classes
         current_path = os.path.abspath(__file__)
@@ -215,17 +213,10 @@ class ModelProviderFactory:
         ]
 
         # get _position.yaml file path
-        position_file_path = os.path.join(model_providers_path, '_position.yaml')
-
-        # read _position.yaml file
-        position_map = {}
-        if os.path.exists(position_file_path):
-            with open(position_file_path, encoding='utf-8') as f:
-                positions = yaml.safe_load(f)
-                # convert list to dict with key as model provider name, value as index
-                position_map = {position: index for index, position in enumerate(positions)}
+        position_map = get_position_map(model_providers_path)
 
         # traverse all model_provider_dir_paths
+        model_providers: list[ModelProviderExtension] = []
         for model_provider_dir_path in model_provider_dir_paths:
             # get model_provider dir name
             model_provider_name = os.path.basename(model_provider_dir_path)
@@ -256,14 +247,13 @@ class ModelProviderFactory:
                 logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
                 continue
 
-            model_providers[model_provider_name] = ModelProviderExtension(
+            model_providers.append(ModelProviderExtension(
                 name=model_provider_name,
                 provider_instance=model_provider_class(),
                 position=position_map.get(model_provider_name)
-            )
+            ))
 
-        sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position))
-        sorted_extensions = OrderedDict(sorted_items)
+        sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
 
         self.model_provider_extensions = sorted_extensions
 

+ 8 - 13
api/core/tools/provider/builtin/_positions.py

@@ -1,8 +1,7 @@
 import os.path
 
-from yaml import FullLoader, load
-
 from core.tools.entities.user_entities import UserToolProvider
+from core.utils.position_helper import get_position_map, sort_by_position_map
 
 
 class BuiltinToolProviderSort:
@@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
     @classmethod
     def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
         if not cls._position:
-            tmp_position = {}
-            file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
-            with open(file_path) as f:
-                for pos, val in enumerate(load(f, Loader=FullLoader)):
-                    tmp_position[val] = pos
-            cls._position = tmp_position
+            cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
 
-        def sort_compare(provider: UserToolProvider) -> int:
+        def name_func(provider: UserToolProvider) -> str:
             if provider.type == UserToolProvider.ProviderType.MODEL:
-                return cls._position.get(f'model.{provider.name}', 10000)
-            return cls._position.get(provider.name, 10000)
-        
-        sorted_providers = sorted(providers, key=sort_compare)
+                return f'model.{provider.name}'
+            else:
+                return provider.name
+
+        sorted_providers = sort_by_position_map(cls._position, providers, name_func)
 
         return sorted_providers

+ 70 - 0
api/core/utils/position_helper.py

@@ -0,0 +1,70 @@
+import logging
+import os
+from collections import OrderedDict
+from collections.abc import Callable
+from typing import Any, AnyStr
+
+import yaml
+
+
+def get_position_map(
+        folder_path: AnyStr,
+        file_name: str = '_position.yaml',
+) -> dict[str, int]:
+    """
+    Get the mapping from name to index from a YAML file
+    :param folder_path:
+    :param file_name: the YAML file name, default to '_position.yaml'
+    :return: a dict with name as key and index as value
+    """
+    try:
+        position_file_name = os.path.join(folder_path, file_name)
+        if not os.path.exists(position_file_name):
+            return {}
+
+        with open(position_file_name, encoding='utf-8') as f:
+            positions = yaml.safe_load(f)
+        position_map = {}
+        for index, name in enumerate(positions):
+            if name and isinstance(name, str):
+                position_map[name.strip()] = index
+        return position_map
+    except:
+        logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
+        return {}
+
+
+def sort_by_position_map(
+        position_map: dict[str, int],
+        data: list[Any],
+        name_func: Callable[[Any], str],
+) -> list[Any]:
+    """
+    Sort the objects by the position map.
+    If the name of the object is not in the position map, it will be put at the end.
+    :param position_map: the map holding positions in the form of {name: index}
+    :param name_func: the function to get the name of the object
+    :param data: the data to be sorted
+    :return: the sorted objects
+    """
+    if not position_map or not data:
+        return data
+
+    return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))
+
+
+def sort_to_dict_by_position_map(
+        position_map: dict[str, int],
+        data: list[Any],
+        name_func: Callable[[Any], str],
+) -> OrderedDict[str, Any]:
+    """
+    Sort the objects into a ordered dict by the position map.
+    If the name of the object is not in the position map, it will be put at the end.
+    :param position_map: the map holding positions in the form of {name: index}
+    :param name_func: the function to get the name of the object
+    :param data: the data to be sorted
+    :return: an OrderedDict with the sorted pairs of name and object
+    """
+    sorted_items = sort_by_position_map(position_map, data, name_func)
+    return OrderedDict([(name_func(item), item) for item in sorted_items])