Jelajahi Sumber

feat: implement asynchronous token counting in GPT2Tokenizer (#12239)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 3 bulan lalu
induk
melakukan
6a85960605

+ 6 - 2
api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py

@@ -1,11 +1,13 @@
 from os.path import abspath, dirname, join
 from threading import Lock
-from typing import Any
+from typing import Any, cast
 
+import gevent.threadpool  # type: ignore
 from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer  # type: ignore
 
 _tokenizer: Any = None
 _lock = Lock()
+_pool = gevent.threadpool.ThreadPool(1)
 
 
 class GPT2Tokenizer:
@@ -20,7 +22,9 @@ class GPT2Tokenizer:
 
     @staticmethod
     def get_num_tokens(text: str) -> int:
-        return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
+        future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
+        result = future.get(block=True)
+        return cast(int, result)
 
     @staticmethod
     def get_encoder() -> Any: