|
@@ -2,13 +2,14 @@ import datetime
|
|
|
import json
|
|
|
import logging
|
|
|
from collections import defaultdict
|
|
|
-from collections.abc import Iterator
|
|
|
+from collections.abc import Iterator, Sequence
|
|
|
from json import JSONDecodeError
|
|
|
from typing import Optional
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
|
|
from constants import HIDDEN_VALUE
|
|
|
+from core.entities import DEFAULT_PLUGIN_ID
|
|
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
|
|
from core.entities.provider_entities import (
|
|
|
CustomConfiguration,
|
|
@@ -18,16 +19,15 @@ from core.entities.provider_entities import (
|
|
|
)
|
|
|
from core.helper import encrypter
|
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
|
-from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
|
|
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
|
|
from core.model_runtime.entities.provider_entities import (
|
|
|
ConfigurateMethod,
|
|
|
CredentialFormSchema,
|
|
|
FormType,
|
|
|
ProviderEntity,
|
|
|
)
|
|
|
-from core.model_runtime.model_providers import model_provider_factory
|
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
|
-from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
|
|
+from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
|
|
from extensions.ext_database import db
|
|
|
from models.provider import (
|
|
|
LoadBalancingModelConfig,
|
|
@@ -99,9 +99,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
continue
|
|
|
|
|
|
restrict_models = quota_configuration.restrict_models
|
|
|
- if self.system_configuration.credentials is None:
|
|
|
- return None
|
|
|
- copy_credentials = self.system_configuration.credentials.copy()
|
|
|
+
|
|
|
+ copy_credentials = (
|
|
|
+ self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
|
|
|
+ )
|
|
|
if restrict_models:
|
|
|
for restrict_model in restrict_models:
|
|
|
if (
|
|
@@ -140,6 +141,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
if current_quota_configuration is None:
|
|
|
return None
|
|
|
|
|
|
+ if not current_quota_configuration:
|
|
|
+ return SystemConfigurationStatus.UNSUPPORTED
|
|
|
+
|
|
|
return (
|
|
|
SystemConfigurationStatus.ACTIVE
|
|
|
if current_quota_configuration.is_valid
|
|
@@ -153,7 +157,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
"""
|
|
|
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
|
|
|
|
|
- def get_custom_credentials(self, obfuscated: bool = False):
|
|
|
+ def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
|
|
|
"""
|
|
|
Get custom credentials.
|
|
|
|
|
@@ -175,7 +179,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
else [],
|
|
|
)
|
|
|
|
|
|
- def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
|
|
|
+ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
|
|
"""
|
|
|
Validate custom credentials.
|
|
|
:param credentials: provider credentials
|
|
@@ -219,6 +223,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
if value == HIDDEN_VALUE and key in original_credentials:
|
|
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
|
|
|
|
|
+ model_provider_factory = ModelProviderFactory(self.tenant_id)
|
|
|
credentials = model_provider_factory.provider_credentials_validate(
|
|
|
provider=self.provider.provider, credentials=credentials
|
|
|
)
|
|
@@ -246,13 +251,13 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
db.session.commit()
|
|
|
else:
|
|
|
- provider_record = Provider(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- provider_type=ProviderType.CUSTOM.value,
|
|
|
- encrypted_config=json.dumps(credentials),
|
|
|
- is_valid=True,
|
|
|
- )
|
|
|
+ provider_record = Provider()
|
|
|
+ provider_record.tenant_id = self.tenant_id
|
|
|
+ provider_record.provider_name = self.provider.provider
|
|
|
+ provider_record.provider_type = ProviderType.CUSTOM.value
|
|
|
+ provider_record.encrypted_config = json.dumps(credentials)
|
|
|
+ provider_record.is_valid = True
|
|
|
+
|
|
|
db.session.add(provider_record)
|
|
|
db.session.commit()
|
|
|
|
|
@@ -327,7 +332,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
def custom_model_credentials_validate(
|
|
|
self, model_type: ModelType, model: str, credentials: dict
|
|
|
- ) -> tuple[Optional[ProviderModel], dict]:
|
|
|
+ ) -> tuple[ProviderModel | None, dict]:
|
|
|
"""
|
|
|
Validate custom model credentials.
|
|
|
|
|
@@ -370,6 +375,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
if value == HIDDEN_VALUE and key in original_credentials:
|
|
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
|
|
|
|
|
+ model_provider_factory = ModelProviderFactory(self.tenant_id)
|
|
|
credentials = model_provider_factory.model_credentials_validate(
|
|
|
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
|
|
)
|
|
@@ -400,14 +406,13 @@ class ProviderConfiguration(BaseModel):
|
|
|
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
db.session.commit()
|
|
|
else:
|
|
|
- provider_model_record = ProviderModel(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- model_name=model,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- encrypted_config=json.dumps(credentials),
|
|
|
- is_valid=True,
|
|
|
- )
|
|
|
+ provider_model_record = ProviderModel()
|
|
|
+ provider_model_record.tenant_id = self.tenant_id
|
|
|
+ provider_model_record.provider_name = self.provider.provider
|
|
|
+ provider_model_record.model_name = model
|
|
|
+ provider_model_record.model_type = model_type.to_origin_model_type()
|
|
|
+ provider_model_record.encrypted_config = json.dumps(credentials)
|
|
|
+ provider_model_record.is_valid = True
|
|
|
db.session.add(provider_model_record)
|
|
|
db.session.commit()
|
|
|
|
|
@@ -474,13 +479,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
db.session.commit()
|
|
|
else:
|
|
|
- model_setting = ProviderModelSetting(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- model_name=model,
|
|
|
- enabled=True,
|
|
|
- )
|
|
|
+ model_setting = ProviderModelSetting()
|
|
|
+ model_setting.tenant_id = self.tenant_id
|
|
|
+ model_setting.provider_name = self.provider.provider
|
|
|
+ model_setting.model_type = model_type.to_origin_model_type()
|
|
|
+ model_setting.model_name = model
|
|
|
+ model_setting.enabled = True
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
|
|
|
@@ -509,13 +513,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
db.session.commit()
|
|
|
else:
|
|
|
- model_setting = ProviderModelSetting(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- model_name=model,
|
|
|
- enabled=False,
|
|
|
- )
|
|
|
+ model_setting = ProviderModelSetting()
|
|
|
+ model_setting.tenant_id = self.tenant_id
|
|
|
+ model_setting.provider_name = self.provider.provider
|
|
|
+ model_setting.model_type = model_type.to_origin_model_type()
|
|
|
+ model_setting.model_name = model
|
|
|
+ model_setting.enabled = False
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
|
|
|
@@ -576,13 +579,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
db.session.commit()
|
|
|
else:
|
|
|
- model_setting = ProviderModelSetting(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- model_name=model,
|
|
|
- load_balancing_enabled=True,
|
|
|
- )
|
|
|
+ model_setting = ProviderModelSetting()
|
|
|
+ model_setting.tenant_id = self.tenant_id
|
|
|
+ model_setting.provider_name = self.provider.provider
|
|
|
+ model_setting.model_type = model_type.to_origin_model_type()
|
|
|
+ model_setting.model_name = model
|
|
|
+ model_setting.load_balancing_enabled = True
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
|
|
|
@@ -611,25 +613,17 @@ class ProviderConfiguration(BaseModel):
|
|
|
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
|
db.session.commit()
|
|
|
else:
|
|
|
- model_setting = ProviderModelSetting(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- model_name=model,
|
|
|
- load_balancing_enabled=False,
|
|
|
- )
|
|
|
+ model_setting = ProviderModelSetting()
|
|
|
+ model_setting.tenant_id = self.tenant_id
|
|
|
+ model_setting.provider_name = self.provider.provider
|
|
|
+ model_setting.model_type = model_type.to_origin_model_type()
|
|
|
+ model_setting.model_name = model
|
|
|
+ model_setting.load_balancing_enabled = False
|
|
|
db.session.add(model_setting)
|
|
|
db.session.commit()
|
|
|
|
|
|
return model_setting
|
|
|
|
|
|
- def get_provider_instance(self) -> ModelProvider:
|
|
|
- """
|
|
|
- Get provider instance.
|
|
|
- :return:
|
|
|
- """
|
|
|
- return model_provider_factory.get_provider_instance(self.provider.provider)
|
|
|
-
|
|
|
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
|
|
|
"""
|
|
|
Get current model type instance.
|
|
@@ -637,11 +631,19 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param model_type: model type
|
|
|
:return:
|
|
|
"""
|
|
|
- # Get provider instance
|
|
|
- provider_instance = self.get_provider_instance()
|
|
|
+ model_provider_factory = ModelProviderFactory(self.tenant_id)
|
|
|
|
|
|
# Get model instance of LLM
|
|
|
- return provider_instance.get_model_instance(model_type)
|
|
|
+ return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
|
|
+
|
|
|
+ def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
|
|
|
+ """
|
|
|
+ Get model schema
|
|
|
+ """
|
|
|
+ model_provider_factory = ModelProviderFactory(self.tenant_id)
|
|
|
+ return model_provider_factory.get_model_schema(
|
|
|
+ provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
|
|
+ )
|
|
|
|
|
|
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
|
|
"""
|
|
@@ -668,11 +670,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
if preferred_model_provider:
|
|
|
preferred_model_provider.preferred_provider_type = provider_type.value
|
|
|
else:
|
|
|
- preferred_model_provider = TenantPreferredModelProvider(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- provider_name=self.provider.provider,
|
|
|
- preferred_provider_type=provider_type.value,
|
|
|
- )
|
|
|
+ preferred_model_provider = TenantPreferredModelProvider()
|
|
|
+ preferred_model_provider.tenant_id = self.tenant_id
|
|
|
+ preferred_model_provider.provider_name = self.provider.provider
|
|
|
+ preferred_model_provider.preferred_provider_type = provider_type.value
|
|
|
db.session.add(preferred_model_provider)
|
|
|
|
|
|
db.session.commit()
|
|
@@ -737,13 +738,14 @@ class ProviderConfiguration(BaseModel):
|
|
|
:param only_active: only active models
|
|
|
:return:
|
|
|
"""
|
|
|
- provider_instance = self.get_provider_instance()
|
|
|
+ model_provider_factory = ModelProviderFactory(self.tenant_id)
|
|
|
+ provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
|
|
|
|
|
- model_types = []
|
|
|
+ model_types: list[ModelType] = []
|
|
|
if model_type:
|
|
|
model_types.append(model_type)
|
|
|
else:
|
|
|
- model_types = list(provider_instance.get_provider_schema().supported_model_types)
|
|
|
+ model_types = list(provider_schema.supported_model_types)
|
|
|
|
|
|
# Group model settings by model type and model
|
|
|
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
|
|
@@ -752,11 +754,11 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
|
provider_models = self._get_system_provider_models(
|
|
|
- model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
|
|
+ model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
|
|
)
|
|
|
else:
|
|
|
provider_models = self._get_custom_provider_models(
|
|
|
- model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
|
|
+ model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
|
|
)
|
|
|
|
|
|
if only_active:
|
|
@@ -767,23 +769,26 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
def _get_system_provider_models(
|
|
|
self,
|
|
|
- model_types: list[ModelType],
|
|
|
- provider_instance: ModelProvider,
|
|
|
+ model_types: Sequence[ModelType],
|
|
|
+ provider_schema: ProviderEntity,
|
|
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
|
) -> list[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get system provider models.
|
|
|
|
|
|
:param model_types: model types
|
|
|
- :param provider_instance: provider instance
|
|
|
+ :param provider_schema: provider schema
|
|
|
:param model_setting_map: model setting map
|
|
|
:return:
|
|
|
"""
|
|
|
provider_models = []
|
|
|
for model_type in model_types:
|
|
|
- for m in provider_instance.models(model_type):
|
|
|
+ for m in provider_schema.models:
|
|
|
+ if m.model_type != model_type:
|
|
|
+ continue
|
|
|
+
|
|
|
status = ModelStatus.ACTIVE
|
|
|
- if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
|
|
+ if m.model in model_setting_map:
|
|
|
model_setting = model_setting_map[m.model_type][m.model]
|
|
|
if model_setting.enabled is False:
|
|
|
status = ModelStatus.DISABLED
|
|
@@ -804,7 +809,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
if self.provider.provider not in original_provider_configurate_methods:
|
|
|
original_provider_configurate_methods[self.provider.provider] = []
|
|
|
- for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
|
|
+ for configurate_method in provider_schema.configurate_methods:
|
|
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
|
|
|
|
|
should_use_custom_model = False
|
|
@@ -825,18 +830,22 @@ class ProviderConfiguration(BaseModel):
|
|
|
]:
|
|
|
# only customizable model
|
|
|
for restrict_model in restrict_models:
|
|
|
- if self.system_configuration.credentials is not None:
|
|
|
- copy_credentials = self.system_configuration.credentials.copy()
|
|
|
- if restrict_model.base_model_name:
|
|
|
- copy_credentials["base_model_name"] = restrict_model.base_model_name
|
|
|
-
|
|
|
- try:
|
|
|
- custom_model_schema = provider_instance.get_model_instance(
|
|
|
- restrict_model.model_type
|
|
|
- ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
|
|
- except Exception as ex:
|
|
|
- logger.warning(f"get custom model schema failed, {ex}")
|
|
|
- continue
|
|
|
+ copy_credentials = (
|
|
|
+ self.system_configuration.credentials.copy()
|
|
|
+ if self.system_configuration.credentials
|
|
|
+ else {}
|
|
|
+ )
|
|
|
+ if restrict_model.base_model_name:
|
|
|
+ copy_credentials["base_model_name"] = restrict_model.base_model_name
|
|
|
+
|
|
|
+ try:
|
|
|
+ custom_model_schema = self.get_model_schema(
|
|
|
+ model_type=restrict_model.model_type,
|
|
|
+ model=restrict_model.model,
|
|
|
+ credentials=copy_credentials,
|
|
|
+ )
|
|
|
+ except Exception as ex:
|
|
|
+ logger.warning(f"get custom model schema failed, {ex}")
|
|
|
|
|
|
if not custom_model_schema:
|
|
|
continue
|
|
@@ -881,15 +890,15 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
def _get_custom_provider_models(
|
|
|
self,
|
|
|
- model_types: list[ModelType],
|
|
|
- provider_instance: ModelProvider,
|
|
|
+ model_types: Sequence[ModelType],
|
|
|
+ provider_schema: ProviderEntity,
|
|
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
|
) -> list[ModelWithProviderEntity]:
|
|
|
"""
|
|
|
Get custom provider models.
|
|
|
|
|
|
:param model_types: model types
|
|
|
- :param provider_instance: provider instance
|
|
|
+ :param provider_schema: provider schema
|
|
|
:param model_setting_map: model setting map
|
|
|
:return:
|
|
|
"""
|
|
@@ -903,8 +912,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
if model_type not in self.provider.supported_model_types:
|
|
|
continue
|
|
|
|
|
|
- models = provider_instance.models(model_type)
|
|
|
- for m in models:
|
|
|
+ for m in provider_schema.models:
|
|
|
+ if m.model_type != model_type:
|
|
|
+ continue
|
|
|
+
|
|
|
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
|
|
load_balancing_enabled = False
|
|
|
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
|
@@ -936,10 +947,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
- custom_model_schema = provider_instance.get_model_instance(
|
|
|
- model_configuration.model_type
|
|
|
- ).get_customizable_model_schema_from_credentials(
|
|
|
- model_configuration.model, model_configuration.credentials
|
|
|
+ custom_model_schema = self.get_model_schema(
|
|
|
+ model_type=model_configuration.model_type,
|
|
|
+ model=model_configuration.model,
|
|
|
+ credentials=model_configuration.credentials,
|
|
|
)
|
|
|
except Exception as ex:
|
|
|
logger.warning(f"get custom model schema failed, {ex}")
|
|
@@ -967,7 +978,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
label=custom_model_schema.label,
|
|
|
model_type=custom_model_schema.model_type,
|
|
|
features=custom_model_schema.features,
|
|
|
- fetch_from=custom_model_schema.fetch_from,
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
model_properties=custom_model_schema.model_properties,
|
|
|
deprecated=custom_model_schema.deprecated,
|
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
@@ -1040,6 +1051,9 @@ class ProviderConfigurations(BaseModel):
|
|
|
return list(self.values())
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
+ if "/" not in key:
|
|
|
+ key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
|
|
+
|
|
|
return self.configurations[key]
|
|
|
|
|
|
def __setitem__(self, key, value):
|
|
@@ -1051,8 +1065,11 @@ class ProviderConfigurations(BaseModel):
|
|
|
def values(self) -> Iterator[ProviderConfiguration]:
|
|
|
return iter(self.configurations.values())
|
|
|
|
|
|
- def get(self, key, default=None):
|
|
|
- return self.configurations.get(key, default)
|
|
|
+ def get(self, key, default=None) -> ProviderConfiguration | None:
|
|
|
+ if "/" not in key:
|
|
|
+ key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
|
|
+
|
|
|
+ return self.configurations.get(key, default) # type: ignore
|
|
|
|
|
|
|
|
|
class ProviderModelBundle(BaseModel):
|
|
@@ -1061,7 +1078,6 @@ class ProviderModelBundle(BaseModel):
|
|
|
"""
|
|
|
|
|
|
configuration: ProviderConfiguration
|
|
|
- provider_instance: ModelProvider
|
|
|
model_type_instance: AIModel
|
|
|
|
|
|
# pydantic configs
|