|
@@ -1,8 +1,6 @@
|
|
|
from threading import Lock
|
|
|
from typing import Any
|
|
|
|
|
|
-import tiktoken
|
|
|
-
|
|
|
_tokenizer: Any = None
|
|
|
_lock = Lock()
|
|
|
|
|
@@ -33,9 +31,17 @@ class GPT2Tokenizer:
|
|
|
if _tokenizer is None:
|
|
|
# Try to use tiktoken to get the tokenizer because it is faster
|
|
|
#
|
|
|
- _tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
- # base_path = abspath(__file__)
|
|
|
- # gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
|
|
- # _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
|
|
+ try:
|
|
|
+ import tiktoken
|
|
|
+
|
|
|
+ _tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
+ except Exception:
|
|
|
+ from os.path import abspath, dirname, join
|
|
|
+
|
|
|
+ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
|
|
+
|
|
|
+ base_path = abspath(__file__)
|
|
|
+ gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
|
|
+ _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
|
|
|
|
|
return _tokenizer
|