浏览代码

chore(provider_manager): Update hosted model's name (#14334)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 1 月之前
父节点
当前提交
76bcdc2581
共有 2 个文件被更改,包括 12 次插入15 次删除
  1. 6 13
      api/core/hosting_configuration.py
  2. 6 2
      api/core/provider_manager.py

+ 6 - 13
api/core/hosting_configuration.py

@@ -52,19 +52,12 @@ class HostingConfiguration:
         if dify_config.EDITION != "CLOUD":
             return
 
-        self.provider_map["azure_openai"] = self.init_azure_openai()
-        self.provider_map["openai"] = self.init_openai()
-        self.provider_map["anthropic"] = self.init_anthropic()
-        self.provider_map["minimax"] = self.init_minimax()
-        self.provider_map["spark"] = self.init_spark()
-        self.provider_map["zhipuai"] = self.init_zhipuai()
-        # NOTE: We need to use the new name format after the data migration.
-        # self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
-        # self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
-        # self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
-        # self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
-        # self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
-        # self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
+        self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
+        self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
+        self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
+        self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
+        self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
+        self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
 
         self.moderation_config = self.init_moderation_config()
 

+ 6 - 2
api/core/provider_manager.py

@@ -6,6 +6,7 @@ from typing import Any, Optional, cast
 from sqlalchemy.exc import IntegrityError
 
 from configs import dify_config
+from core.entities import DEFAULT_PLUGIN_ID
 from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
 from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
 from core.entities.provider_entities import (
@@ -504,11 +505,14 @@ class ProviderManager:
                 if quota.quota_type == ProviderQuotaType.TRIAL:
                     # Init trial provider records if not exists
                     if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
+                        if not provider_name.startswith(DEFAULT_PLUGIN_ID):
+                            continue
+                        hosting_provider_name = provider_name.split("/")[-1]
                         try:
                             # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
                             provider_record = Provider(
                                 tenant_id=tenant_id,
-                                provider_name=provider_name,
+                                provider_name=hosting_provider_name,
                                 provider_type=ProviderType.SYSTEM.value,
                                 quota_type=ProviderQuotaType.TRIAL.value,
                                 quota_limit=quota.quota_limit,  # type: ignore
@@ -523,7 +527,7 @@ class ProviderManager:
                                 db.session.query(Provider)
                                 .filter(
                                     Provider.tenant_id == tenant_id,
-                                    Provider.provider_name == provider_name,
+                                    Provider.provider_name == hosting_provider_name,
                                     Provider.provider_type == ProviderType.SYSTEM.value,
                                     Provider.quota_type == ProviderQuotaType.TRIAL.value,
                                 )