Ver código fonte

feat: support pinning, including, and excluding for model providers and tools (#7419)

Co-authored-by: GareArc <chen4851@purude.edu>
Xiyuan Chen 8 meses atrás
pai
commit
4e7b6aec3a

+ 10 - 1
api/.env.example

@@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
 
 
 # Celery beat configuration
-CELERY_BEAT_SCHEDULER_TIME=1
+CELERY_BEAT_SCHEDULER_TIME=1
+
+# Position configuration
+POSITION_TOOL_PINS=
+POSITION_TOOL_INCLUDES=
+POSITION_TOOL_EXCLUDES=
+
+POSITION_PROVIDER_PINS=
+POSITION_PROVIDER_INCLUDES=
+POSITION_PROVIDER_EXCLUDES=

+ 59 - 0
api/configs/feature/__init__.py

@@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
         default=False,
     )
 
+
 class WorkspaceConfig(BaseSettings):
     """
     Workspace configs
@@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
     )
 
 
+class PositionConfig(BaseSettings):
+
+    POSITION_PROVIDER_PINS: str = Field(
+        description='The heads of model providers',
+        default='',
+    )
+
+    POSITION_PROVIDER_INCLUDES: str = Field(
+        description='The included model providers',
+        default='',
+    )
+
+    POSITION_PROVIDER_EXCLUDES: str = Field(
+        description='The excluded model providers',
+        default='',
+    )
+
+    POSITION_TOOL_PINS: str = Field(
+        description='The heads of tools',
+        default='',
+    )
+
+    POSITION_TOOL_INCLUDES: str = Field(
+        description='The included tools',
+        default='',
+    )
+
+    POSITION_TOOL_EXCLUDES: str = Field(
+        description='The excluded tools',
+        default='',
+    )
+
+    @computed_field
+    def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
+        return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
+
+    @computed_field
+    def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
+        return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
+
+    @computed_field
+    def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
+        return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
+
+    @computed_field
+    def POSITION_TOOL_PINS_LIST(self) -> list[str]:
+        return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
+
+    @computed_field
+    def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
+        return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
+
+    @computed_field
+    def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
+        return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
+
+
 class FeatureConfig(
     # place the configs in alphabet order
     AppExecutionConfig,
@@ -466,6 +524,7 @@ class FeatureConfig(
     UpdateConfig,
     WorkflowConfig,
     WorkspaceConfig,
+    PositionConfig,
 
     # hosted services config
     HostedServiceConfig,

+ 82 - 0
api/core/helper/position_helper.py

@@ -3,6 +3,7 @@ from collections import OrderedDict
 from collections.abc import Callable
 from typing import Any
 
+from configs import dify_config
 from core.tools.utils.yaml_utils import load_yaml_file
 
 
@@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
     return {name: index for index, name in enumerate(positions)}
 
 
+def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
+    """
+    Get the mapping for tools from name to index from a YAML file.
+    :param folder_path:
+    :param file_name: the YAML file name, default to '_position.yaml'
+    :return: a dict with name as key and index as value
+    """
+    position_map = get_position_map(folder_path, file_name=file_name)
+
+    return pin_position_map(
+        position_map,
+        pin_list=dify_config.POSITION_TOOL_PINS_LIST,
+    )
+
+
+def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
+    """
+    Get the mapping for providers from name to index from a YAML file.
+    :param folder_path:
+    :param file_name: the YAML file name, default to '_position.yaml'
+    :return: a dict with name as key and index as value
+    """
+    position_map = get_position_map(folder_path, file_name=file_name)
+    return pin_position_map(
+        position_map,
+        pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
+    )
+
+
+def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
+    """
+    Pin the items in the pin list to the beginning of the position map.
+    Overall logic: exclude > include > pin
+    :param position_map: the position map to be sorted and filtered
+    :param pin_list: the list of pins to be put at the beginning
+    :return: the sorted position map
+    """
+    positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
+
+    # Add pins to position map
+    position_map = {name: idx for idx, name in enumerate(pin_list)}
+
+    # Add remaining positions to position map
+    start_idx = len(position_map)
+    for name in positions:
+        if name not in position_map:
+            position_map[name] = start_idx
+            start_idx += 1
+
+    return position_map
+
+
+def is_filtered(
+        include_set: set[str],
+        exclude_set: set[str],
+        data: Any,
+        name_func: Callable[[Any], str],
+) -> bool:
+    """
+    Chcek if the object should be filtered out.
+    Overall logic: exclude > include > pin
+    :param include_set: the set of names to be included
+    :param exclude_set: the set of names to be excluded
+    :param name_func: the function to get the name of the object
+    :param data: the data to be filtered
+    :return: True if the object should be filtered out, False otherwise
+    """
+    if not data:
+        return False
+    if not include_set and not exclude_set:
+        return False
+
+    name = name_func(data)
+
+    if name in exclude_set:  # exclude_set is prioritized
+        return True
+    if include_set and name not in include_set:  # filter out only if include_set is not empty
+        return True
+    return False
+
+
 def sort_by_position_map(
         position_map: dict[str, int],
         data: list[Any],

+ 9 - 1
api/core/model_manager.py

@@ -368,6 +368,15 @@ class ModelManager:
 
         return ModelInstance(provider_model_bundle, model)
 
+    def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
+        """
+        Return first provider and the first model in the provider
+        :param tenant_id: tenant id
+        :param model_type: model type
+        :return: provider name, model name
+        """
+        return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
+
     def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
         """
         Get default model instance
@@ -502,7 +511,6 @@ class LBModelManager:
             config.id
         )
 
-
         res = redis_client.exists(cooldown_cache_key)
         res = cast(bool, res)
         return res

+ 3 - 3
api/core/model_runtime/model_providers/__base/ai_model.py

@@ -151,9 +151,9 @@ class AIModel(ABC):
             os.path.join(provider_model_type_path, model_schema_yaml)
             for model_schema_yaml in os.listdir(provider_model_type_path)
             if not model_schema_yaml.startswith('__')
-               and not model_schema_yaml.startswith('_')
-               and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
-               and model_schema_yaml.endswith('.yaml')
+            and not model_schema_yaml.startswith('_')
+            and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
+            and model_schema_yaml.endswith('.yaml')
         ]
 
         # get _position.yaml file path

+ 2 - 2
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -6,7 +6,7 @@ from typing import Optional
 from pydantic import BaseModel, ConfigDict
 
 from core.helper.module_import_helper import load_single_subclass_from_source
-from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
+from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@@ -234,7 +234,7 @@ class ModelProviderFactory:
         ]
 
         # get _position.yaml file path
-        position_map = get_position_map(model_providers_path)
+        position_map = get_provider_position_map(model_providers_path)
 
         # traverse all model_provider_dir_paths
         model_providers: list[ModelProviderExtension] = []

+ 32 - 5
api/core/provider_manager.py

@@ -5,6 +5,7 @@ from typing import Optional
 
 from sqlalchemy.exc import IntegrityError
 
+from configs import dify_config
 from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
 from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
 from core.entities.provider_entities import (
@@ -18,12 +19,9 @@ from core.entities.provider_entities import (
 )
 from core.helper import encrypter
 from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
+from core.helper.position_helper import is_filtered
 from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.entities.provider_entities import (
-    CredentialFormSchema,
-    FormType,
-    ProviderEntity,
-)
+from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
 from core.model_runtime.model_providers import model_provider_factory
 from extensions import ext_hosting_provider
 from extensions.ext_database import db
@@ -45,6 +43,7 @@ class ProviderManager:
     """
     ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
     """
+
     def __init__(self) -> None:
         self.decoding_rsa_key = None
         self.decoding_cipher_rsa = None
@@ -117,6 +116,16 @@ class ProviderManager:
 
         # Construct ProviderConfiguration objects for each provider
         for provider_entity in provider_entities:
+
+            # handle include, exclude
+            if is_filtered(
+                    include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
+                    exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
+                    data=provider_entity,
+                    name_func=lambda x: x.provider,
+            ):
+                continue
+
             provider_name = provider_entity.provider
             provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
             provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
@@ -271,6 +280,24 @@ class ProviderManager:
             )
         )
 
+    def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
+        """
+        Get names of first model and its provider
+
+        :param tenant_id: workspace id
+        :param model_type: model type
+        :return: provider name, model name
+        """
+        provider_configurations = self.get_configurations(tenant_id)
+
+        # get available models from provider_configurations
+        all_models = provider_configurations.get_models(
+            model_type=model_type,
+            only_active=False
+        )
+
+        return all_models[0].provider.provider, all_models[0].model
+
     def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
             -> TenantDefaultModel:
         """

+ 3 - 3
api/core/tools/provider/builtin/_positions.py

@@ -1,6 +1,6 @@
 import os.path
 
-from core.helper.position_helper import get_position_map, sort_by_position_map
+from core.helper.position_helper import get_tool_position_map, sort_by_position_map
 from core.tools.entities.api_entities import UserToolProvider
 
 
@@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
     @classmethod
     def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
         if not cls._position:
-            cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
+            cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
 
         def name_func(provider: UserToolProvider) -> str:
             return provider.name
 
         sorted_providers = sort_by_position_map(cls._position, providers, name_func)
 
-        return sorted_providers
+        return sorted_providers

+ 17 - 12
api/core/tools/tool_manager.py

@@ -10,14 +10,11 @@ from configs import dify_config
 from core.agent.entities import AgentToolEntity
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.helper.module_import_helper import load_single_subclass_from_source
+from core.helper.position_helper import is_filtered
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
 from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import (
-    ApiProviderAuthType,
-    ToolInvokeFrom,
-    ToolParameter,
-)
+from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
 from core.tools.errors import ToolProviderNotFoundError
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.provider.builtin._positions import BuiltinToolProviderSort
@@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
 from core.tools.tool_label_manager import ToolLabelManager
-from core.tools.utils.configuration import (
-    ToolConfigurationManager,
-    ToolParameterConfigurationManager,
-)
+from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.workflow.nodes.tool.entities import ToolEntity
 from extensions.ext_database import db
@@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
 
 logger = logging.getLogger(__name__)
 
+
 class ToolManager:
     _builtin_provider_lock = Lock()
     _builtin_providers = {}
@@ -107,7 +102,7 @@ class ToolManager:
                          tenant_id: str,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-        -> Union[BuiltinTool, ApiTool]:
+            -> Union[BuiltinTool, ApiTool]:
         """
             get the tool runtime
 
@@ -346,7 +341,7 @@ class ToolManager:
                     provider_class = load_single_subclass_from_source(
                         module_name=f'core.tools.provider.builtin.{provider}.{provider}',
                         script_path=path.join(path.dirname(path.realpath(__file__)),
-                                            'provider', 'builtin', provider, f'{provider}.py'),
+                                              'provider', 'builtin', provider, f'{provider}.py'),
                         parent_type=BuiltinToolProviderController)
                     provider: BuiltinToolProviderController = provider_class()
                     cls._builtin_providers[provider.identity.name] = provider
@@ -414,6 +409,15 @@ class ToolManager:
 
             # append builtin providers
             for provider in builtin_providers:
+                # handle include, exclude
+                if is_filtered(
+                        include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
+                        exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
+                        data=provider,
+                        name_func=lambda x: x.identity.name
+                ):
+                    continue
+
                 user_provider = ToolTransformService.builtin_provider_to_user_provider(
                     provider_controller=provider,
                     db_provider=find_db_builtin_provider(provider.identity.name),
@@ -473,7 +477,7 @@ class ToolManager:
 
     @classmethod
     def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
-        ApiToolProviderController, dict[str, Any]]:
+            ApiToolProviderController, dict[str, Any]]:
         """
             get the api provider
 
@@ -593,4 +597,5 @@ class ToolManager:
         else:
             raise ValueError(f"provider type {provider_type} not found")
 
+
 ToolManager.load_builtin_providers_cache()

+ 8 - 1
api/services/app_service.py

@@ -111,6 +111,12 @@ class AppService:
                         'completion_params': {}
                     }
             else:
+                provider, model = model_manager.get_default_provider_model_name(
+                    tenant_id=account.current_tenant_id,
+                    model_type=ModelType.LLM
+                )
+                default_model_config['model']['provider'] = provider
+                default_model_config['model']['name'] = model
                 default_model_dict = default_model_config['model']
 
             default_model_config['model'] = json.dumps(default_model_dict)
@@ -190,13 +196,14 @@ class AppService:
                 """
                 Modified App class
                 """
+
                 def __init__(self, app):
                     self.__dict__.update(app.__dict__)
 
                 @property
                 def app_model_config(self):
                     return model_config
-                
+
             app = ModifiedApp(app)
 
         return app

+ 16 - 12
api/services/model_provider_service.py

@@ -30,6 +30,7 @@ class ModelProviderService:
     """
     Model Provider Service
     """
+
     def __init__(self) -> None:
         self.provider_manager = ProviderManager()
 
@@ -387,18 +388,21 @@ class ModelProviderService:
             tenant_id=tenant_id,
             model_type=model_type_enum
         )
-
-        return DefaultModelResponse(
-            model=result.model,
-            model_type=result.model_type,
-            provider=SimpleProviderEntityResponse(
-                provider=result.provider.provider,
-                label=result.provider.label,
-                icon_small=result.provider.icon_small,
-                icon_large=result.provider.icon_large,
-                supported_model_types=result.provider.supported_model_types
-            )
-        ) if result else None
+        try:
+            return DefaultModelResponse(
+                model=result.model,
+                model_type=result.model_type,
+                provider=SimpleProviderEntityResponse(
+                    provider=result.provider.provider,
+                    label=result.provider.label,
+                    icon_small=result.provider.icon_small,
+                    icon_large=result.provider.icon_large,
+                    supported_model_types=result.provider.supported_model_types
+                )
+            ) if result else None
+        except Exception as e:
+            logger.info(f"get_default_model_of_model_type error: {e}")
+            return None
 
     def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
         """

+ 22 - 12
api/services/tools/builtin_tools_manage_service.py

@@ -1,6 +1,8 @@
 import json
 import logging
 
+from configs import dify_config
+from core.helper.position_helper import is_filtered
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.api_entities import UserTool, UserToolProvider
 from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
@@ -43,14 +45,14 @@ class BuiltinToolManageService:
         result = []
         for tool in tools:
             result.append(ToolTransformService.tool_to_user_tool(
-                tool=tool, 
-                credentials=credentials, 
+                tool=tool,
+                credentials=credentials,
                 tenant_id=tenant_id,
                 labels=ToolLabelManager.get_tool_labels(provider_controller)
             ))
 
         return result
-    
+
     @staticmethod
     def list_builtin_provider_credentials_schema(
         provider_name
@@ -78,7 +80,7 @@ class BuiltinToolManageService:
             BuiltinToolProvider.provider == provider_name,
         ).first()
 
-        try: 
+        try:
             # get provider
             provider_controller = ToolManager.get_builtin_provider(provider_name)
             if not provider_controller.need_credentials:
@@ -119,8 +121,8 @@ class BuiltinToolManageService:
             # delete cache
             tool_configuration.delete_tool_credentials_cache()
 
-        return { 'result': 'success' }
-    
+        return {'result': 'success'}
+
     @staticmethod
     def get_builtin_tool_provider_credentials(
         user_id: str, tenant_id: str, provider: str
@@ -135,7 +137,7 @@ class BuiltinToolManageService:
 
         if provider is None:
             return {}
-        
+
         provider_controller = ToolManager.get_builtin_provider(provider.provider)
         tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
         credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@@ -156,7 +158,7 @@ class BuiltinToolManageService:
 
         if provider is None:
             raise ValueError(f'you have not added provider {provider_name}')
-        
+
         db.session.delete(provider)
         db.session.commit()
 
@@ -165,8 +167,8 @@ class BuiltinToolManageService:
         tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
         tool_configuration.delete_tool_credentials_cache()
 
-        return { 'result': 'success' }
-    
+        return {'result': 'success'}
+
     @staticmethod
     def get_builtin_tool_provider_icon(
         provider: str
@@ -179,7 +181,7 @@ class BuiltinToolManageService:
             icon_bytes = f.read()
 
         return icon_bytes, mime_type
-    
+
     @staticmethod
     def list_builtin_tools(
         user_id: str, tenant_id: str
@@ -202,6 +204,15 @@ class BuiltinToolManageService:
 
         for provider_controller in provider_controllers:
             try:
+                # handle include, exclude
+                if is_filtered(
+                    include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
+                    exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
+                    data=provider_controller,
+                    name_func=lambda x: x.identity.name
+                ):
+                    continue
+
                 # convert provider controller to user provider
                 user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
                     provider_controller=provider_controller,
@@ -226,4 +237,3 @@ class BuiltinToolManageService:
                 raise e
 
         return BuiltinToolProviderSort.sort(result)
-    

+ 81 - 5
api/tests/unit_tests/utils/position_helper/test_position_helper.py

@@ -2,7 +2,7 @@ from textwrap import dedent
 
 import pytest
 
-from core.helper.position_helper import get_position_map
+from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map
 
 
 @pytest.fixture
@@ -14,7 +14,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
         - second
         # - commented
         - third
-        
+
         - 9999999999999
         - forth
         """))
@@ -28,9 +28,9 @@ def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
         """\
         # - commented1
         # - commented2
-        - 
-        -   
-        
+        -
+        -
+
         """))
     return str(tmp_path)
 
@@ -53,3 +53,79 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya
         folder_path=prepare_empty_commented_positions_yaml,
         file_name='example_positions_all_commented.yaml')
     assert position_map == {}
+
+
+def test_excluded_position_data(prepare_example_positions_yaml):
+    position_map = get_position_map(
+        folder_path=prepare_example_positions_yaml,
+        file_name='example_positions.yaml'
+    )
+    pin_list = ['forth', 'first']
+    include_set = set()
+    exclude_set = {'9999999999999'}
+
+    position_map = pin_position_map(
+        original_position_map=position_map,
+        pin_list=pin_list
+    )
+
+    data = [
+        "forth",
+        "first",
+        "second",
+        "third",
+        "9999999999999",
+        "extra1",
+        "extra2",
+    ]
+
+    # filter out the data
+    data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
+
+    # sort data by position map
+    sorted_data = sort_by_position_map(
+        position_map=position_map,
+        data=data,
+        name_func=lambda x: x,
+    )
+
+    # assert the result in the correct order
+    assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
+
+
+def test_included_position_data(prepare_example_positions_yaml):
+    position_map = get_position_map(
+        folder_path=prepare_example_positions_yaml,
+        file_name='example_positions.yaml'
+    )
+    pin_list = ['forth', 'first']
+    include_set = {'forth', 'first'}
+    exclude_set = {}
+
+    position_map = pin_position_map(
+        original_position_map=position_map,
+        pin_list=pin_list
+    )
+
+    data = [
+        "forth",
+        "first",
+        "second",
+        "third",
+        "9999999999999",
+        "extra1",
+        "extra2",
+    ]
+
+    # filter out the data
+    data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
+
+    # sort data by position map
+    sorted_data = sort_by_position_map(
+        position_map=position_map,
+        data=data,
+        name_func=lambda x: x,
+    )
+
+    # assert the result in the correct order
+    assert sorted_data == ['forth', 'first']

+ 19 - 0
docker/.env.example

@@ -701,3 +701,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate}
 # ------------------------------
 EXPOSE_NGINX_PORT=80
 EXPOSE_NGINX_SSL_PORT=443
+
+# ----------------------------------------------------------------------------
+# ModelProvider & Tool Position Configuration
+# Used to specify the model providers and tools that can be used in the app.
+# ----------------------------------------------------------------------------
+
+# Pin, include, and exclude tools
+# Use comma-separated values with no spaces between items.
+# Example: POSITION_TOOL_PINS=bing,google
+POSITION_TOOL_PINS=
+POSITION_TOOL_INCLUDES=
+POSITION_TOOL_EXCLUDES=
+
+# Pin, include, and exclude model providers
+# Use comma-separated values with no spaces between items.
+# Example: POSITION_PROVIDER_PINS=openai,openllm
+POSITION_PROVIDER_PINS=
+POSITION_PROVIDER_INCLUDES=
+POSITION_PROVIDER_EXCLUDES=