浏览代码

fix: import jieba.analyse (#12133)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 3 月之前
父节点
当前提交
dae1b5a619
共有 1 个文件被更改,包括 6 次插入4 次删除
  1. 6 4
      api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

+ 6 - 4
api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

@@ -1,5 +1,5 @@
 import re
-from typing import Optional
+from typing import Optional, cast
 
 
 class JiebaKeywordTableHandler:
@@ -8,18 +8,20 @@ class JiebaKeywordTableHandler:
 
         from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
 
-        jieba.analyse.default_tfidf.stop_words = STOPWORDS
+        jieba.analyse.default_tfidf.stop_words = STOPWORDS  # type: ignore
 
     def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
         """Extract keywords with JIEBA tfidf."""
-        import jieba  # type: ignore
+        import jieba.analyse  # type: ignore
 
         keywords = jieba.analyse.extract_tags(
             sentence=text,
             topK=max_keywords_per_chunk,
         )
+        # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
+        keywords = cast(list[str], keywords)
 
-        return set(self._expand_tokens_with_subtokens(keywords))
+        return set(self._expand_tokens_with_subtokens(set(keywords)))
 
     def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
         """Get subtokens from a list of tokens., filtering for stopwords."""