|
@@ -1,20 +1,32 @@
|
|
|
from os.path import abspath, dirname, join
|
|
|
+from threading import Lock
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
class JinaTokenizer:
|
|
|
- @staticmethod
|
|
|
- def _get_num_tokens_by_jina_base(text: str) -> int:
|
|
|
+ _tokenizer = None
|
|
|
+ _lock = Lock()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_tokenizer(cls):
|
|
|
+ if cls._tokenizer is None:
|
|
|
+ with cls._lock:
|
|
|
+ if cls._tokenizer is None:
|
|
|
+ base_path = abspath(__file__)
|
|
|
+ gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
|
|
+ cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
|
|
+ return cls._tokenizer
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
|
|
"""
|
|
|
use jina tokenizer to get num tokens
|
|
|
"""
|
|
|
- base_path = abspath(__file__)
|
|
|
- gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
|
|
+ tokenizer = cls._get_tokenizer()
|
|
|
tokens = tokenizer.encode(text)
|
|
|
return len(tokens)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def get_num_tokens(text: str) -> int:
|
|
|
- return JinaTokenizer._get_num_tokens_by_jina_base(text)
|
|
|
+ @classmethod
|
|
|
+ def get_num_tokens(cls, text: str) -> int:
|
|
|
+ return cls._get_num_tokens_by_jina_base(text)
|