|
@@ -1,7 +1,7 @@
|
|
|
import datetime
|
|
|
import json
|
|
|
import logging
|
|
|
-import time
|
|
|
+
|
|
|
from json import JSONDecodeError
|
|
|
from typing import Optional, List, Dict, Tuple, Iterator
|
|
|
|
|
@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
|
|
|
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
|
|
from core.helper import encrypter
|
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
|
-from core.model_runtime.entities.model_entities import ModelType
|
|
|
-from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
|
|
+from core.model_runtime.entities.model_entities import ModelType, FetchFrom
|
|
|
+from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
|
|
|
+ ConfigurateMethod
|
|
|
from core.model_runtime.model_providers import model_provider_factory
|
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
|
@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
+original_provider_configurate_methods = {}
|
|
|
+
|
|
|
|
|
|
class ProviderConfiguration(BaseModel):
|
|
|
"""
|
|
@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
|
|
|
system_configuration: SystemConfiguration
|
|
|
custom_configuration: CustomConfiguration
|
|
|
|
|
|
+ def __init__(self, **data):
|
|
|
+ super().__init__(**data)
|
|
|
+
|
|
|
+ if self.provider.provider not in original_provider_configurate_methods:
|
|
|
+ original_provider_configurate_methods[self.provider.provider] = []
|
|
|
+ for configurate_method in self.provider.configurate_methods:
|
|
|
+ original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
|
|
+
|
|
|
+ if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
|
+ if (any([len(quota_configuration.restrict_models) > 0
|
|
|
+ for quota_configuration in self.system_configuration.quota_configurations])
|
|
|
+ and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
|
|
+ self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
|
|
+
|
|
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
|
|
"""
|
|
|
Get current credentials.
|
|
@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
if provider_record:
|
|
|
try:
|
|
|
- original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
|
|
+ original_credentials = json.loads(
|
|
|
+ provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
|
|
except JSONDecodeError:
|
|
|
original_credentials = {}
|
|
|
|
|
@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
if provider_model_record:
|
|
|
try:
|
|
|
- original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
|
|
+ original_credentials = json.loads(
|
|
|
+ provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
|
|
except JSONDecodeError:
|
|
|
original_credentials = {}
|
|
|
|
|
@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
+ if self.provider.provider not in original_provider_configurate_methods:
|
|
|
+ original_provider_configurate_methods[self.provider.provider] = []
|
|
|
+ for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
|
|
+ original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
|
|
+
|
|
|
+ should_use_custom_model = False
|
|
|
+ if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
|
+ should_use_custom_model = True
|
|
|
+
|
|
|
for quota_configuration in self.system_configuration.quota_configurations:
|
|
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
|
|
continue
|
|
|
|
|
|
- restrict_llms = quota_configuration.restrict_llms
|
|
|
- if not restrict_llms:
|
|
|
+ restrict_models = quota_configuration.restrict_models
|
|
|
+ if len(restrict_models) == 0:
|
|
|
break
|
|
|
|
|
|
+ if should_use_custom_model:
|
|
|
+ if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
|
+ # only customizable model
|
|
|
+ for restrict_model in restrict_models:
|
|
|
+ copy_credentials = self.system_configuration.credentials.copy()
|
|
|
+ if restrict_model.base_model_name:
|
|
|
+ copy_credentials['base_model_name'] = restrict_model.base_model_name
|
|
|
+
|
|
|
+ try:
|
|
|
+ custom_model_schema = (
|
|
|
+ provider_instance.get_model_instance(restrict_model.model_type)
|
|
|
+ .get_customizable_model_schema_from_credentials(
|
|
|
+ restrict_model.model,
|
|
|
+ copy_credentials
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except Exception as ex:
|
|
|
+ logger.warning(f'get custom model schema failed, {ex}')
|
|
|
+ continue
|
|
|
+
|
|
|
+ if not custom_model_schema:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if custom_model_schema.model_type not in model_types:
|
|
|
+ continue
|
|
|
+
|
|
|
+ provider_models.append(
|
|
|
+ ModelWithProviderEntity(
|
|
|
+ model=custom_model_schema.model,
|
|
|
+ label=custom_model_schema.label,
|
|
|
+ model_type=custom_model_schema.model_type,
|
|
|
+ features=custom_model_schema.features,
|
|
|
+ fetch_from=FetchFrom.PREDEFINED_MODEL,
|
|
|
+ model_properties=custom_model_schema.model_properties,
|
|
|
+ deprecated=custom_model_schema.deprecated,
|
|
|
+ provider=SimpleModelProviderEntity(self.provider),
|
|
|
+ status=ModelStatus.ACTIVE
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
# if llm name not in restricted llm list, remove it
|
|
|
+ restrict_model_names = [rm.model for rm in restrict_models]
|
|
|
for m in provider_models:
|
|
|
- if m.model_type == ModelType.LLM and m.model not in restrict_llms:
|
|
|
+ if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
|
|
m.status = ModelStatus.NO_PERMISSION
|
|
|
elif not quota_configuration.is_valid:
|
|
|
m.status = ModelStatus.QUOTA_EXCEEDED
|
|
|
-
|
|
|
return provider_models
|
|
|
|
|
|
def _get_custom_provider_models(self,
|