Sfoglia il codice sorgente

refactor: use tiktoken for token calculation (#12416)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 3 mesi fa
parent
commit
d3f5b1cbb6

+ 15 - 12
api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py

@@ -1,13 +1,10 @@
-from concurrent.futures import ProcessPoolExecutor
-from os.path import abspath, dirname, join
 from threading import Lock
-from typing import Any, cast
+from typing import Any
 
-from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer  # type: ignore
+import tiktoken
 
 _tokenizer: Any = None
 _lock = Lock()
-_executor = ProcessPoolExecutor(max_workers=1)
 
 
 class GPT2Tokenizer:
@@ -17,22 +14,28 @@ class GPT2Tokenizer:
         use gpt2 tokenizer to get num tokens
         """
         _tokenizer = GPT2Tokenizer.get_encoder()
-        tokens = _tokenizer.encode(text, verbose=False)
+        tokens = _tokenizer.encode(text)
         return len(tokens)
 
     @staticmethod
     def get_num_tokens(text: str) -> int:
-        future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
-        result = future.result()
-        return cast(int, result)
+        # Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
+        #
+        # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
+        # result = future.result()
+        # return cast(int, result)
+        return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
 
     @staticmethod
     def get_encoder() -> Any:
         global _tokenizer, _lock
         with _lock:
             if _tokenizer is None:
-                base_path = abspath(__file__)
-                gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
-                _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
+                # Try to use tiktoken to get the tokenizer because it is faster
+                #
+                _tokenizer = tiktoken.get_encoding("gpt2")
+                # base_path = abspath(__file__)
+                # gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
+                # _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
 
             return _tokenizer