|
@@ -1,11 +1,13 @@
|
|
|
from os.path import abspath, dirname, join
|
|
|
from threading import Lock
|
|
|
-from typing import Any
|
|
|
+from typing import Any, cast
|
|
|
|
|
|
+import gevent.threadpool # type: ignore
|
|
|
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
|
|
|
|
|
_tokenizer: Any = None
|
|
|
_lock = Lock()
|
|
|
+_pool = gevent.threadpool.ThreadPool(1)
|
|
|
|
|
|
|
|
|
class GPT2Tokenizer:
|
|
@@ -20,7 +22,9 @@ class GPT2Tokenizer:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_num_tokens(text: str) -> int:
|
|
|
- return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
|
|
+ future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
|
|
+ result = future.get(block=True)
|
|
|
+ return cast(int, result)
|
|
|
|
|
|
@staticmethod
|
|
|
def get_encoder() -> Any:
|