Parcourir la source

feat:add apollo configuration to load env file (#11210)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: huanshare <liuhuan101@longfor.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
huanshare il y a 4 mois
Parent
commit
967b7d89e3

+ 66 - 7
api/configs/app_config.py

@@ -1,11 +1,51 @@
-from pydantic_settings import SettingsConfigDict
+import logging
+from typing import Any
 
-from configs.deploy import DeploymentConfig
-from configs.enterprise import EnterpriseFeatureConfig
-from configs.extra import ExtraServiceConfig
-from configs.feature import FeatureConfig
-from configs.middleware import MiddlewareConfig
-from configs.packaging import PackagingInfo
+from pydantic.fields import FieldInfo
+from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
+
+from .deploy import DeploymentConfig
+from .enterprise import EnterpriseFeatureConfig
+from .extra import ExtraServiceConfig
+from .feature import FeatureConfig
+from .middleware import MiddlewareConfig
+from .packaging import PackagingInfo
+from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
+from .remote_settings_sources.apollo import ApolloSettingsSource
+
+logger = logging.getLogger(__name__)
+
+
+class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
+    def __init__(self, settings_cls: type[BaseSettings]):
+        super().__init__(settings_cls)
+
+    def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
+        raise NotImplementedError
+
+    def __call__(self) -> dict[str, Any]:
+        current_state = self.current_state
+        remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
+        if not remote_source_name:
+            return {}
+
+        remote_source: RemoteSettingsSource | None = None
+        match remote_source_name:
+            case RemoteSettingsSourceName.APOLLO:
+                remote_source = ApolloSettingsSource(current_state)
+            case _:
+                logger.warning(f"Unsupported remote source: {remote_source_name}")
+                return {}
+
+        d: dict[str, Any] = {}
+
+        for field_name, field in self.settings_cls.model_fields.items():
+            field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
+            field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
+            if field_value is not None:
+                d[field_key] = field_value
+
+        return d
 
 
 class DifyConfig(
@@ -19,6 +59,8 @@ class DifyConfig(
     MiddlewareConfig,
     # Extra service configs
     ExtraServiceConfig,
+    # Remote source configs
+    RemoteSettingsSourceConfig,
     # Enterprise feature configs
     # **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
     EnterpriseFeatureConfig,
@@ -35,3 +77,20 @@ class DifyConfig(
     # please consider to arrange it in the proper config group of existed or added
     # for better readability and maintainability.
     # Thanks for your concentration and consideration.
+
+    @classmethod
+    def settings_customise_sources(
+        cls,
+        settings_cls: type[BaseSettings],
+        init_settings: PydanticBaseSettingsSource,
+        env_settings: PydanticBaseSettingsSource,
+        dotenv_settings: PydanticBaseSettingsSource,
+        file_secret_settings: PydanticBaseSettingsSource,
+    ) -> tuple[PydanticBaseSettingsSource, ...]:
+        return (
+            init_settings,
+            env_settings,
+            RemoteSettingsSourceFactory(settings_cls),
+            dotenv_settings,
+            file_secret_settings,
+        )

+ 1 - 1
api/configs/middleware/__init__.py

@@ -73,7 +73,7 @@ class KeywordStoreConfig(BaseSettings):
     )
 
 
-class DatabaseConfig:
+class DatabaseConfig(BaseSettings):
     DB_HOST: str = Field(
         description="Hostname or IP address of the database server.",
         default="localhost",

+ 3 - 2
api/configs/middleware/storage/baidu_obs_storage_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import Field
+from pydantic_settings import BaseSettings
 
 
-class BaiduOBSStorageConfig(BaseModel):
+class BaiduOBSStorageConfig(BaseSettings):
     """
     Configuration settings for Baidu Object Storage Service (OBS)
     """

+ 3 - 2
api/configs/middleware/storage/huawei_obs_storage_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import Field
+from pydantic_settings import BaseSettings
 
 
-class HuaweiCloudOBSStorageConfig(BaseModel):
+class HuaweiCloudOBSStorageConfig(BaseSettings):
     """
     Configuration settings for Huawei Cloud Object Storage Service (OBS)
     """

+ 3 - 2
api/configs/middleware/storage/supabase_storage_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import Field
+from pydantic_settings import BaseSettings
 
 
-class SupabaseStorageConfig(BaseModel):
+class SupabaseStorageConfig(BaseSettings):
     """
     Configuration settings for Supabase Object Storage Service
     """

+ 3 - 2
api/configs/middleware/storage/volcengine_tos_storage_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import Field
+from pydantic_settings import BaseSettings
 
 
-class VolcengineTOSStorageConfig(BaseModel):
+class VolcengineTOSStorageConfig(BaseSettings):
     """
     Configuration settings for Volcengine Tinder Object Storage (TOS)
     """

+ 3 - 2
api/configs/middleware/vdb/analyticdb_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field, PositiveInt
+from pydantic import Field, PositiveInt
+from pydantic_settings import BaseSettings
 
 
-class AnalyticdbConfig(BaseModel):
+class AnalyticdbConfig(BaseSettings):
     """
     Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
     Refer to the following documentation for details on obtaining credentials:

+ 3 - 2
api/configs/middleware/vdb/couchbase_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import Field
+from pydantic_settings import BaseSettings
 
 
-class CouchbaseConfig(BaseModel):
+class CouchbaseConfig(BaseSettings):
     """
     Couchbase configs
     """

+ 3 - 2
api/configs/middleware/vdb/myscale_config.py

@@ -1,7 +1,8 @@
-from pydantic import BaseModel, Field, PositiveInt
+from pydantic import Field, PositiveInt
+from pydantic_settings import BaseSettings
 
 
-class MyScaleConfig(BaseModel):
+class MyScaleConfig(BaseSettings):
     """
     Configuration settings for MyScale vector database
     """

+ 3 - 2
api/configs/middleware/vdb/vikingdb_config.py

@@ -1,9 +1,10 @@
 from typing import Optional
 
-from pydantic import BaseModel, Field
+from pydantic import Field
+from pydantic_settings import BaseSettings
 
 
-class VikingDBConfig(BaseModel):
+class VikingDBConfig(BaseSettings):
     """
     Configuration for connecting to Volcengine VikingDB.
     Refer to the following documentation for details on obtaining credentials:

+ 17 - 0
api/configs/remote_settings_sources/__init__.py

@@ -0,0 +1,17 @@
+from typing import Optional
+
+from pydantic import Field
+
+from .apollo import ApolloSettingsSourceInfo
+from .base import RemoteSettingsSource
+from .enums import RemoteSettingsSourceName
+
+
+class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
+    REMOTE_SETTINGS_SOURCE_NAME: Optional[RemoteSettingsSourceName] = Field(
+        description="name of remote config source",
+        default=None,
+    )
+
+
+__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]

+ 55 - 0
api/configs/remote_settings_sources/apollo/__init__.py

@@ -0,0 +1,55 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from pydantic import Field
+from pydantic.fields import FieldInfo
+from pydantic_settings import BaseSettings
+
+from configs.remote_settings_sources.base import RemoteSettingsSource
+
+from .client import ApolloClient
+
+
+class ApolloSettingsSourceInfo(BaseSettings):
+    """
+    Packaging build information
+    """
+
+    APOLLO_APP_ID: Optional[str] = Field(
+        description="apollo app_id",
+        default=None,
+    )
+
+    APOLLO_CLUSTER: Optional[str] = Field(
+        description="apollo cluster",
+        default=None,
+    )
+
+    APOLLO_CONFIG_URL: Optional[str] = Field(
+        description="apollo config url",
+        default=None,
+    )
+
+    APOLLO_NAMESPACE: Optional[str] = Field(
+        description="apollo namespace",
+        default=None,
+    )
+
+
+class ApolloSettingsSource(RemoteSettingsSource):
+    def __init__(self, configs: Mapping[str, Any]):
+        self.client = ApolloClient(
+            app_id=configs["APOLLO_APP_ID"],
+            cluster=configs["APOLLO_CLUSTER"],
+            config_url=configs["APOLLO_CONFIG_URL"],
+            start_hot_update=False,
+            _notification_map={configs["APOLLO_NAMESPACE"]: -1},
+        )
+        self.namespace = configs["APOLLO_NAMESPACE"]
+        self.remote_configs = self.client.get_all_dicts(self.namespace)
+
+    def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
+        if not isinstance(self.remote_configs, dict):
+            raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
+        field_value = self.remote_configs.get(field_name)
+        return field_value, field_name, False

+ 303 - 0
api/configs/remote_settings_sources/apollo/client.py

@@ -0,0 +1,303 @@
+import hashlib
+import json
+import logging
+import os
+import threading
+import time
+from pathlib import Path
+
+from .python_3x import http_request, makedirs_wrapper
+from .utils import (
+    CONFIGURATIONS,
+    NAMESPACE_NAME,
+    NOTIFICATION_ID,
+    get_value_from_dict,
+    init_ip,
+    no_key_cache_key,
+    signature,
+    url_encode_wrapper,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ApolloClient:
+    def __init__(
+        self,
+        config_url,
+        app_id,
+        cluster="default",
+        secret="",
+        start_hot_update=True,
+        change_listener=None,
+        _notification_map=None,
+    ):
+        # Core routing parameters
+        self.config_url = config_url
+        self.cluster = cluster
+        self.app_id = app_id
+
+        # Non-core parameters
+        self.ip = init_ip()
+        self.secret = secret
+
+        # Check the parameter variables
+
+        # Private control variables
+        self._cycle_time = 5
+        self._stopping = False
+        self._cache = {}
+        self._no_key = {}
+        self._hash = {}
+        self._pull_timeout = 75
+        self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
+        self._long_poll_thread = None
+        self._change_listener = change_listener  # "add" "delete" "update"
+        if _notification_map is None:
+            _notification_map = {"application": -1}
+        self._notification_map = _notification_map
+        self.last_release_key = None
+        # Private startup method
+        self._path_checker()
+        if start_hot_update:
+            self._start_hot_update()
+
+        # start the heartbeat thread
+        heartbeat = threading.Thread(target=self._heart_beat)
+        heartbeat.daemon = True
+        heartbeat.start()
+
+    def get_json_from_net(self, namespace="application"):
+        url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
+            self.config_url, self.app_id, self.cluster, namespace, "", self.ip
+        )
+        try:
+            code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
+            if code == 200:
+                if not body:
+                    logger.error(f"get_json_from_net load configs failed, body is {body}")
+                    return None
+                data = json.loads(body)
+                data = data["configurations"]
+                return_data = {CONFIGURATIONS: data}
+                return return_data
+            else:
+                return None
+        except Exception:
+            logger.exception("an error occurred in get_json_from_net")
+            return None
+
+    def get_value(self, key, default_val=None, namespace="application"):
+        try:
+            # read memory configuration
+            namespace_cache = self._cache.get(namespace)
+            val = get_value_from_dict(namespace_cache, key)
+            if val is not None:
+                return val
+
+            no_key = no_key_cache_key(namespace, key)
+            if no_key in self._no_key:
+                return default_val
+
+            # read the network configuration
+            namespace_data = self.get_json_from_net(namespace)
+            val = get_value_from_dict(namespace_data, key)
+            if val is not None:
+                self._update_cache_and_file(namespace_data, namespace)
+                return val
+
+            # read the file configuration
+            namespace_cache = self._get_local_cache(namespace)
+            val = get_value_from_dict(namespace_cache, key)
+            if val is not None:
+                self._update_cache_and_file(namespace_cache, namespace)
+                return val
+
+            # If all of them are not obtained, the default value is returned
+            # and the local cache is set to None
+            self._set_local_cache_none(namespace, key)
+            return default_val
+        except Exception:
+            logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
+            return default_val
+
+    # Set the key of a namespace to none, and do not set default val
+    # to ensure the real-time correctness of the function call.
+    # If the user does not have the same default val twice
+    # and the default val is used here, there may be a problem.
+    def _set_local_cache_none(self, namespace, key):
+        no_key = no_key_cache_key(namespace, key)
+        self._no_key[no_key] = key
+
+    def _start_hot_update(self):
+        self._long_poll_thread = threading.Thread(target=self._listener)
+        # When the asynchronous thread is started, the daemon thread will automatically exit
+        # when the main thread is launched.
+        self._long_poll_thread.daemon = True
+        self._long_poll_thread.start()
+
+    def stop(self):
+        self._stopping = True
+        logger.info("Stopping listener...")
+
+    # Call the set callback function, and if it is abnormal, try it out
+    def _call_listener(self, namespace, old_kv, new_kv):
+        if self._change_listener is None:
+            return
+        if old_kv is None:
+            old_kv = {}
+        if new_kv is None:
+            new_kv = {}
+        try:
+            for key in old_kv:
+                new_value = new_kv.get(key)
+                old_value = old_kv.get(key)
+                if new_value is None:
+                    # If newValue is empty, it means key, and the value is deleted.
+                    self._change_listener("delete", namespace, key, old_value)
+                    continue
+                if new_value != old_value:
+                    self._change_listener("update", namespace, key, new_value)
+                    continue
+            for key in new_kv:
+                new_value = new_kv.get(key)
+                old_value = old_kv.get(key)
+                if old_value is None:
+                    self._change_listener("add", namespace, key, new_value)
+        except BaseException as e:
+            logger.warning(str(e))
+
+    def _path_checker(self):
+        if not os.path.isdir(self._cache_file_path):
+            makedirs_wrapper(self._cache_file_path)
+
+    # update the local cache and file cache
+    def _update_cache_and_file(self, namespace_data, namespace="application"):
+        # update the local cache
+        self._cache[namespace] = namespace_data
+        # update the file cache
+        new_string = json.dumps(namespace_data)
+        new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
+        if self._hash.get(namespace) == new_hash:
+            pass
+        else:
+            file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
+            file_path.write_text(new_string)
+            self._hash[namespace] = new_hash
+
+    # get the configuration from the local file
+    def _get_local_cache(self, namespace="application"):
+        cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
+        if os.path.isfile(cache_file_path):
+            with open(cache_file_path) as f:
+                result = json.loads(f.readline())
+            return result
+        return {}
+
+    def _long_poll(self):
+        notifications = []
+        for key in self._cache:
+            namespace_data = self._cache[key]
+            notification_id = -1
+            if NOTIFICATION_ID in namespace_data:
+                notification_id = self._cache[key][NOTIFICATION_ID]
+            notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
+        try:
+            # if the length is 0 it is returned directly
+            if len(notifications) == 0:
+                return
+            url = "{}/notifications/v2".format(self.config_url)
+            params = {
+                "appId": self.app_id,
+                "cluster": self.cluster,
+                "notifications": json.dumps(notifications, ensure_ascii=False),
+            }
+            param_str = url_encode_wrapper(params)
+            url = url + "?" + param_str
+            code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
+            http_code = code
+            if http_code == 304:
+                logger.debug("No change, loop...")
+                return
+            if http_code == 200:
+                if not body:
+                    logger.error(f"_long_poll load configs failed,body is {body}")
+                    return
+                data = json.loads(body)
+                for entry in data:
+                    namespace = entry[NAMESPACE_NAME]
+                    n_id = entry[NOTIFICATION_ID]
+                    logger.info("%s has changes: notificationId=%d", namespace, n_id)
+                    self._get_net_and_set_local(namespace, n_id, call_change=True)
+                    return
+            else:
+                logger.warning("Sleep...")
+        except Exception as e:
+            logger.warning(str(e))
+
+    def _get_net_and_set_local(self, namespace, n_id, call_change=False):
+        namespace_data = self.get_json_from_net(namespace)
+        if not namespace_data:
+            return
+        namespace_data[NOTIFICATION_ID] = n_id
+        old_namespace = self._cache.get(namespace)
+        self._update_cache_and_file(namespace_data, namespace)
+        if self._change_listener is not None and call_change and old_namespace:
+            old_kv = old_namespace.get(CONFIGURATIONS)
+            new_kv = namespace_data.get(CONFIGURATIONS)
+            self._call_listener(namespace, old_kv, new_kv)
+
+    def _listener(self):
+        logger.info("start long_poll")
+        while not self._stopping:
+            self._long_poll()
+            time.sleep(self._cycle_time)
+        logger.info("stopped, long_poll")
+
+    # add the need for endorsement to the header
+    def _sign_headers(self, url):
+        headers = {}
+        if self.secret == "":
+            return headers
+        uri = url[len(self.config_url) : len(url)]
+        time_unix_now = str(int(round(time.time() * 1000)))
+        headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
+        headers["Timestamp"] = time_unix_now
+        return headers
+
+    def _heart_beat(self):
+        while not self._stopping:
+            for namespace in self._notification_map:
+                self._do_heart_beat(namespace)
+            time.sleep(60 * 10)  # 10分钟
+
+    def _do_heart_beat(self, namespace):
+        url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
+        try:
+            code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
+            if code == 200:
+                if not body:
+                    logger.error(f"_do_heart_beat load configs failed,body is {body}")
+                    return None
+                data = json.loads(body)
+                if self.last_release_key == data["releaseKey"]:
+                    return None
+                self.last_release_key = data["releaseKey"]
+                data = data["configurations"]
+                self._update_cache_and_file(data, namespace)
+            else:
+                return None
+        except Exception:
+            logger.exception("an error occurred in _do_heart_beat")
+            return None
+
+    def get_all_dicts(self, namespace):
+        namespace_data = self._cache.get(namespace)
+        if namespace_data is None:
+            net_namespace_data = self.get_json_from_net(namespace)
+            if not net_namespace_data:
+                return namespace_data
+            namespace_data = net_namespace_data.get(CONFIGURATIONS)
+            if namespace_data:
+                self._update_cache_and_file(namespace_data, namespace)
+        return namespace_data

+ 41 - 0
api/configs/remote_settings_sources/apollo/python_3x.py

@@ -0,0 +1,41 @@
+import logging
+import os
+import ssl
+import urllib.request
+from urllib import parse
+from urllib.error import HTTPError
+
+# Create an SSL context that allows for a lower level of security
+ssl_context = ssl.create_default_context()
+ssl_context.set_ciphers("HIGH:!DH:!aNULL")
+ssl_context.check_hostname = False
+ssl_context.verify_mode = ssl.CERT_NONE
+
+# Create an opener object and pass in a custom SSL context
+opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
+
+urllib.request.install_opener(opener)
+
+logger = logging.getLogger(__name__)
+
+
+def http_request(url, timeout, headers={}):
+    try:
+        request = urllib.request.Request(url, headers=headers)
+        res = urllib.request.urlopen(request, timeout=timeout)
+        body = res.read().decode("utf-8")
+        return res.code, body
+    except HTTPError as e:
+        if e.code == 304:
+            logger.warning("http_request error,code is 304, maybe you should check secret")
+            return 304, None
+        logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
+        raise e
+
+
+def url_encode(params):
+    return parse.urlencode(params)
+
+
+def makedirs_wrapper(path):
+    os.makedirs(path, exist_ok=True)

+ 51 - 0
api/configs/remote_settings_sources/apollo/utils.py

@@ -0,0 +1,51 @@
+import hashlib
+import socket
+
+from .python_3x import url_encode
+
+# define constants
+CONFIGURATIONS = "configurations"
+NOTIFICATION_ID = "notificationId"
+NAMESPACE_NAME = "namespaceName"
+
+
+# add timestamps uris and keys
+def signature(timestamp, uri, secret):
+    import base64
+    import hmac
+
+    string_to_sign = "" + timestamp + "\n" + uri
+    hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
+    return base64.b64encode(hmac_code).decode()
+
+
+def url_encode_wrapper(params):
+    return url_encode(params)
+
+
+def no_key_cache_key(namespace, key):
+    return "{}{}{}".format(namespace, len(namespace), key)
+
+
+# Returns whether the obtained value is obtained, and None if it does not
+def get_value_from_dict(namespace_cache, key):
+    if namespace_cache:
+        kv_data = namespace_cache.get(CONFIGURATIONS)
+        if kv_data is None:
+            return None
+        if key in kv_data:
+            return kv_data[key]
+    return None
+
+
+def init_ip():
+    ip = ""
+    s = None
+    try:
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        s.connect(("8.8.8.8", 53))
+        ip = s.getsockname()[0]
+    finally:
+        if s:
+            s.close()
+    return ip

+ 15 - 0
api/configs/remote_settings_sources/base.py

@@ -0,0 +1,15 @@
+from collections.abc import Mapping
+from typing import Any
+
+from pydantic.fields import FieldInfo
+
+
+class RemoteSettingsSource:
+    def __init__(self, configs: Mapping[str, Any]):
+        pass
+
+    def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
+        raise NotImplementedError
+
+    def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
+        return value

+ 5 - 0
api/configs/remote_settings_sources/enums.py

@@ -0,0 +1,5 @@
+from enum import StrEnum
+
+
+class RemoteSettingsSourceName(StrEnum):
+    APOLLO = "apollo"