Przeglądaj źródła

refactor: Simplify plugin and provider ID generation logic and deduplicate plugin_ids (#14041)

Yeuoly 2 miesięcy temu
rodzic
commit
23888398d1

+ 3 - 10
api/services/plugin/dependencies_analysis.py

@@ -1,5 +1,5 @@
 from core.helper import marketplace
-from core.plugin.entities.plugin import GenericProviderID, PluginDependency, PluginInstallationSource
+from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
 from core.plugin.manager.plugin import PluginInstallationManager
 
 
@@ -12,10 +12,7 @@ class DependenciesAnalysisService:
         Convert the tool id to the plugin_id
         """
         try:
-            tool_provider_id = GenericProviderID(tool_id)
-            if tool_id in ["jina", "siliconflow"]:
-                tool_provider_id.plugin_name = tool_provider_id.plugin_name + "_tool"
-            return tool_provider_id.plugin_id
+            return ToolProviderID(tool_id).plugin_id
         except Exception as e:
             raise e
 
@@ -27,11 +24,7 @@ class DependenciesAnalysisService:
         Convert the model provider id to the plugin_id
         """
         try:
-            generic_provider_id = GenericProviderID(model_provider_id)
-            if model_provider_id == "google":
-                generic_provider_id.plugin_name = "gemini"
-
-            return generic_provider_id.plugin_id
+            return ModelProviderID(model_provider_id).plugin_id
         except Exception as e:
             raise e
 

+ 5 - 33
api/services/plugin/plugin_migration.py

@@ -14,9 +14,8 @@ from flask import Flask, current_app
 from sqlalchemy.orm import Session
 
 from core.agent.entities import AgentToolEntity
-from core.entities import DEFAULT_PLUGIN_ID
 from core.helper import marketplace
-from core.plugin.entities.plugin import PluginInstallationSource
+from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
 from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
 from core.plugin.manager.plugin import PluginInstallationManager
 from core.tools.entities.tool_entities import ToolProviderType
@@ -203,13 +202,7 @@ class PluginMigration:
             result = []
             for row in rs:
                 provider_name = str(row[0])
-                if provider_name and "/" not in provider_name:
-                    if provider_name == "google":
-                        provider_name = "gemini"
-
-                    result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
-                elif provider_name:
-                    result.append(provider_name)
+                result.append(ModelProviderID(provider_name).plugin_id)
 
             return result
 
@@ -222,30 +215,10 @@ class PluginMigration:
             rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
             result = []
             for row in rs:
-                if "/" not in row.provider:
-                    result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
-                else:
-                    result.append(row.provider)
+                result.append(ToolProviderID(row.provider).plugin_id)
 
             return result
 
-    @classmethod
-    def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
-        """
-        Handle builtin tool provider.
-        """
-        if provider_name == "jina":
-            provider_name = "jina_tool"
-        elif provider_name == "siliconflow":
-            provider_name = "siliconflow_tool"
-        elif provider_name == "stepfun":
-            provider_name = "stepfun_tool"
-
-        if "/" not in provider_name:
-            return DEFAULT_PLUGIN_ID + "/" + provider_name
-        else:
-            return provider_name
-
     @classmethod
     def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
         """
@@ -266,8 +239,7 @@ class PluginMigration:
                         provider_name = data.get("provider_name")
                         provider_type = data.get("provider_type")
                         if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
-                            provider_name = cls._handle_builtin_tool_provider(provider_name)
-                            result.append(provider_name)
+                            result.append(ToolProviderID(provider_name).plugin_id)
 
             return result
 
@@ -298,7 +270,7 @@ class PluginMigration:
                                     tool_entity.provider_type == ToolProviderType.BUILT_IN.value
                                     and tool_entity.provider_id not in excluded_providers
                                 ):
-                                    result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
+                                    result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
 
                             except Exception:
                                 logger.exception(f"Failed to process tool {tool}")