Преглед изворни кода

fix(provider_manager): fix custom provider (#14340)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- пре 2 месеци
родитељ
комит
57b60dd51f
1 измењених фајлова са 5 додато и 8 уклоњено
  1. 5 8
      api/core/provider_manager.py

+ 5 - 8
api/core/provider_manager.py

@@ -6,7 +6,6 @@ from typing import Any, Optional, cast
 from sqlalchemy.exc import IntegrityError
 
 from configs import dify_config
-from core.entities import DEFAULT_PLUGIN_ID
 from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
 from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
 from core.entities.provider_entities import (
@@ -370,7 +369,8 @@ class ProviderManager:
 
         provider_name_to_provider_records_dict = defaultdict(list)
         for provider in providers:
-            provider_name_to_provider_records_dict[provider.provider_name].append(provider)
+            # TODO: Use provider name with prefix after the data migration
+            provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
 
         return provider_name_to_provider_records_dict
 
@@ -505,14 +505,12 @@ class ProviderManager:
                 if quota.quota_type == ProviderQuotaType.TRIAL:
                     # Init trial provider records if not exists
                     if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
-                        if not provider_name.startswith(DEFAULT_PLUGIN_ID):
-                            continue
-                        hosting_provider_name = provider_name.split("/")[-1]
                         try:
                             # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
                             provider_record = Provider(
                                 tenant_id=tenant_id,
-                                provider_name=hosting_provider_name,
+                                # TODO: Use provider name with prefix after the data migration.
+                                provider_name=ModelProviderID(provider_name).provider_name,
                                 provider_type=ProviderType.SYSTEM.value,
                                 quota_type=ProviderQuotaType.TRIAL.value,
                                 quota_limit=quota.quota_limit,  # type: ignore
@@ -527,13 +525,12 @@ class ProviderManager:
                                 db.session.query(Provider)
                                 .filter(
                                     Provider.tenant_id == tenant_id,
-                                    Provider.provider_name == hosting_provider_name,
+                                    Provider.provider_name == ModelProviderID(provider_name).provider_name,
                                     Provider.provider_type == ProviderType.SYSTEM.value,
                                     Provider.quota_type == ProviderQuotaType.TRIAL.value,
                                 )
                                 .first()
                             )
-
                             if provider_record and not provider_record.is_valid:
                                 provider_record.is_valid = True
                                 db.session.commit()