|
@@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
"""
|
|
|
Model class for provider configuration.
|
|
|
"""
|
|
|
+
|
|
|
tenant_id: str
|
|
|
provider: ProviderEntity
|
|
|
preferred_provider_type: ProviderType
|
|
@@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
|
|
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
|
|
|
|
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
|
- if (any(len(quota_configuration.restrict_models) > 0
|
|
|
- for quota_configuration in self.system_configuration.quota_configurations)
|
|
|
- and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
|
|
+ if (
|
|
|
+ any(
|
|
|
+ len(quota_configuration.restrict_models) > 0
|
|
|
+ for quota_configuration in self.system_configuration.quota_configurations
|
|
|
+ )
|
|
|
+ and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
|
|
+ ):
|
|
|
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
|
|
|
|
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
|
@@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
if self.model_settings:
|
|
|
# check if model is disabled by admin
|
|
|
for model_setting in self.model_settings:
|
|
|
- if (model_setting.model_type == model_type
|
|
|
- and model_setting.model == model):
|
|
|
+ if model_setting.model_type == model_type and model_setting.model == model:
|
|
|
if not model_setting.enabled:
|
|
|
- raise ValueError(f'Model {model} is disabled.')
|
|
|
+ raise ValueError(f"Model {model} is disabled.")
|
|
|
|
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
|
restrict_models = []
|
|
@@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
copy_credentials = self.system_configuration.credentials.copy()
|
|
|
if restrict_models:
|
|
|
for restrict_model in restrict_models:
|
|
|
- if (restrict_model.model_type == model_type
|
|
|
- and restrict_model.model == model
|
|
|
- and restrict_model.base_model_name):
|
|
|
- copy_credentials['base_model_name'] = restrict_model.base_model_name
|
|
|
+ if (
|
|
|
+ restrict_model.model_type == model_type
|
|
|
+ and restrict_model.model == model
|
|
|
+ and restrict_model.base_model_name
|
|
|
+ ):
|
|
|
+ copy_credentials["base_model_name"] = restrict_model.base_model_name
|
|
|
|
|
|
return copy_credentials
|
|
|
else:
|
|
@@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
current_quota_type = self.system_configuration.current_quota_type
|
|
|
current_quota_configuration = next(
|
|
|
- (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
|
|
|
- None
|
|
|
+ (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
|
|
|
)
|
|
|
|
|
|
- return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
|
|
|
- SystemConfigurationStatus.QUOTA_EXCEEDED
|
|
|
+ return (
|
|
|
+ SystemConfigurationStatus.ACTIVE
|
|
|
+ if current_quota_configuration.is_valid
|
|
|
+ else SystemConfigurationStatus.QUOTA_EXCEEDED
|
|
|
+ )
|
|
|
|
|
|
def is_custom_configuration_available(self) -> bool:
|
|
|
"""
|
|
|
Check custom configuration available.
|
|
|
:return:
|
|
|
"""
|
|
|
- return (self.custom_configuration.provider is not None
|
|
|
- or len(self.custom_configuration.models) > 0)
|
|
|
+ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
|
|
|
|
|
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
|
|
"""
|
|
@@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
|
|
|
return self.obfuscated_credentials(
|
|
|
credentials=credentials,
|
|
|
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
|
|
- if self.provider.provider_credential_schema else []
|
|
|
+ if self.provider.provider_credential_schema
|
|
|
+ else [],
|
|
|
)
|
|
|
|
|
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
|
@@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get provider
|
|
|
- provider_record = db.session.query(Provider) \
|
|
|
+ provider_record = (
|
|
|
+ db.session.query(Provider)
|
|
|
.filter(
|
|
|
- Provider.tenant_id == self.tenant_id,
|
|
|
- Provider.provider_name == self.provider.provider,
|
|
|
- Provider.provider_type == ProviderType.CUSTOM.value
|
|
|
- ).first()
|
|
|
+ Provider.tenant_id == self.tenant_id,
|
|
|
+ Provider.provider_name == self.provider.provider,
|
|
|
+ Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
# Get provider credential secret variables
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
self.provider.provider_credential_schema.credential_form_schemas
|
|
|
- if self.provider.provider_credential_schema else []
|
|
|
+ if self.provider.provider_credential_schema
|
|
|
+ else []
|
|
|
)
|
|
|
|
|
|
if provider_record:
|
|
@@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
# fix origin data
|
|
|
if provider_record.encrypted_config:
|
|
|
if not provider_record.encrypted_config.startswith("{"):
|
|
|
- original_credentials = {
|
|
|
- "openai_api_key": provider_record.encrypted_config
|
|
|
- }
|
|
|
+ original_credentials = {"openai_api_key": provider_record.encrypted_config}
|
|
|
else:
|
|
|
original_credentials = json.loads(provider_record.encrypted_config)
|
|
|
else:
|
|
@@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
|
|
|
|
|
credentials = model_provider_factory.provider_credentials_validate(
|
|
|
- provider=self.provider.provider,
|
|
|
- credentials=credentials
|
|
|
+ provider=self.provider.provider, credentials=credentials
|
|
|
)
|
|
|
|
|
|
for key, value in credentials.items():
|
|
@@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_name=self.provider.provider,
|
|
|
provider_type=ProviderType.CUSTOM.value,
|
|
|
encrypted_config=json.dumps(credentials),
|
|
|
- is_valid=True
|
|
|
+ is_valid=True,
|
|
|
)
|
|
|
db.session.add(provider_record)
|
|
|
db.session.commit()
|
|
|
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- identity_id=provider_record.id,
|
|
|
- cache_type=ProviderCredentialsCacheType.PROVIDER
|
|
|
+ tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
|
|
|
)
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
@@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get provider
|
|
|
- provider_record = db.session.query(Provider) \
|
|
|
+ provider_record = (
|
|
|
+ db.session.query(Provider)
|
|
|
.filter(
|
|
|
- Provider.tenant_id == self.tenant_id,
|
|
|
- Provider.provider_name == self.provider.provider,
|
|
|
- Provider.provider_type == ProviderType.CUSTOM.value
|
|
|
- ).first()
|
|
|
+ Provider.tenant_id == self.tenant_id,
|
|
|
+ Provider.provider_name == self.provider.provider,
|
|
|
+ Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
# delete provider
|
|
|
if provider_record:
|
|
@@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
|
tenant_id=self.tenant_id,
|
|
|
identity_id=provider_record.id,
|
|
|
- cache_type=ProviderCredentialsCacheType.PROVIDER
|
|
|
+ cache_type=ProviderCredentialsCacheType.PROVIDER,
|
|
|
)
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
|
|
- def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
|
|
- -> Optional[dict]:
|
|
|
+ def get_custom_model_credentials(
|
|
|
+ self, model_type: ModelType, model: str, obfuscated: bool = False
|
|
|
+ ) -> Optional[dict]:
|
|
|
"""
|
|
|
Get custom model credentials.
|
|
|
|
|
@@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
|
|
|
return self.obfuscated_credentials(
|
|
|
credentials=credentials,
|
|
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
|
|
- if self.provider.model_credential_schema else []
|
|
|
+ if self.provider.model_credential_schema
|
|
|
+ else [],
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
|
|
- -> tuple[ProviderModel, dict]:
|
|
|
+ def custom_model_credentials_validate(
|
|
|
+ self, model_type: ModelType, model: str, credentials: dict
|
|
|
+ ) -> tuple[ProviderModel, dict]:
|
|
|
"""
|
|
|
Validate custom model credentials.
|
|
|
|
|
@@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get provider model
|
|
|
- provider_model_record = db.session.query(ProviderModel) \
|
|
|
+ provider_model_record = (
|
|
|
+ db.session.query(ProviderModel)
|
|
|
.filter(
|
|
|
- ProviderModel.tenant_id == self.tenant_id,
|
|
|
- ProviderModel.provider_name == self.provider.provider,
|
|
|
- ProviderModel.model_name == model,
|
|
|
- ProviderModel.model_type == model_type.to_origin_model_type()
|
|
|
- ).first()
|
|
|
+ ProviderModel.tenant_id == self.tenant_id,
|
|
|
+ ProviderModel.provider_name == self.provider.provider,
|
|
|
+ ProviderModel.model_name == model,
|
|
|
+ ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
# Get provider credential secret variables
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
|
- if self.provider.model_credential_schema else []
|
|
|
+ if self.provider.model_credential_schema
|
|
|
+ else []
|
|
|
)
|
|
|
|
|
|
if provider_model_record:
|
|
|
try:
|
|
|
- original_credentials = json.loads(
|
|
|
- provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
|
|
+ original_credentials = (
|
|
|
+ json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
|
|
+ )
|
|
|
except JSONDecodeError:
|
|
|
original_credentials = {}
|
|
|
|
|
@@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
|
|
|
|
|
credentials = model_provider_factory.model_credentials_validate(
|
|
|
- provider=self.provider.provider,
|
|
|
- model_type=model_type,
|
|
|
- model=model,
|
|
|
- credentials=credentials
|
|
|
+ provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
|
|
)
|
|
|
|
|
|
for key, value in credentials.items():
|
|
@@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_name=model,
|
|
|
model_type=model_type.to_origin_model_type(),
|
|
|
encrypted_config=json.dumps(credentials),
|
|
|
- is_valid=True
|
|
|
+ is_valid=True,
|
|
|
)
|
|
|
db.session.add(provider_model_record)
|
|
|
db.session.commit()
|
|
@@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
|
tenant_id=self.tenant_id,
|
|
|
identity_id=provider_model_record.id,
|
|
|
- cache_type=ProviderCredentialsCacheType.MODEL
|
|
|
+ cache_type=ProviderCredentialsCacheType.MODEL,
|
|
|
)
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
@@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get provider model
|
|
|
- provider_model_record = db.session.query(ProviderModel) \
|
|
|
+ provider_model_record = (
|
|
|
+ db.session.query(ProviderModel)
|
|
|
.filter(
|
|
|
- ProviderModel.tenant_id == self.tenant_id,
|
|
|
- ProviderModel.provider_name == self.provider.provider,
|
|
|
- ProviderModel.model_name == model,
|
|
|
- ProviderModel.model_type == model_type.to_origin_model_type()
|
|
|
- ).first()
|
|
|
+ ProviderModel.tenant_id == self.tenant_id,
|
|
|
+ ProviderModel.provider_name == self.provider.provider,
|
|
|
+ ProviderModel.model_name == model,
|
|
|
+ ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
# delete provider model
|
|
|
if provider_model_record:
|
|
@@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
|
tenant_id=self.tenant_id,
|
|
|
identity_id=provider_model_record.id,
|
|
|
- cache_type=ProviderCredentialsCacheType.MODEL
|
|
|
+ cache_type=ProviderCredentialsCacheType.MODEL,
|
|
|
)
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
@@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
- model_setting = db.session.query(ProviderModelSetting) \
|
|
|
+ model_setting = (
|
|
|
+ db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model
|
|
|
- ).first()
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
+ ProviderModelSetting.model_name == model,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
if model_setting:
|
|
|
model_setting.enabled = True
|
|
@@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_name=self.provider.provider,
|
|
|
model_type=model_type.to_origin_model_type(),
|
|
|
model_name=model,
|
|
|
- enabled=True
|
|
|
+ enabled=True,
|
|
|
)
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
@@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
- model_setting = db.session.query(ProviderModelSetting) \
|
|
|
+ model_setting = (
|
|
|
+ db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model
|
|
|
- ).first()
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
+ ProviderModelSetting.model_name == model,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
if model_setting:
|
|
|
model_setting.enabled = False
|
|
@@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_name=self.provider.provider,
|
|
|
model_type=model_type.to_origin_model_type(),
|
|
|
model_name=model,
|
|
|
- enabled=False
|
|
|
+ enabled=False,
|
|
|
)
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
@@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
- return db.session.query(ProviderModelSetting) \
|
|
|
+ return (
|
|
|
+ db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model
|
|
|
- ).first()
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
+ ProviderModelSetting.model_name == model,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
"""
|
|
@@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
- load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
|
|
|
+ load_balancing_config_count = (
|
|
|
+ db.session.query(LoadBalancingModelConfig)
|
|
|
.filter(
|
|
|
- LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
- LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
- LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
- LoadBalancingModelConfig.model_name == model
|
|
|
- ).count()
|
|
|
+ LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
+ LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
+ LoadBalancingModelConfig.model_name == model,
|
|
|
+ )
|
|
|
+ .count()
|
|
|
+ )
|
|
|
|
|
|
if load_balancing_config_count <= 1:
|
|
|
- raise ValueError('Model load balancing configuration must be more than 1.')
|
|
|
+ raise ValueError("Model load balancing configuration must be more than 1.")
|
|
|
|
|
|
- model_setting = db.session.query(ProviderModelSetting) \
|
|
|
+ model_setting = (
|
|
|
+ db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model
|
|
|
- ).first()
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
+ ProviderModelSetting.model_name == model,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
if model_setting:
|
|
|
model_setting.load_balancing_enabled = True
|
|
@@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_name=self.provider.provider,
|
|
|
model_type=model_type.to_origin_model_type(),
|
|
|
model_name=model,
|
|
|
- load_balancing_enabled=True
|
|
|
+ load_balancing_enabled=True,
|
|
|
)
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
@@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
- model_setting = db.session.query(ProviderModelSetting) \
|
|
|
+ model_setting = (
|
|
|
+ db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model
|
|
|
- ).first()
|
|
|
+ ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
+ ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
+ ProviderModelSetting.model_name == model,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
if model_setting:
|
|
|
model_setting.load_balancing_enabled = False
|
|
@@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_name=self.provider.provider,
|
|
|
model_type=model_type.to_origin_model_type(),
|
|
|
model_name=model,
|
|
|
- load_balancing_enabled=False
|
|
|
+ load_balancing_enabled=False,
|
|
|
)
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
@@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
|
|
|
return
|
|
|
|
|
|
# get preferred provider
|
|
|
- preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
|
|
+ preferred_model_provider = (
|
|
|
+ db.session.query(TenantPreferredModelProvider)
|
|
|
.filter(
|
|
|
- TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
|
|
- TenantPreferredModelProvider.provider_name == self.provider.provider
|
|
|
- ).first()
|
|
|
+ TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
|
|
+ TenantPreferredModelProvider.provider_name == self.provider.provider,
|
|
|
+ )
|
|
|
+ .first()
|
|
|
+ )
|
|
|
|
|
|
if preferred_model_provider:
|
|
|
preferred_model_provider.preferred_provider_type = provider_type.value
|
|
@@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
preferred_model_provider = TenantPreferredModelProvider(
|
|
|
tenant_id=self.tenant_id,
|
|
|
provider_name=self.provider.provider,
|
|
|
- preferred_provider_type=provider_type.value
|
|
|
+ preferred_provider_type=provider_type.value,
|
|
|
)
|
|
|
db.session.add(preferred_model_provider)
|
|
|
|
|
@@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# Get provider credential secret variables
|
|
|
- credential_secret_variables = self.extract_secret_variables(
|
|
|
- credential_form_schemas
|
|
|
- )
|
|
|
+ credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
|
|
|
|
|
|
# Obfuscate provider credentials
|
|
|
copy_credentials = credentials.copy()
|
|
@@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
return copy_credentials
|
|
|
|
|
|
- def get_provider_model(self, model_type: ModelType,
|
|
|
- model: str,
|
|
|
- only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
|
|
+ def get_provider_model(
|
|
|
+ self, model_type: ModelType, model: str, only_active: bool = False
|
|
|
+ ) -> Optional[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get provider model.
|
|
|
:param model_type: model type
|
|
@@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def get_provider_models(self, model_type: Optional[ModelType] = None,
|
|
|
- only_active: bool = False) -> list[ModelWithProviderEntity]:
|
|
|
+ def get_provider_models(
|
|
|
+ self, model_type: Optional[ModelType] = None, only_active: bool = False
|
|
|
+ ) -> list[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get provider models.
|
|
|
:param model_type: model type
|
|
@@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
|
provider_models = self._get_system_provider_models(
|
|
|
- model_types=model_types,
|
|
|
- provider_instance=provider_instance,
|
|
|
- model_setting_map=model_setting_map
|
|
|
+ model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
|
|
)
|
|
|
else:
|
|
|
provider_models = self._get_custom_provider_models(
|
|
|
- model_types=model_types,
|
|
|
- provider_instance=provider_instance,
|
|
|
- model_setting_map=model_setting_map
|
|
|
+ model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
|
|
)
|
|
|
|
|
|
if only_active:
|
|
@@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
# resort provider_models
|
|
|
return sorted(provider_models, key=lambda x: x.model_type.value)
|
|
|
|
|
|
- def _get_system_provider_models(self,
|
|
|
- model_types: list[ModelType],
|
|
|
- provider_instance: ModelProvider,
|
|
|
- model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
|
|
- -> list[ModelWithProviderEntity]:
|
|
|
+ def _get_system_provider_models(
|
|
|
+ self,
|
|
|
+ model_types: list[ModelType],
|
|
|
+ provider_instance: ModelProvider,
|
|
|
+ model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
|
+ ) -> list[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get system provider models.
|
|
|
|
|
@@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_properties=m.model_properties,
|
|
|
deprecated=m.deprecated,
|
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
|
- status=status
|
|
|
+ status=status,
|
|
|
)
|
|
|
)
|
|
|
|
|
@@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
if should_use_custom_model:
|
|
|
if original_provider_configurate_methods[self.provider.provider] == [
|
|
|
- ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
|
+ ConfigurateMethod.CUSTOMIZABLE_MODEL
|
|
|
+ ]:
|
|
|
# only customizable model
|
|
|
for restrict_model in restrict_models:
|
|
|
copy_credentials = self.system_configuration.credentials.copy()
|
|
|
if restrict_model.base_model_name:
|
|
|
- copy_credentials['base_model_name'] = restrict_model.base_model_name
|
|
|
+ copy_credentials["base_model_name"] = restrict_model.base_model_name
|
|
|
|
|
|
try:
|
|
|
- custom_model_schema = (
|
|
|
- provider_instance.get_model_instance(restrict_model.model_type)
|
|
|
- .get_customizable_model_schema_from_credentials(
|
|
|
- restrict_model.model,
|
|
|
- copy_credentials
|
|
|
- )
|
|
|
- )
|
|
|
+ custom_model_schema = provider_instance.get_model_instance(
|
|
|
+ restrict_model.model_type
|
|
|
+ ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
|
|
except Exception as ex:
|
|
|
- logger.warning(f'get custom model schema failed, {ex}')
|
|
|
+ logger.warning(f"get custom model schema failed, {ex}")
|
|
|
continue
|
|
|
|
|
|
if not custom_model_schema:
|
|
@@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
continue
|
|
|
|
|
|
status = ModelStatus.ACTIVE
|
|
|
- if (custom_model_schema.model_type in model_setting_map
|
|
|
- and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
|
|
+ if (
|
|
|
+ custom_model_schema.model_type in model_setting_map
|
|
|
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
|
|
+ ):
|
|
|
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
|
|
if model_setting.enabled is False:
|
|
|
status = ModelStatus.DISABLED
|
|
@@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_properties=custom_model_schema.model_properties,
|
|
|
deprecated=custom_model_schema.deprecated,
|
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
|
- status=status
|
|
|
+ status=status,
|
|
|
)
|
|
|
)
|
|
|
|
|
@@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
return provider_models
|
|
|
|
|
|
- def _get_custom_provider_models(self,
|
|
|
- model_types: list[ModelType],
|
|
|
- provider_instance: ModelProvider,
|
|
|
- model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
|
|
- -> list[ModelWithProviderEntity]:
|
|
|
+ def _get_custom_provider_models(
|
|
|
+ self,
|
|
|
+ model_types: list[ModelType],
|
|
|
+ provider_instance: ModelProvider,
|
|
|
+ model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
|
+ ) -> list[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get custom provider models.
|
|
|
|
|
@@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
deprecated=m.deprecated,
|
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
|
status=status,
|
|
|
- load_balancing_enabled=load_balancing_enabled
|
|
|
+ load_balancing_enabled=load_balancing_enabled,
|
|
|
)
|
|
|
)
|
|
|
|
|
@@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
- custom_model_schema = (
|
|
|
- provider_instance.get_model_instance(model_configuration.model_type)
|
|
|
- .get_customizable_model_schema_from_credentials(
|
|
|
- model_configuration.model,
|
|
|
- model_configuration.credentials
|
|
|
- )
|
|
|
+ custom_model_schema = provider_instance.get_model_instance(
|
|
|
+ model_configuration.model_type
|
|
|
+ ).get_customizable_model_schema_from_credentials(
|
|
|
+ model_configuration.model, model_configuration.credentials
|
|
|
)
|
|
|
except Exception as ex:
|
|
|
- logger.warning(f'get custom model schema failed, {ex}')
|
|
|
+ logger.warning(f"get custom model schema failed, {ex}")
|
|
|
continue
|
|
|
|
|
|
if not custom_model_schema:
|
|
@@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
status = ModelStatus.ACTIVE
|
|
|
load_balancing_enabled = False
|
|
|
- if (custom_model_schema.model_type in model_setting_map
|
|
|
- and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
|
|
+ if (
|
|
|
+ custom_model_schema.model_type in model_setting_map
|
|
|
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
|
|
+ ):
|
|
|
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
|
|
if model_setting.enabled is False:
|
|
|
status = ModelStatus.DISABLED
|
|
@@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
deprecated=custom_model_schema.deprecated,
|
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
|
status=status,
|
|
|
- load_balancing_enabled=load_balancing_enabled
|
|
|
+ load_balancing_enabled=load_balancing_enabled,
|
|
|
)
|
|
|
)
|
|
|
|
|
@@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
|
|
|
"""
|
|
|
Model class for provider configuration dict.
|
|
|
"""
|
|
|
+
|
|
|
tenant_id: str
|
|
|
configurations: dict[str, ProviderConfiguration] = {}
|
|
|
|
|
|
def __init__(self, tenant_id: str):
|
|
|
super().__init__(tenant_id=tenant_id)
|
|
|
|
|
|
- def get_models(self,
|
|
|
- provider: Optional[str] = None,
|
|
|
- model_type: Optional[ModelType] = None,
|
|
|
- only_active: bool = False) \
|
|
|
- -> list[ModelWithProviderEntity]:
|
|
|
+ def get_models(
|
|
|
+ self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
|
|
|
+ ) -> list[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get available models.
|
|
|
|
|
@@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
|
|
|
"""
|
|
|
Provider model bundle.
|
|
|
"""
|
|
|
+
|
|
|
configuration: ProviderConfiguration
|
|
|
provider_instance: ModelProvider
|
|
|
model_type_instance: AIModel
|
|
|
|
|
|
# pydantic configs
|
|
|
- model_config = ConfigDict(arbitrary_types_allowed=True,
|
|
|
- protected_namespaces=())
|
|
|
+ model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|