|
@@ -1,6 +1,7 @@
|
|
import datetime
|
|
import datetime
|
|
import json
|
|
import json
|
|
import logging
|
|
import logging
|
|
|
|
+from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from collections.abc import Iterator
|
|
from json import JSONDecodeError
|
|
from json import JSONDecodeError
|
|
from typing import Optional
|
|
from typing import Optional
|
|
@@ -8,7 +9,12 @@ from typing import Optional
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
|
-from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
|
|
|
|
|
|
+from core.entities.provider_entities import (
|
|
|
|
+ CustomConfiguration,
|
|
|
|
+ ModelSettings,
|
|
|
|
+ SystemConfiguration,
|
|
|
|
+ SystemConfigurationStatus,
|
|
|
|
+)
|
|
from core.helper import encrypter
|
|
from core.helper import encrypter
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
|
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
|
@@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
|
from extensions.ext_database import db
|
|
from extensions.ext_database import db
|
|
-from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
|
|
|
|
|
|
+from models.provider import (
|
|
|
|
+ LoadBalancingModelConfig,
|
|
|
|
+ Provider,
|
|
|
|
+ ProviderModel,
|
|
|
|
+ ProviderModelSetting,
|
|
|
|
+ ProviderType,
|
|
|
|
+ TenantPreferredModelProvider,
|
|
|
|
+)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
|
|
using_provider_type: ProviderType
|
|
using_provider_type: ProviderType
|
|
system_configuration: SystemConfiguration
|
|
system_configuration: SystemConfiguration
|
|
custom_configuration: CustomConfiguration
|
|
custom_configuration: CustomConfiguration
|
|
|
|
+ model_settings: list[ModelSettings]
|
|
|
|
|
|
def __init__(self, **data):
|
|
def __init__(self, **data):
|
|
super().__init__(**data)
|
|
super().__init__(**data)
|
|
@@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
|
|
:param model: model name
|
|
:param model: model name
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
|
|
+ if self.model_settings:
|
|
|
|
+ # check if model is disabled by admin
|
|
|
|
+ for model_setting in self.model_settings:
|
|
|
|
+ if (model_setting.model_type == model_type
|
|
|
|
+ and model_setting.model == model):
|
|
|
|
+ if not model_setting.enabled:
|
|
|
|
+ raise ValueError(f'Model {model} is disabled.')
|
|
|
|
+
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
restrict_models = []
|
|
restrict_models = []
|
|
for quota_configuration in self.system_configuration.quota_configurations:
|
|
for quota_configuration in self.system_configuration.quota_configurations:
|
|
@@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
return copy_credentials
|
|
return copy_credentials
|
|
else:
|
|
else:
|
|
|
|
+ credentials = None
|
|
if self.custom_configuration.models:
|
|
if self.custom_configuration.models:
|
|
for model_configuration in self.custom_configuration.models:
|
|
for model_configuration in self.custom_configuration.models:
|
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
|
- return model_configuration.credentials
|
|
|
|
|
|
+ credentials = model_configuration.credentials
|
|
|
|
+ break
|
|
|
|
|
|
if self.custom_configuration.provider:
|
|
if self.custom_configuration.provider:
|
|
- return self.custom_configuration.provider.credentials
|
|
|
|
- else:
|
|
|
|
- return None
|
|
|
|
|
|
+ credentials = self.custom_configuration.provider.credentials
|
|
|
|
+
|
|
|
|
+ return credentials
|
|
|
|
|
|
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
|
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
|
"""
|
|
"""
|
|
@@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
|
|
return credentials
|
|
return credentials
|
|
|
|
|
|
# Obfuscate credentials
|
|
# Obfuscate credentials
|
|
- return self._obfuscated_credentials(
|
|
|
|
|
|
+ return self.obfuscated_credentials(
|
|
credentials=credentials,
|
|
credentials=credentials,
|
|
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
|
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
|
if self.provider.provider_credential_schema else []
|
|
if self.provider.provider_credential_schema else []
|
|
@@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
|
|
).first()
|
|
).first()
|
|
|
|
|
|
# Get provider credential secret variables
|
|
# Get provider credential secret variables
|
|
- provider_credential_secret_variables = self._extract_secret_variables(
|
|
|
|
|
|
+ provider_credential_secret_variables = self.extract_secret_variables(
|
|
self.provider.provider_credential_schema.credential_form_schemas
|
|
self.provider.provider_credential_schema.credential_form_schemas
|
|
if self.provider.provider_credential_schema else []
|
|
if self.provider.provider_credential_schema else []
|
|
)
|
|
)
|
|
@@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
|
return credentials
|
|
return credentials
|
|
|
|
|
|
# Obfuscate credentials
|
|
# Obfuscate credentials
|
|
- return self._obfuscated_credentials(
|
|
|
|
|
|
+ return self.obfuscated_credentials(
|
|
credentials=credentials,
|
|
credentials=credentials,
|
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
|
if self.provider.model_credential_schema else []
|
|
if self.provider.model_credential_schema else []
|
|
@@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
|
|
).first()
|
|
).first()
|
|
|
|
|
|
# Get provider credential secret variables
|
|
# Get provider credential secret variables
|
|
- provider_credential_secret_variables = self._extract_secret_variables(
|
|
|
|
|
|
+ provider_credential_secret_variables = self.extract_secret_variables(
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
if self.provider.model_credential_schema else []
|
|
if self.provider.model_credential_schema else []
|
|
)
|
|
)
|
|
@@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
provider_model_credentials_cache.delete()
|
|
|
|
|
|
|
|
+ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
|
+ """
|
|
|
|
+ Enable model.
|
|
|
|
+ :param model_type: model type
|
|
|
|
+ :param model: model name
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ model_setting = db.session.query(ProviderModelSetting) \
|
|
|
|
+ .filter(
|
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
+ ProviderModelSetting.model_name == model
|
|
|
|
+ ).first()
|
|
|
|
+
|
|
|
|
+ if model_setting:
|
|
|
|
+ model_setting.enabled = True
|
|
|
|
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
|
+ db.session.commit()
|
|
|
|
+ else:
|
|
|
|
+ model_setting = ProviderModelSetting(
|
|
|
|
+ tenant_id=self.tenant_id,
|
|
|
|
+ provider_name=self.provider.provider,
|
|
|
|
+ model_type=model_type.to_origin_model_type(),
|
|
|
|
+ model_name=model,
|
|
|
|
+ enabled=True
|
|
|
|
+ )
|
|
|
|
+ db.session.add(model_setting)
|
|
|
|
+ db.session.commit()
|
|
|
|
+
|
|
|
|
+ return model_setting
|
|
|
|
+
|
|
|
|
+ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
|
+ """
|
|
|
|
+ Disable model.
|
|
|
|
+ :param model_type: model type
|
|
|
|
+ :param model: model name
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ model_setting = db.session.query(ProviderModelSetting) \
|
|
|
|
+ .filter(
|
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
+ ProviderModelSetting.model_name == model
|
|
|
|
+ ).first()
|
|
|
|
+
|
|
|
|
+ if model_setting:
|
|
|
|
+ model_setting.enabled = False
|
|
|
|
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
|
+ db.session.commit()
|
|
|
|
+ else:
|
|
|
|
+ model_setting = ProviderModelSetting(
|
|
|
|
+ tenant_id=self.tenant_id,
|
|
|
|
+ provider_name=self.provider.provider,
|
|
|
|
+ model_type=model_type.to_origin_model_type(),
|
|
|
|
+ model_name=model,
|
|
|
|
+ enabled=False
|
|
|
|
+ )
|
|
|
|
+ db.session.add(model_setting)
|
|
|
|
+ db.session.commit()
|
|
|
|
+
|
|
|
|
+ return model_setting
|
|
|
|
+
|
|
|
|
+ def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
|
|
|
|
+ """
|
|
|
|
+ Get provider model setting.
|
|
|
|
+ :param model_type: model type
|
|
|
|
+ :param model: model name
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ return db.session.query(ProviderModelSetting) \
|
|
|
|
+ .filter(
|
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
+ ProviderModelSetting.model_name == model
|
|
|
|
+ ).first()
|
|
|
|
+
|
|
|
|
+ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
|
+ """
|
|
|
|
+ Enable model load balancing.
|
|
|
|
+ :param model_type: model type
|
|
|
|
+ :param model: model name
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
|
|
|
|
+ .filter(
|
|
|
|
+ LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
|
+ LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
|
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
|
+ LoadBalancingModelConfig.model_name == model
|
|
|
|
+ ).count()
|
|
|
|
+
|
|
|
|
+ if load_balancing_config_count <= 1:
|
|
|
|
+ raise ValueError('Model load balancing configuration must be more than 1.')
|
|
|
|
+
|
|
|
|
+ model_setting = db.session.query(ProviderModelSetting) \
|
|
|
|
+ .filter(
|
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
+ ProviderModelSetting.model_name == model
|
|
|
|
+ ).first()
|
|
|
|
+
|
|
|
|
+ if model_setting:
|
|
|
|
+ model_setting.load_balancing_enabled = True
|
|
|
|
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
|
+ db.session.commit()
|
|
|
|
+ else:
|
|
|
|
+ model_setting = ProviderModelSetting(
|
|
|
|
+ tenant_id=self.tenant_id,
|
|
|
|
+ provider_name=self.provider.provider,
|
|
|
|
+ model_type=model_type.to_origin_model_type(),
|
|
|
|
+ model_name=model,
|
|
|
|
+ load_balancing_enabled=True
|
|
|
|
+ )
|
|
|
|
+ db.session.add(model_setting)
|
|
|
|
+ db.session.commit()
|
|
|
|
+
|
|
|
|
+ return model_setting
|
|
|
|
+
|
|
|
|
+ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
|
+ """
|
|
|
|
+ Disable model load balancing.
|
|
|
|
+ :param model_type: model type
|
|
|
|
+ :param model: model name
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ model_setting = db.session.query(ProviderModelSetting) \
|
|
|
|
+ .filter(
|
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
+ ProviderModelSetting.model_name == model
|
|
|
|
+ ).first()
|
|
|
|
+
|
|
|
|
+ if model_setting:
|
|
|
|
+ model_setting.load_balancing_enabled = False
|
|
|
|
+ model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
|
+ db.session.commit()
|
|
|
|
+ else:
|
|
|
|
+ model_setting = ProviderModelSetting(
|
|
|
|
+ tenant_id=self.tenant_id,
|
|
|
|
+ provider_name=self.provider.provider,
|
|
|
|
+ model_type=model_type.to_origin_model_type(),
|
|
|
|
+ model_name=model,
|
|
|
|
+ load_balancing_enabled=False
|
|
|
|
+ )
|
|
|
|
+ db.session.add(model_setting)
|
|
|
|
+ db.session.commit()
|
|
|
|
+
|
|
|
|
+ return model_setting
|
|
|
|
+
|
|
def get_provider_instance(self) -> ModelProvider:
|
|
def get_provider_instance(self) -> ModelProvider:
|
|
"""
|
|
"""
|
|
Get provider instance.
|
|
Get provider instance.
|
|
@@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
db.session.commit()
|
|
db.session.commit()
|
|
|
|
|
|
- def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
|
|
|
|
|
+ def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
|
"""
|
|
"""
|
|
Extract secret input form variables.
|
|
Extract secret input form variables.
|
|
|
|
|
|
@@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
return secret_input_form_variables
|
|
return secret_input_form_variables
|
|
|
|
|
|
- def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
|
|
|
|
|
+ def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
|
"""
|
|
"""
|
|
Obfuscated credentials.
|
|
Obfuscated credentials.
|
|
|
|
|
|
@@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
# Get provider credential secret variables
|
|
# Get provider credential secret variables
|
|
- credential_secret_variables = self._extract_secret_variables(
|
|
|
|
|
|
+ credential_secret_variables = self.extract_secret_variables(
|
|
credential_form_schemas
|
|
credential_form_schemas
|
|
)
|
|
)
|
|
|
|
|
|
@@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
|
|
else:
|
|
else:
|
|
model_types = provider_instance.get_provider_schema().supported_model_types
|
|
model_types = provider_instance.get_provider_schema().supported_model_types
|
|
|
|
|
|
|
|
+ # Group model settings by model type and model
|
|
|
|
+ model_setting_map = defaultdict(dict)
|
|
|
|
+ for model_setting in self.model_settings:
|
|
|
|
+ model_setting_map[model_setting.model_type][model_setting.model] = model_setting
|
|
|
|
+
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
provider_models = self._get_system_provider_models(
|
|
provider_models = self._get_system_provider_models(
|
|
model_types=model_types,
|
|
model_types=model_types,
|
|
- provider_instance=provider_instance
|
|
|
|
|
|
+ provider_instance=provider_instance,
|
|
|
|
+ model_setting_map=model_setting_map
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
provider_models = self._get_custom_provider_models(
|
|
provider_models = self._get_custom_provider_models(
|
|
model_types=model_types,
|
|
model_types=model_types,
|
|
- provider_instance=provider_instance
|
|
|
|
|
|
+ provider_instance=provider_instance,
|
|
|
|
+ model_setting_map=model_setting_map
|
|
)
|
|
)
|
|
|
|
|
|
if only_active:
|
|
if only_active:
|
|
@@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
def _get_system_provider_models(self,
|
|
def _get_system_provider_models(self,
|
|
model_types: list[ModelType],
|
|
model_types: list[ModelType],
|
|
- provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
|
|
|
|
|
+ provider_instance: ModelProvider,
|
|
|
|
+ model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
|
|
|
+ -> list[ModelWithProviderEntity]:
|
|
"""
|
|
"""
|
|
Get system provider models.
|
|
Get system provider models.
|
|
|
|
|
|
:param model_types: model types
|
|
:param model_types: model types
|
|
:param provider_instance: provider instance
|
|
:param provider_instance: provider instance
|
|
|
|
+ :param model_setting_map: model setting map
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
provider_models = []
|
|
provider_models = []
|
|
for model_type in model_types:
|
|
for model_type in model_types:
|
|
- provider_models.extend(
|
|
|
|
- [
|
|
|
|
|
|
+ for m in provider_instance.models(model_type):
|
|
|
|
+ status = ModelStatus.ACTIVE
|
|
|
|
+ if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
|
|
|
+ model_setting = model_setting_map[m.model_type][m.model]
|
|
|
|
+ if model_setting.enabled is False:
|
|
|
|
+ status = ModelStatus.DISABLED
|
|
|
|
+
|
|
|
|
+ provider_models.append(
|
|
ModelWithProviderEntity(
|
|
ModelWithProviderEntity(
|
|
model=m.model,
|
|
model=m.model,
|
|
label=m.label,
|
|
label=m.label,
|
|
@@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel):
|
|
model_properties=m.model_properties,
|
|
model_properties=m.model_properties,
|
|
deprecated=m.deprecated,
|
|
deprecated=m.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
- status=ModelStatus.ACTIVE
|
|
|
|
|
|
+ status=status
|
|
)
|
|
)
|
|
- for m in provider_instance.models(model_type)
|
|
|
|
- ]
|
|
|
|
- )
|
|
|
|
|
|
+ )
|
|
|
|
|
|
if self.provider.provider not in original_provider_configurate_methods:
|
|
if self.provider.provider not in original_provider_configurate_methods:
|
|
original_provider_configurate_methods[self.provider.provider] = []
|
|
original_provider_configurate_methods[self.provider.provider] = []
|
|
@@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
|
|
break
|
|
break
|
|
|
|
|
|
if should_use_custom_model:
|
|
if should_use_custom_model:
|
|
- if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
|
|
|
|
+ if original_provider_configurate_methods[self.provider.provider] == [
|
|
|
|
+ ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
# only customizable model
|
|
# only customizable model
|
|
for restrict_model in restrict_models:
|
|
for restrict_model in restrict_models:
|
|
copy_credentials = self.system_configuration.credentials.copy()
|
|
copy_credentials = self.system_configuration.credentials.copy()
|
|
@@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
|
|
if custom_model_schema.model_type not in model_types:
|
|
if custom_model_schema.model_type not in model_types:
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
+ status = ModelStatus.ACTIVE
|
|
|
|
+ if (custom_model_schema.model_type in model_setting_map
|
|
|
|
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
|
|
|
+ model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
|
|
|
+ if model_setting.enabled is False:
|
|
|
|
+ status = ModelStatus.DISABLED
|
|
|
|
+
|
|
provider_models.append(
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
ModelWithProviderEntity(
|
|
model=custom_model_schema.model,
|
|
model=custom_model_schema.model,
|
|
@@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
|
|
model_properties=custom_model_schema.model_properties,
|
|
model_properties=custom_model_schema.model_properties,
|
|
deprecated=custom_model_schema.deprecated,
|
|
deprecated=custom_model_schema.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
- status=ModelStatus.ACTIVE
|
|
|
|
|
|
+ status=status
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
@@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
|
|
m.status = ModelStatus.NO_PERMISSION
|
|
m.status = ModelStatus.NO_PERMISSION
|
|
elif not quota_configuration.is_valid:
|
|
elif not quota_configuration.is_valid:
|
|
m.status = ModelStatus.QUOTA_EXCEEDED
|
|
m.status = ModelStatus.QUOTA_EXCEEDED
|
|
|
|
+
|
|
return provider_models
|
|
return provider_models
|
|
|
|
|
|
def _get_custom_provider_models(self,
|
|
def _get_custom_provider_models(self,
|
|
model_types: list[ModelType],
|
|
model_types: list[ModelType],
|
|
- provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
|
|
|
|
|
+ provider_instance: ModelProvider,
|
|
|
|
+ model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
|
|
|
+ -> list[ModelWithProviderEntity]:
|
|
"""
|
|
"""
|
|
Get custom provider models.
|
|
Get custom provider models.
|
|
|
|
|
|
:param model_types: model types
|
|
:param model_types: model types
|
|
:param provider_instance: provider instance
|
|
:param provider_instance: provider instance
|
|
|
|
+ :param model_setting_map: model setting map
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
provider_models = []
|
|
provider_models = []
|
|
@@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
models = provider_instance.models(model_type)
|
|
models = provider_instance.models(model_type)
|
|
for m in models:
|
|
for m in models:
|
|
|
|
+ status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
|
|
|
+ load_balancing_enabled = False
|
|
|
|
+ if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
|
|
|
+ model_setting = model_setting_map[m.model_type][m.model]
|
|
|
|
+ if model_setting.enabled is False:
|
|
|
|
+ status = ModelStatus.DISABLED
|
|
|
|
+
|
|
|
|
+ if len(model_setting.load_balancing_configs) > 1:
|
|
|
|
+ load_balancing_enabled = True
|
|
|
|
+
|
|
provider_models.append(
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
ModelWithProviderEntity(
|
|
model=m.model,
|
|
model=m.model,
|
|
@@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
|
|
model_properties=m.model_properties,
|
|
model_properties=m.model_properties,
|
|
deprecated=m.deprecated,
|
|
deprecated=m.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
- status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
|
|
|
|
|
+ status=status,
|
|
|
|
+ load_balancing_enabled=load_balancing_enabled
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
@@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
|
|
if not custom_model_schema:
|
|
if not custom_model_schema:
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
+ status = ModelStatus.ACTIVE
|
|
|
|
+ load_balancing_enabled = False
|
|
|
|
+ if (custom_model_schema.model_type in model_setting_map
|
|
|
|
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
|
|
|
+ model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
|
|
|
+ if model_setting.enabled is False:
|
|
|
|
+ status = ModelStatus.DISABLED
|
|
|
|
+
|
|
|
|
+ if len(model_setting.load_balancing_configs) > 1:
|
|
|
|
+ load_balancing_enabled = True
|
|
|
|
+
|
|
provider_models.append(
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
ModelWithProviderEntity(
|
|
model=custom_model_schema.model,
|
|
model=custom_model_schema.model,
|
|
@@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
|
|
model_properties=custom_model_schema.model_properties,
|
|
model_properties=custom_model_schema.model_properties,
|
|
deprecated=custom_model_schema.deprecated,
|
|
deprecated=custom_model_schema.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
- status=ModelStatus.ACTIVE
|
|
|
|
|
|
+ status=status,
|
|
|
|
+ load_balancing_enabled=load_balancing_enabled
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
|