|
@@ -229,11 +229,18 @@ class ProviderManager:
|
|
|
return None
|
|
|
|
|
|
provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name)
|
|
|
+ provider_schema = provider_instance.get_provider_schema()
|
|
|
|
|
|
return DefaultModelEntity(
|
|
|
model=default_model.model_name,
|
|
|
model_type=model_type,
|
|
|
- provider=DefaultModelProviderEntity(**provider_instance.get_provider_schema().to_simple_provider().dict())
|
|
|
+ provider=DefaultModelProviderEntity(
|
|
|
+ provider=provider_schema.provider,
|
|
|
+ label=provider_schema.label,
|
|
|
+ icon_small=provider_schema.icon_small,
|
|
|
+ icon_large=provider_schema.icon_large,
|
|
|
+ supported_model_types=provider_schema.supported_model_types
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|