|
@@ -7,7 +7,6 @@ from json import JSONDecodeError
|
|
|
from typing import Optional
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
-from sqlalchemy import or_
|
|
|
|
|
|
from constants import HIDDEN_VALUE
|
|
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
|
@@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
|
|
|
else [],
|
|
|
)
|
|
|
|
|
|
- def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
|
|
+ def _get_custom_provider_credentials(self) -> Provider | None:
|
|
|
"""
|
|
|
- Validate custom credentials.
|
|
|
- :param credentials: provider credentials
|
|
|
- :return:
|
|
|
+ Get custom provider credentials.
|
|
|
"""
|
|
|
# get provider
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
if model_provider_id.is_langgenius():
|
|
|
- provider_record = (
|
|
|
- db.session.query(Provider)
|
|
|
- .filter(
|
|
|
- Provider.tenant_id == self.tenant_id,
|
|
|
- Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
- or_(
|
|
|
- Provider.provider_name == model_provider_id.provider_name,
|
|
|
- Provider.provider_name == self.provider.provider,
|
|
|
- ),
|
|
|
- )
|
|
|
- .first()
|
|
|
- )
|
|
|
- else:
|
|
|
- provider_record = (
|
|
|
- db.session.query(Provider)
|
|
|
- .filter(
|
|
|
- Provider.tenant_id == self.tenant_id,
|
|
|
- Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
- Provider.provider_name == self.provider.provider,
|
|
|
- )
|
|
|
- .first()
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
+ provider_record = (
|
|
|
+ db.session.query(Provider)
|
|
|
+ .filter(
|
|
|
+ Provider.tenant_id == self.tenant_id,
|
|
|
+ Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
+ Provider.provider_name.in_(provider_names),
|
|
|
)
|
|
|
+ .first()
|
|
|
+ )
|
|
|
+
|
|
|
+ return provider_record
|
|
|
+
|
|
|
+ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
|
|
+ """
|
|
|
+ Validate custom credentials.
|
|
|
+ :param credentials: provider credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ provider_record = self._get_custom_provider_credentials()
|
|
|
|
|
|
# Get provider credential secret variables
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
@@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get provider
|
|
|
- provider_record = (
|
|
|
- db.session.query(Provider)
|
|
|
- .filter(
|
|
|
- Provider.tenant_id == self.tenant_id,
|
|
|
- or_(
|
|
|
- Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
|
|
- Provider.provider_name == self.provider.provider,
|
|
|
- ),
|
|
|
- Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
- )
|
|
|
- .first()
|
|
|
- )
|
|
|
+ provider_record = self._get_custom_provider_credentials()
|
|
|
|
|
|
# delete provider
|
|
|
if provider_record:
|
|
@@ -349,29 +335,47 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def custom_model_credentials_validate(
|
|
|
- self, model_type: ModelType, model: str, credentials: dict
|
|
|
- ) -> tuple[ProviderModel | None, dict]:
|
|
|
+ def _get_custom_model_credentials(
|
|
|
+ self,
|
|
|
+ model_type: ModelType,
|
|
|
+ model: str,
|
|
|
+ ) -> ProviderModel | None:
|
|
|
"""
|
|
|
- Validate custom model credentials.
|
|
|
-
|
|
|
- :param model_type: model type
|
|
|
- :param model: model name
|
|
|
- :param credentials: model credentials
|
|
|
- :return:
|
|
|
+ Get custom model credentials.
|
|
|
"""
|
|
|
# get provider model
|
|
|
+ model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
+ if model_provider_id.is_langgenius():
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
provider_model_record = (
|
|
|
db.session.query(ProviderModel)
|
|
|
.filter(
|
|
|
ProviderModel.tenant_id == self.tenant_id,
|
|
|
- ProviderModel.provider_name == self.provider.provider,
|
|
|
+ ProviderModel.provider_name.in_(provider_names),
|
|
|
ProviderModel.model_name == model,
|
|
|
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
|
)
|
|
|
.first()
|
|
|
)
|
|
|
|
|
|
+ return provider_model_record
|
|
|
+
|
|
|
+ def custom_model_credentials_validate(
|
|
|
+ self, model_type: ModelType, model: str, credentials: dict
|
|
|
+ ) -> tuple[ProviderModel | None, dict]:
|
|
|
+ """
|
|
|
+ Validate custom model credentials.
|
|
|
+
|
|
|
+ :param model_type: model type
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # get provider model
|
|
|
+ provider_model_record = self._get_custom_model_credentials(model_type, model)
|
|
|
+
|
|
|
# Get provider credential secret variables
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
@@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get provider model
|
|
|
- provider_model_record = (
|
|
|
- db.session.query(ProviderModel)
|
|
|
- .filter(
|
|
|
- ProviderModel.tenant_id == self.tenant_id,
|
|
|
- ProviderModel.provider_name == self.provider.provider,
|
|
|
- ProviderModel.model_name == model,
|
|
|
- ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
|
- )
|
|
|
- .first()
|
|
|
- )
|
|
|
+ provider_model_record = self._get_custom_model_credentials(model_type, model)
|
|
|
|
|
|
# delete provider model
|
|
|
if provider_model_record:
|
|
@@ -475,24 +470,35 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
|
|
- def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
+ def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
|
|
|
"""
|
|
|
- Enable model.
|
|
|
- :param model_type: model type
|
|
|
- :param model: model name
|
|
|
- :return:
|
|
|
+ Get provider model setting.
|
|
|
"""
|
|
|
- model_setting = (
|
|
|
+ model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
+ if model_provider_id.is_langgenius():
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
+ return (
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.provider_name.in_(provider_names),
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
)
|
|
|
.first()
|
|
|
)
|
|
|
|
|
|
+ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
|
+ """
|
|
|
+ Enable model.
|
|
|
+ :param model_type: model type
|
|
|
+ :param model: model name
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ model_setting = self._get_provider_model_setting(model_type, model)
|
|
|
+
|
|
|
if model_setting:
|
|
|
model_setting.enabled = True
|
|
|
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
@@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
- model_setting = (
|
|
|
- db.session.query(ProviderModelSetting)
|
|
|
- .filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model,
|
|
|
- )
|
|
|
- .first()
|
|
|
- )
|
|
|
+ model_setting = self._get_provider_model_setting(model_type, model)
|
|
|
|
|
|
if model_setting:
|
|
|
model_setting.enabled = False
|
|
@@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
+ return self._get_provider_model_setting(model_type, model)
|
|
|
+
|
|
|
+ def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
|
|
|
+ """
|
|
|
+ Get load balancing config.
|
|
|
+ """
|
|
|
+ model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
+ if model_provider_id.is_langgenius():
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
return (
|
|
|
- db.session.query(ProviderModelSetting)
|
|
|
+ db.session.query(LoadBalancingModelConfig)
|
|
|
.filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model,
|
|
|
+ LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
+ LoadBalancingModelConfig.provider_name.in_(provider_names),
|
|
|
+ LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
+ LoadBalancingModelConfig.model_name == model,
|
|
|
)
|
|
|
.first()
|
|
|
)
|
|
@@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
+ model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
+ if model_provider_id.is_langgenius():
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
load_balancing_config_count = (
|
|
|
db.session.query(LoadBalancingModelConfig)
|
|
|
.filter(
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
- LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
+ LoadBalancingModelConfig.provider_name.in_(provider_names),
|
|
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
|
LoadBalancingModelConfig.model_name == model,
|
|
|
)
|
|
@@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
if load_balancing_config_count <= 1:
|
|
|
raise ValueError("Model load balancing configuration must be more than 1.")
|
|
|
|
|
|
- model_setting = (
|
|
|
- db.session.query(ProviderModelSetting)
|
|
|
- .filter(
|
|
|
- ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
- ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
- ProviderModelSetting.model_name == model,
|
|
|
- )
|
|
|
- .first()
|
|
|
- )
|
|
|
+ model_setting = self._get_provider_model_setting(model_type, model)
|
|
|
|
|
|
if model_setting:
|
|
|
model_setting.load_balancing_enabled = True
|
|
@@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model: model name
|
|
|
:return:
|
|
|
"""
|
|
|
+ model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
+ if model_provider_id.is_langgenius():
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
model_setting = (
|
|
|
db.session.query(ProviderModelSetting)
|
|
|
.filter(
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
- ProviderModelSetting.provider_name == self.provider.provider,
|
|
|
+ ProviderModelSetting.provider_name.in_(provider_names),
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
)
|
|
@@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
|
|
|
return
|
|
|
|
|
|
# get preferred provider
|
|
|
+ model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
+ provider_names = [self.provider.provider]
|
|
|
+ if model_provider_id.is_langgenius():
|
|
|
+ provider_names.append(model_provider_id.provider_name)
|
|
|
+
|
|
|
preferred_model_provider = (
|
|
|
db.session.query(TenantPreferredModelProvider)
|
|
|
.filter(
|
|
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
|
|
- TenantPreferredModelProvider.provider_name == self.provider.provider,
|
|
|
+ TenantPreferredModelProvider.provider_name.in_(provider_names),
|
|
|
)
|
|
|
.first()
|
|
|
)
|