|
@@ -6,7 +6,6 @@ from typing import Any, Optional, cast
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
|
|
from configs import dify_config
|
|
|
-from core.entities import DEFAULT_PLUGIN_ID
|
|
|
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
|
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
|
|
from core.entities.provider_entities import (
|
|
@@ -370,7 +369,8 @@ class ProviderManager:
|
|
|
|
|
|
provider_name_to_provider_records_dict = defaultdict(list)
|
|
|
for provider in providers:
|
|
|
- provider_name_to_provider_records_dict[provider.provider_name].append(provider)
|
|
|
+
|
|
|
+ provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
|
|
|
|
|
return provider_name_to_provider_records_dict
|
|
|
|
|
@@ -505,14 +505,12 @@ class ProviderManager:
|
|
|
if quota.quota_type == ProviderQuotaType.TRIAL:
|
|
|
|
|
|
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
|
|
- if not provider_name.startswith(DEFAULT_PLUGIN_ID):
|
|
|
- continue
|
|
|
- hosting_provider_name = provider_name.split("/")[-1]
|
|
|
try:
|
|
|
|
|
|
provider_record = Provider(
|
|
|
tenant_id=tenant_id,
|
|
|
- provider_name=hosting_provider_name,
|
|
|
+
|
|
|
+ provider_name=ModelProviderID(provider_name).provider_name,
|
|
|
provider_type=ProviderType.SYSTEM.value,
|
|
|
quota_type=ProviderQuotaType.TRIAL.value,
|
|
|
quota_limit=quota.quota_limit,
|
|
@@ -527,13 +525,12 @@ class ProviderManager:
|
|
|
db.session.query(Provider)
|
|
|
.filter(
|
|
|
Provider.tenant_id == tenant_id,
|
|
|
- Provider.provider_name == hosting_provider_name,
|
|
|
+ Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
|
|
Provider.provider_type == ProviderType.SYSTEM.value,
|
|
|
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
|
|
)
|
|
|
.first()
|
|
|
)
|
|
|
-
|
|
|
if provider_record and not provider_record.is_valid:
|
|
|
provider_record.is_valid = True
|
|
|
db.session.commit()
|