|
@@ -1,10 +1,8 @@
|
|
|
import importlib
|
|
|
import logging
|
|
|
import os
|
|
|
-from collections import OrderedDict
|
|
|
from typing import Optional
|
|
|
|
|
|
-import yaml
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
@@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
|
|
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
|
|
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
|
|
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
|
|
+from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -200,7 +199,6 @@ class ModelProviderFactory:
|
|
|
if self.model_provider_extensions:
|
|
|
return self.model_provider_extensions
|
|
|
|
|
|
- model_providers = {}
|
|
|
|
|
|
# get the path of current classes
|
|
|
current_path = os.path.abspath(__file__)
|
|
@@ -215,17 +213,10 @@ class ModelProviderFactory:
|
|
|
]
|
|
|
|
|
|
# get _position.yaml file path
|
|
|
- position_file_path = os.path.join(model_providers_path, '_position.yaml')
|
|
|
-
|
|
|
- # read _position.yaml file
|
|
|
- position_map = {}
|
|
|
- if os.path.exists(position_file_path):
|
|
|
- with open(position_file_path, encoding='utf-8') as f:
|
|
|
- positions = yaml.safe_load(f)
|
|
|
- # convert list to dict with key as model provider name, value as index
|
|
|
- position_map = {position: index for index, position in enumerate(positions)}
|
|
|
+ position_map = get_position_map(model_providers_path)
|
|
|
|
|
|
# traverse all model_provider_dir_paths
|
|
|
+ model_providers: list[ModelProviderExtension] = []
|
|
|
for model_provider_dir_path in model_provider_dir_paths:
|
|
|
# get model_provider dir name
|
|
|
model_provider_name = os.path.basename(model_provider_dir_path)
|
|
@@ -256,14 +247,13 @@ class ModelProviderFactory:
|
|
|
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
|
|
continue
|
|
|
|
|
|
- model_providers[model_provider_name] = ModelProviderExtension(
|
|
|
+ model_providers.append(ModelProviderExtension(
|
|
|
name=model_provider_name,
|
|
|
provider_instance=model_provider_class(),
|
|
|
position=position_map.get(model_provider_name)
|
|
|
- )
|
|
|
+ ))
|
|
|
|
|
|
- sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position))
|
|
|
- sorted_extensions = OrderedDict(sorted_items)
|
|
|
+ sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
|
|
|
|
|
self.model_provider_extensions = sorted_extensions
|
|
|
|