瀏覽代碼

Fix/plugin race condition (#14253)

Yeuoly 1 月之前
父節點
當前提交
490b6d092e

+ 7 - 0
api/app_factory.py

@@ -2,6 +2,7 @@ import logging
 import time
 
 from configs import dify_config
+from contexts.wrapper import RecyclableContextVar
 from dify_app import DifyApp
 
 
@@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp:
     dify_app = DifyApp(__name__)
     dify_app.config.from_mapping(dify_config.model_dump())
 
+    # add before request hook
+    @dify_app.before_request
+    def before_request():
+        # add an unique identifier to each request
+        RecyclableContextVar.increment_thread_recycles()
+
     return dify_app
 
 

+ 15 - 4
api/contexts/__init__.py

@@ -2,6 +2,8 @@ from contextvars import ContextVar
 from threading import Lock
 from typing import TYPE_CHECKING
 
+from contexts.wrapper import RecyclableContextVar
+
 if TYPE_CHECKING:
     from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
     from core.tools.plugin_tool.provider import PluginToolProviderController
@@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id")
 
 workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
 
-plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
-plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
+"""
+To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
+"""
+plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
+    ContextVar("plugin_tool_providers")
+)
+plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
 
-plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
-plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
+plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
+    ContextVar("plugin_model_providers")
+)
+plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
+    ContextVar("plugin_model_providers_lock")
+)

+ 65 - 0
api/contexts/wrapper.py

@@ -0,0 +1,65 @@
+from contextvars import ContextVar
+from typing import Generic, TypeVar
+
+T = TypeVar("T")
+
+
+class HiddenValue:
+    pass
+
+
+_default = HiddenValue()
+
+
+class RecyclableContextVar(Generic[T]):
+    """
+    RecyclableContextVar is a wrapper around ContextVar
+    It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
+
+    NOTE: you need to call `increment_thread_recycles` before requests
+    """
+
+    _thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
+
+    @classmethod
+    def increment_thread_recycles(cls):
+        try:
+            recycles = cls._thread_recycles.get()
+            cls._thread_recycles.set(recycles + 1)
+        except LookupError:
+            cls._thread_recycles.set(0)
+
+    def __init__(self, context_var: ContextVar[T]):
+        self._context_var = context_var
+        self._updates = ContextVar[int](context_var.name + "_updates", default=0)
+
+    def get(self, default: T | HiddenValue = _default) -> T:
+        thread_recycles = self._thread_recycles.get(0)
+        self_updates = self._updates.get()
+        if thread_recycles > self_updates:
+            self._updates.set(thread_recycles)
+
+        # check if thread is recycled and should be updated
+        if thread_recycles < self_updates:
+            return self._context_var.get()
+        else:
+            # thread_recycles >= self_updates, means current context is invalid
+            if isinstance(default, HiddenValue) or default is _default:
+                raise LookupError
+            else:
+                return default
+
+    def set(self, value: T):
+        # it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
+        # increase it manually
+        thread_recycles = self._thread_recycles.get(0)
+        self_updates = self._updates.get()
+        if thread_recycles > self_updates:
+            self._updates.set(thread_recycles)
+
+        if self._updates.get() == self._thread_recycles.get(0):
+            # after increment,
+            self._updates.set(self._updates.get() + 1)
+
+        # set the context
+        self._context_var.set(value)

+ 2 - 2
api/core/agent/entities.py

@@ -1,7 +1,7 @@
 from enum import StrEnum
 from typing import Any, Optional, Union
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
 
@@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
     provider_type: ToolProviderType
     provider_id: str
     tool_name: str
-    tool_parameters: dict[str, Any] = {}
+    tool_parameters: dict[str, Any] = Field(default_factory=dict)
     plugin_unique_identifier: str | None = None
 
 

+ 2 - 4
api/core/app/app_config/easy_ui_based_app/model_config/manager.py

@@ -2,9 +2,9 @@ from collections.abc import Mapping
 from typing import Any
 
 from core.app.app_config.entities import ModelConfigEntity
-from core.entities import DEFAULT_PLUGIN_ID
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
+from core.plugin.entities.plugin import ModelProviderID
 from core.provider_manager import ProviderManager
 
 
@@ -61,9 +61,7 @@ class ModelConfigManager:
             raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
 
         if "/" not in config["model"]["provider"]:
-            config["model"]["provider"] = (
-                f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
-            )
+            config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
 
         if config["model"]["provider"] not in model_provider_names:
             raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")

+ 4 - 4
api/core/app/app_config/entities.py

@@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
     provider: str
     model: str
     mode: Optional[str] = None
-    parameters: dict[str, Any] = {}
-    stop: list[str] = []
+    parameters: dict[str, Any] = Field(default_factory=dict)
+    stop: list[str] = Field(default_factory=list)
 
 
 class AdvancedChatMessageEntity(BaseModel):
@@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
 
     variable: str
     type: str
-    config: dict[str, Any] = {}
+    config: dict[str, Any] = Field(default_factory=dict)
 
 
 class DatasetRetrieveConfigEntity(BaseModel):
@@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
     """
 
     type: str
-    config: dict[str, Any] = {}
+    config: dict[str, Any] = Field(default_factory=dict)
 
 
 class TextToSpeechEntity(BaseModel):

+ 4 - 4
api/core/app/entities/app_invoke_entities.py

@@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
     model_schema: AIModelEntity
     mode: str
     provider_model_bundle: ProviderModelBundle
-    credentials: dict[str, Any] = {}
-    parameters: dict[str, Any] = {}
-    stop: list[str] = []
+    credentials: dict[str, Any] = Field(default_factory=dict)
+    parameters: dict[str, Any] = Field(default_factory=dict)
+    stop: list[str] = Field(default_factory=list)
 
     # pydantic configs
     model_config = ConfigDict(protected_namespaces=())
@@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
     call_depth: int = 0
 
     # extra parameters, like: auto_generate_conversation_name
-    extras: dict[str, Any] = {}
+    extras: dict[str, Any] = Field(default_factory=dict)
 
     # tracing instance
     trace_manager: Optional[TraceQueueManager] = None

+ 4 - 5
api/core/entities/provider_configuration.py

@@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence
 from json import JSONDecodeError
 from typing import Optional
 
-from pydantic import BaseModel, ConfigDict
+from pydantic import BaseModel, ConfigDict, Field
 from sqlalchemy import or_
 
 from constants import HIDDEN_VALUE
-from core.entities import DEFAULT_PLUGIN_ID
 from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
 from core.entities.provider_entities import (
     CustomConfiguration,
@@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
     """
 
     tenant_id: str
-    configurations: dict[str, ProviderConfiguration] = {}
+    configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
 
     def __init__(self, tenant_id: str):
         super().__init__(tenant_id=tenant_id)
@@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel):
 
     def __getitem__(self, key):
         if "/" not in key:
-            key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
+            key = str(ModelProviderID(key))
 
         return self.configurations[key]
 
@@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel):
 
     def get(self, key, default=None) -> ProviderConfiguration | None:
         if "/" not in key:
-            key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
+            key = str(ModelProviderID(key))
 
         return self.configurations.get(key, default)  # type: ignore
 

+ 5 - 1
api/core/hosting_configuration.py

@@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):
 
 
 class HostingConfiguration:
-    provider_map: dict[str, HostingProvider] = {}
+    provider_map: dict[str, HostingProvider]
     moderation_config: Optional[HostedModerationConfig] = None
 
+    def __init__(self) -> None:
+        self.provider_map = {}
+        self.moderation_config = None
+
     def init_app(self, app: Flask) -> None:
         if dify_config.EDITION != "CLOUD":
             return

+ 5 - 10
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -7,7 +7,6 @@ from typing import Optional
 from pydantic import BaseModel
 
 import contexts
-from core.entities import DEFAULT_PLUGIN_ID
 from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
 from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
@@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):
 
 
 class ModelProviderFactory:
-    provider_position_map: dict[str, int] = {}
+    provider_position_map: dict[str, int]
 
     def __init__(self, tenant_id: str) -> None:
+        self.provider_position_map = {}
+
         self.tenant_id = tenant_id
         self.plugin_model_manager = PluginModelManager()
 
@@ -360,11 +361,5 @@ class ModelProviderFactory:
         :param provider: provider name
         :return: plugin id and provider name
         """
-        plugin_id = DEFAULT_PLUGIN_ID
-        provider_name = provider
-        if "/" in provider:
-            # get the plugin_id before provider
-            plugin_id = "/".join(provider.split("/")[:-1])
-            provider_name = provider.split("/")[-1]
-
-        return str(plugin_id), provider_name
+        provider_id = ModelProviderID(provider)
+        return provider_id.plugin_id, provider_id.provider_name

+ 3 - 7
api/services/dataset_service.py

@@ -13,10 +13,10 @@ from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound
 
 from configs import dify_config
-from core.entities import DEFAULT_PLUGIN_ID
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
+from core.plugin.entities.plugin import ModelProviderID
 from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from events.dataset_event import dataset_was_deleted
@@ -328,14 +328,10 @@ class DatasetService:
             else:
                 # add default plugin id to both setting sets, to make sure the plugin model provider is consistent
                 plugin_model_provider = dataset.embedding_model_provider
-                if "/" not in plugin_model_provider:
-                    plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
+                plugin_model_provider = str(ModelProviderID(plugin_model_provider))
 
                 new_plugin_model_provider = data["embedding_model_provider"]
-                if "/" not in new_plugin_model_provider:
-                    new_plugin_model_provider = (
-                        f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
-                    )
+                new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider))
 
                 if (
                     new_plugin_model_provider != plugin_model_provider