Преглед на файлове

Fix/jina tokenizer cache (#2735)

Yeuoly преди 1 година
родител
ревизия
8fe83750b7
променени са 1 файла, в които са добавени 20 реда и са изтрити 8 реда
  1. 20 8
      api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py

+ 20 - 8
api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py

@@ -1,20 +1,32 @@
 from os.path import abspath, dirname, join
+from threading import Lock
 
 from transformers import AutoTokenizer
 
 
 class JinaTokenizer:
-    @staticmethod
-    def _get_num_tokens_by_jina_base(text: str) -> int:
+    _tokenizer = None
+    _lock = Lock()
+
+    @classmethod
+    def _get_tokenizer(cls):
+        if cls._tokenizer is None:
+            with cls._lock:
+                if cls._tokenizer is None:
+                    base_path = abspath(__file__)
+                    gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
+                    cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
+        return cls._tokenizer
+
+    @classmethod
+    def _get_num_tokens_by_jina_base(cls, text: str) -> int:
         """
             use jina tokenizer to get num tokens
         """
-        base_path = abspath(__file__)
-        gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
-        tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
+        tokenizer = cls._get_tokenizer()
         tokens = tokenizer.encode(text)
         return len(tokens)
     
-    @staticmethod
-    def get_num_tokens(text: str) -> int:
-        return JinaTokenizer._get_num_tokens_by_jina_base(text)
+    @classmethod
+    def get_num_tokens(cls, text: str) -> int:
+        return cls._get_num_tokens_by_jina_base(text)