Forráskód Böngészése

feat: tool credentials cache and introduce _position.yaml (#2386)

Yeuoly 1 éve
szülő
commit
5010706d8b

+ 49 - 0
api/core/helper/tool_provider_cache.py

@@ -0,0 +1,49 @@
+import json
+from enum import Enum
+from json import JSONDecodeError
+from typing import Optional
+
+from extensions.ext_redis import redis_client
+
+
+class ToolProviderCredentialsCacheType(Enum):
+    PROVIDER = "tool_provider"
+
+class ToolProviderCredentialsCache:
+    def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
+        self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
+
+    def get(self) -> Optional[dict]:
+        """
+        Get cached model provider credentials.
+
+        :return:
+        """
+        cached_provider_credentials = redis_client.get(self.cache_key)
+        if cached_provider_credentials:
+            try:
+                cached_provider_credentials = cached_provider_credentials.decode('utf-8')
+                cached_provider_credentials = json.loads(cached_provider_credentials)
+            except JSONDecodeError:
+                return None
+
+            return cached_provider_credentials
+        else:
+            return None
+
+    def set(self, credentials: dict) -> None:
+        """
+        Cache model provider credentials.
+
+        :param credentials: provider credentials
+        :return:
+        """
+        redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
+
+    def delete(self) -> None:
+        """
+        Delete cached model provider credentials.
+
+        :return:
+        """
+        redis_client.delete(self.cache_key)

+ 15 - 0
api/core/tools/provider/_position.yaml

@@ -0,0 +1,15 @@
+- google
+- bing
+- wikipedia
+- dalle
+- azuredalle
+- webscraper
+- wolframalpha
+- github
+- chart
+- time
+- yahoo
+- stablediffusion
+- vectorizer
+- youtube
+- gaode

+ 17 - 19
api/core/tools/provider/builtin/_positions.py

@@ -1,31 +1,29 @@
-from typing import List
-
 from core.tools.entities.user_entities import UserToolProvider
+from core.tools.entities.tool_entities import ToolProviderType
+from typing import List
+from yaml import load, FullLoader
 
-position = {
-    'google': 1,
-    'bing': 2,
-    'wikipedia': 2,
-    'dalle': 3,
-    'webscraper': 4,
-    'wolframalpha': 5,
-    'chart': 6,
-    'time': 7,
-    'yahoo': 8,
-    'stablediffusion': 9,
-    'vectorizer': 10,
-    'youtube': 11,
-    'github': 12,
-    'gaode': 13
-}
+import os.path
 
+position = {}
 
 class BuiltinToolProviderSort:
     @staticmethod
     def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
+        global position
+        if not position:
+            tmp_position = {}
+            file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
+            with open(file_path, 'r') as f:
+                for pos, val in enumerate(load(f, Loader=FullLoader)):
+                    tmp_position[val] = pos
+            position = tmp_position
+
         def sort_compare(provider: UserToolProvider) -> int:
+            # if provider.type == UserToolProvider.ProviderType.MODEL:
+            #     return position.get(f'model_provider.{provider.name}', 10000)
             return position.get(provider.name, 10000)
         
         sorted_providers = sorted(providers, key=sort_compare)
 
-        return sorted_providers
+        return sorted_providers

+ 14 - 6
api/core/tools/utils/configuration.py

@@ -1,10 +1,10 @@
-from typing import Any, Dict
+from typing import Dict, Any
+from pydantic import BaseModel
 
-from core.helper import encrypter
 from core.tools.entities.tool_entities import ToolProviderCredentials
 from core.tools.provider.tool_provider import ToolProviderController
-from pydantic import BaseModel
-
+from core.helper import encrypter
+from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache
 
 class ToolConfiguration(BaseModel):
     tenant_id: str
@@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):
 
         return a deep copy of credentials with decrypted values
         """
+        cache = ToolProviderCredentialsCache(
+            tenant_id=self.tenant_id, 
+            identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
+            cache_type=ToolProviderCredentialsCacheType.PROVIDER
+        )
+        cached_credentials = cache.get()
+        if cached_credentials:
+            return cached_credentials
         credentials = self._deep_copy(credentials)
-
         # get fields need to be decrypted
         fields = self.provider_controller.get_credentials_schema()
         for field_name, field in fields.items():
@@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
                         credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
                     except:
                         pass
-        
+
+        cache.set(credentials)
         return credentials