|
@@ -17,11 +17,12 @@ from models.provider import Provider, ProviderModel, TenantPreferredModelProvide
|
|
|
|
|
|
class ProviderService:
|
|
|
|
|
|
- def get_provider_list(self, tenant_id: str):
|
|
|
+ def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
|
|
|
"""
|
|
|
get provider list of tenant.
|
|
|
|
|
|
- :param tenant_id:
|
|
|
+ :param tenant_id: workspace id
|
|
|
+ :param model_type: filter by model type
|
|
|
:return:
|
|
|
"""
|
|
|
# get rules for all providers
|
|
@@ -79,6 +80,9 @@ class ProviderService:
|
|
|
providers_list = {}
|
|
|
|
|
|
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
|
|
+ if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
|
|
|
+ continue
|
|
|
+
|
|
|
# get preferred provider type
|
|
|
preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
|
|
|
preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
|
|
@@ -90,6 +94,7 @@ class ProviderService:
|
|
|
provider_config_dict = {
|
|
|
"preferred_provider_type": preferred_provider_type,
|
|
|
"model_flexibility": model_provider_rule['model_flexibility'],
|
|
|
+ "supported_model_types": model_provider_rule.get("supported_model_types", []),
|
|
|
}
|
|
|
|
|
|
provider_parameter_dict = {}
|