|
@@ -1,13 +1,10 @@
|
|
|
-from concurrent.futures import ProcessPoolExecutor
|
|
|
-from os.path import abspath, dirname, join
|
|
|
from threading import Lock
|
|
|
-from typing import Any, cast
|
|
|
+from typing import Any
|
|
|
|
|
|
-from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
|
|
+import tiktoken
|
|
|
|
|
|
_tokenizer: Any = None
|
|
|
_lock = Lock()
|
|
|
-_executor = ProcessPoolExecutor(max_workers=1)
|
|
|
|
|
|
|
|
|
class GPT2Tokenizer:
|
|
@@ -17,22 +14,28 @@ class GPT2Tokenizer:
|
|
|
use gpt2 tokenizer to get num tokens
|
|
|
"""
|
|
|
_tokenizer = GPT2Tokenizer.get_encoder()
|
|
|
- tokens = _tokenizer.encode(text, verbose=False)
|
|
|
+ tokens = _tokenizer.encode(text)
|
|
|
return len(tokens)
|
|
|
|
|
|
@staticmethod
|
|
|
def get_num_tokens(text: str) -> int:
|
|
|
- future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
|
|
- result = future.result()
|
|
|
- return cast(int, result)
|
|
|
+ # Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
|
|
+ #
|
|
|
+ # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
|
|
+ # result = future.result()
|
|
|
+ # return cast(int, result)
|
|
|
+ return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
|
|
|
|
|
@staticmethod
|
|
|
def get_encoder() -> Any:
|
|
|
global _tokenizer, _lock
|
|
|
with _lock:
|
|
|
if _tokenizer is None:
|
|
|
- base_path = abspath(__file__)
|
|
|
- gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
|
|
- _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
|
|
+ # Try to use tiktoken to get the tokenizer because it is faster
|
|
|
+ #
|
|
|
+ _tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
+ # base_path = abspath(__file__)
|
|
|
+ # gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
|
|
+ # _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
|
|
|
|
|
return _tokenizer
|