|
@@ -1,11 +1,15 @@
|
|
|
import array
|
|
|
import json
|
|
|
+import re
|
|
|
import uuid
|
|
|
from contextlib import contextmanager
|
|
|
from typing import Any
|
|
|
|
|
|
+import jieba.posseg as pseg
|
|
|
+import nltk
|
|
|
import numpy
|
|
|
import oracledb
|
|
|
+from nltk.corpus import stopwords
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
|
|
|
from configs import dify_config
|
|
@@ -50,6 +54,11 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
|
|
,embedding vector NOT NULL
|
|
|
)
|
|
|
"""
|
|
|
+SQL_CREATE_INDEX = """
|
|
|
+CREATE INDEX idx_docs_{table_name} ON {table_name}(text)
|
|
|
+INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
|
|
|
+('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER sys.my_chinese_vgram_lexer')
|
|
|
+"""
|
|
|
|
|
|
|
|
|
class OracleVector(BaseVector):
|
|
@@ -188,7 +197,53 @@ class OracleVector(BaseVector):
|
|
|
return docs
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
- # do not support bm25 search
|
|
|
+ top_k = kwargs.get("top_k", 5)
|
|
|
+ # just not implement fetch by score_threshold now, may be later
|
|
|
+ score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
|
|
+ if len(query) > 0:
|
|
|
+ # Check which language the query is in
|
|
|
+ zh_pattern = re.compile('[\u4e00-\u9fa5]+')
|
|
|
+ match = zh_pattern.search(query)
|
|
|
+ entities = []
|
|
|
+ # match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
|
|
|
+ if match:
|
|
|
+ words = pseg.cut(query)
|
|
|
+ current_entity = ""
|
|
|
+ for word, pos in words:
|
|
|
+ if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名
|
|
|
+ current_entity += word
|
|
|
+ else:
|
|
|
+ if current_entity:
|
|
|
+ entities.append(current_entity)
|
|
|
+ current_entity = ""
|
|
|
+ if current_entity:
|
|
|
+ entities.append(current_entity)
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ nltk.data.find('tokenizers/punkt')
|
|
|
+ nltk.data.find('corpora/stopwords')
|
|
|
+ except LookupError:
|
|
|
+ nltk.download('punkt')
|
|
|
+ nltk.download('stopwords')
|
|
|
+ print("run download")
|
|
|
+ e_str = re.sub(r'[^\w ]', '', query)
|
|
|
+ all_tokens = nltk.word_tokenize(e_str)
|
|
|
+ stop_words = stopwords.words('english')
|
|
|
+ for token in all_tokens:
|
|
|
+ if token not in stop_words:
|
|
|
+ entities.append(token)
|
|
|
+ with self._get_cursor() as cur:
|
|
|
+ cur.execute(
|
|
|
+ f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
|
|
+ [" ACCUM ".join(entities)]
|
|
|
+ )
|
|
|
+ docs = []
|
|
|
+ for record in cur:
|
|
|
+ metadata, text = record
|
|
|
+ docs.append(Document(page_content=text, metadata=metadata))
|
|
|
+ return docs
|
|
|
+ else:
|
|
|
+ return [Document(page_content="", metadata="")]
|
|
|
return []
|
|
|
|
|
|
def delete(self) -> None:
|
|
@@ -206,6 +261,8 @@ class OracleVector(BaseVector):
|
|
|
with self._get_cursor() as cur:
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
+ with self._get_cursor() as cur:
|
|
|
+ cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
|
|
|
|
|
|
|
|
class OracleVectorFactory(AbstractVectorFactory):
|