Ver Fonte

Add search by full text when using Oracle23ai as vector DB (#6559)

tmuife há 9 meses atrás
pai
commit
06fc1bce9e

+ 4 - 4
api/controllers/console/datasets/datasets.py

@@ -543,13 +543,13 @@ class DatasetRetrievalSettingApi(Resource):
     def get(self):
         vector_type = dify_config.VECTOR_STORE
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
                 return {
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH.value
                     ]
                 }
-            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
+            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
                 return {
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH.value,
@@ -567,13 +567,13 @@ class DatasetRetrievalSettingMockApi(Resource):
     @account_initialization_required
     def get(self, vector_type):
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
                 return {
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH.value
                     ]
                 }
-            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
+            case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
                 return {
                     'retrieval_method': [
                         RetrievalMethod.SEMANTIC_SEARCH.value,

+ 58 - 1
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -1,11 +1,15 @@
 import array
 import json
+import re
 import uuid
 from contextlib import contextmanager
 from typing import Any
 
+import jieba.posseg as pseg
+import nltk
 import numpy
 import oracledb
+from nltk.corpus import stopwords
 from pydantic import BaseModel, model_validator
 
 from configs import dify_config
@@ -50,6 +54,11 @@ CREATE TABLE IF NOT EXISTS {table_name} (
     ,embedding vector NOT NULL
 )
 """
+SQL_CREATE_INDEX = """
+CREATE INDEX idx_docs_{table_name} ON {table_name}(text) 
+INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS 
+('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER sys.my_chinese_vgram_lexer')
+"""
 
 
 class OracleVector(BaseVector):
@@ -188,7 +197,53 @@ class OracleVector(BaseVector):
         return docs
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
-        # do not support bm25 search
+        top_k = kwargs.get("top_k", 5)
+        # just not implement fetch by score_threshold now, may be later
+        score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
+        if len(query) > 0:
+            # Check which language the query is in
+            zh_pattern = re.compile('[\u4e00-\u9fa5]+')
+            match = zh_pattern.search(query)
+            entities = []
+            #  match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
+            if match:
+                words = pseg.cut(query)
+                current_entity = ""
+                for word, pos in words:
+                    if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v':  # nr: 人名, ns: 地名, nt: 机构名
+                        current_entity += word
+                    else:
+                        if current_entity:
+                            entities.append(current_entity)
+                            current_entity = ""
+                if current_entity:
+                    entities.append(current_entity)
+            else:
+                try:
+                    nltk.data.find('tokenizers/punkt')
+                    nltk.data.find('corpora/stopwords')
+                except LookupError:
+                    nltk.download('punkt')
+                    nltk.download('stopwords')
+                    print("run download")
+                e_str = re.sub(r'[^\w ]', '', query)
+                all_tokens = nltk.word_tokenize(e_str)
+                stop_words = stopwords.words('english')
+                for token in all_tokens:
+                    if token not in stop_words:
+                        entities.append(token)
+            with self._get_cursor() as cur:
+                cur.execute(
+                    f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
+                    [" ACCUM ".join(entities)]
+                )
+                docs = []
+                for record in cur:
+                    metadata, text = record
+                    docs.append(Document(page_content=text, metadata=metadata))
+            return docs
+        else:
+            return [Document(page_content="", metadata="")]
         return []
 
     def delete(self) -> None:
@@ -206,6 +261,8 @@ class OracleVector(BaseVector):
             with self._get_cursor() as cur:
                 cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
             redis_client.set(collection_exist_cache_key, 1, ex=3600)
+            with self._get_cursor() as cur:
+                cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
 
 
 class OracleVectorFactory(AbstractVectorFactory):

+ 13 - 0
docker/startupscripts/init.sh

@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+
+DB_INITIALISED="/opt/oracle/oradata/dbinit"
+#[ -f ${DB_INITIALISED} ] && exit
+#touch ${DB_INITIALISED}
+if [ -f ${DB_INITIALISED} ]; then
+  echo 'File exists. Standards for have been Init'
+  exit
+else
+  echo 'File does not exist. Standards for first time Strart up this DB'
+  "$ORACLE_HOME"/bin/sqlplus -s "/ as sysdba" @"/opt/oracle/scripts/startup/init_user.script"; 
+  touch ${DB_INITIALISED}
+fi

+ 5 - 0
docker/startupscripts/create_user.sql → docker/startupscripts/init_user.script

@@ -3,3 +3,8 @@ ALTER SYSTEM SET PROCESSES=500 SCOPE=SPFILE;
 alter session set container= freepdb1;
 create user dify identified by dify DEFAULT TABLESPACE users quota unlimited on users;
 grant DB_DEVELOPER_ROLE to dify;
+
+BEGIN
+CTX_DDL.CREATE_PREFERENCE('my_chinese_vgram_lexer','CHINESE_VGRAM_LEXER');
+END;
+/