|
@@ -1,5 +1,6 @@
|
|
|
import logging
|
|
|
import os
|
|
|
+from collections.abc import Sequence
|
|
|
from typing import Optional
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
@@ -16,20 +17,21 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ModelProviderExtension(BaseModel):
|
|
|
+ model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
+
|
|
|
provider_instance: ModelProvider
|
|
|
name: str
|
|
|
position: Optional[int] = None
|
|
|
- model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
|
|
|
|
|
class ModelProviderFactory:
|
|
|
- model_provider_extensions: dict[str, ModelProviderExtension] = None
|
|
|
+ model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
# for cache in memory
|
|
|
self.get_providers()
|
|
|
|
|
|
- def get_providers(self) -> list[ProviderEntity]:
|
|
|
+ def get_providers(self) -> Sequence[ProviderEntity]:
|
|
|
"""
|
|
|
Get all providers
|
|
|
:return: list of providers
|
|
@@ -39,7 +41,7 @@ class ModelProviderFactory:
|
|
|
|
|
|
# traverse all model_provider_extensions
|
|
|
providers = []
|
|
|
- for name, model_provider_extension in model_provider_extensions.items():
|
|
|
+ for model_provider_extension in model_provider_extensions.values():
|
|
|
# get model_provider instance
|
|
|
model_provider_instance = model_provider_extension.provider_instance
|
|
|
|
|
@@ -57,7 +59,7 @@ class ModelProviderFactory:
|
|
|
# return providers
|
|
|
return providers
|
|
|
|
|
|
- def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
|
|
|
+ def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
|
|
|
"""
|
|
|
Validate provider credentials
|
|
|
|
|
@@ -74,6 +76,9 @@ class ModelProviderFactory:
|
|
|
# get provider_credential_schema and validate credentials according to the rules
|
|
|
provider_credential_schema = provider_schema.provider_credential_schema
|
|
|
|
|
|
+ if not provider_credential_schema:
|
|
|
+ raise ValueError(f"Provider {provider} does not have provider_credential_schema")
|
|
|
+
|
|
|
# validate provider credential schema
|
|
|
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
|
|
filtered_credentials = validator.validate_and_filter(credentials)
|
|
@@ -83,8 +88,9 @@ class ModelProviderFactory:
|
|
|
|
|
|
return filtered_credentials
|
|
|
|
|
|
- def model_credentials_validate(self, provider: str, model_type: ModelType,
|
|
|
- model: str, credentials: dict) -> dict:
|
|
|
+ def model_credentials_validate(
|
|
|
+ self, *, provider: str, model_type: ModelType, model: str, credentials: dict
|
|
|
+ ) -> dict:
|
|
|
"""
|
|
|
Validate model credentials
|
|
|
|
|
@@ -103,6 +109,9 @@ class ModelProviderFactory:
|
|
|
# get model_credential_schema and validate credentials according to the rules
|
|
|
model_credential_schema = provider_schema.model_credential_schema
|
|
|
|
|
|
+ if not model_credential_schema:
|
|
|
+ raise ValueError(f"Provider {provider} does not have model_credential_schema")
|
|
|
+
|
|
|
# validate model credential schema
|
|
|
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
|
|
filtered_credentials = validator.validate_and_filter(credentials)
|
|
@@ -115,11 +124,13 @@ class ModelProviderFactory:
|
|
|
|
|
|
return filtered_credentials
|
|
|
|
|
|
- def get_models(self,
|
|
|
- provider: Optional[str] = None,
|
|
|
- model_type: Optional[ModelType] = None,
|
|
|
- provider_configs: Optional[list[ProviderConfig]] = None) \
|
|
|
- -> list[SimpleProviderEntity]:
|
|
|
+ def get_models(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ provider: Optional[str] = None,
|
|
|
+ model_type: Optional[ModelType] = None,
|
|
|
+ provider_configs: Optional[list[ProviderConfig]] = None,
|
|
|
+ ) -> list[SimpleProviderEntity]:
|
|
|
"""
|
|
|
Get all models for given model type
|
|
|
|
|
@@ -128,6 +139,8 @@ class ModelProviderFactory:
|
|
|
:param provider_configs: list of provider configs
|
|
|
:return: list of models
|
|
|
"""
|
|
|
+ provider_configs = provider_configs or []
|
|
|
+
|
|
|
# scan all providers
|
|
|
model_provider_extensions = self._get_model_provider_map()
|
|
|
|
|
@@ -184,7 +197,7 @@ class ModelProviderFactory:
|
|
|
# get the provider extension
|
|
|
model_provider_extension = model_provider_extensions.get(provider)
|
|
|
if not model_provider_extension:
|
|
|
- raise Exception(f'Invalid provider: {provider}')
|
|
|
+ raise Exception(f"Invalid provider: {provider}")
|
|
|
|
|
|
# get the provider instance
|
|
|
model_provider_instance = model_provider_extension.provider_instance
|
|
@@ -192,10 +205,22 @@ class ModelProviderFactory:
|
|
|
return model_provider_instance
|
|
|
|
|
|
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
|
|
|
+ """
|
|
|
+ Retrieves the model provider map.
|
|
|
+
|
|
|
+ This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
|
|
|
+ and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
|
|
|
+ available model providers.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ A dictionary containing the model provider map.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ None.
|
|
|
+ """
|
|
|
if self.model_provider_extensions:
|
|
|
return self.model_provider_extensions
|
|
|
|
|
|
-
|
|
|
# get the path of current classes
|
|
|
current_path = os.path.abspath(__file__)
|
|
|
model_providers_path = os.path.dirname(current_path)
|
|
@@ -204,8 +229,8 @@ class ModelProviderFactory:
|
|
|
model_provider_dir_paths = [
|
|
|
os.path.join(model_providers_path, model_provider_dir)
|
|
|
for model_provider_dir in os.listdir(model_providers_path)
|
|
|
- if not model_provider_dir.startswith('__')
|
|
|
- and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
|
|
+ if not model_provider_dir.startswith("__")
|
|
|
+ and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
|
|
]
|
|
|
|
|
|
# get _position.yaml file path
|
|
@@ -219,30 +244,33 @@ class ModelProviderFactory:
|
|
|
|
|
|
file_names = os.listdir(model_provider_dir_path)
|
|
|
|
|
|
- if (model_provider_name + '.py') not in file_names:
|
|
|
+ if (model_provider_name + ".py") not in file_names:
|
|
|
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
|
|
|
continue
|
|
|
|
|
|
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
|
|
- py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
|
|
|
+ py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
|
|
|
model_provider_class = load_single_subclass_from_source(
|
|
|
- module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
|
|
|
+ module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
|
|
|
script_path=py_path,
|
|
|
- parent_type=ModelProvider)
|
|
|
+ parent_type=ModelProvider,
|
|
|
+ )
|
|
|
|
|
|
if not model_provider_class:
|
|
|
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
|
|
continue
|
|
|
|
|
|
- if f'{model_provider_name}.yaml' not in file_names:
|
|
|
+ if f"{model_provider_name}.yaml" not in file_names:
|
|
|
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
|
|
continue
|
|
|
|
|
|
- model_providers.append(ModelProviderExtension(
|
|
|
- name=model_provider_name,
|
|
|
- provider_instance=model_provider_class(),
|
|
|
- position=position_map.get(model_provider_name)
|
|
|
- ))
|
|
|
+ model_providers.append(
|
|
|
+ ModelProviderExtension(
|
|
|
+ name=model_provider_name,
|
|
|
+ provider_instance=model_provider_class(),
|
|
|
+ position=position_map.get(model_provider_name),
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
|
|
|