Quellcode durchsuchen

feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search (#7641)

Co-authored-by: haokai <haokai@shuwen.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
Co-authored-by: wellCh4n <wellCh4n@foxmail.com>
Kenn vor 7 Monaten
Ursprung
Commit
122ce41020

+ 2 - 0
api/configs/middleware/__init__.py

@@ -13,6 +13,7 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
 from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
 from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
 from configs.middleware.vdb.chroma_config import ChromaConfig
+from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
 from configs.middleware.vdb.milvus_config import MilvusConfig
 from configs.middleware.vdb.myscale_config import MyScaleConfig
 from configs.middleware.vdb.opensearch_config import OpenSearchConfig
@@ -200,5 +201,6 @@ class MiddlewareConfig(
     TencentVectorDBConfig,
     TiDBVectorConfig,
     WeaviateConfig,
+    ElasticsearchConfig,
 ):
     pass

+ 30 - 0
api/configs/middleware/vdb/elasticsearch_config.py

@@ -0,0 +1,30 @@
+from typing import Optional
+
+from pydantic import Field, PositiveInt
+from pydantic_settings import BaseSettings
+
+
+class ElasticsearchConfig(BaseSettings):
+    """
+    Elasticsearch configs
+    """
+
+    ELASTICSEARCH_HOST: Optional[str] = Field(
+        description="Elasticsearch host",
+        default="127.0.0.1",
+    )
+
+    ELASTICSEARCH_PORT: PositiveInt = Field(
+        description="Elasticsearch port",
+        default=9200,
+    )
+
+    ELASTICSEARCH_USERNAME: Optional[str] = Field(
+        description="Elasticsearch username",
+        default="elastic",
+    )
+
+    ELASTICSEARCH_PASSWORD: Optional[str] = Field(
+        description="Elasticsearch password",
+        default="elastic",
+    )

+ 79 - 52
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -1,5 +1,7 @@
 import json
-from typing import Any
+import logging
+from typing import Any, Optional
+from urllib.parse import urlparse
 
 import requests
 from elasticsearch import Elasticsearch
@@ -7,16 +9,20 @@ from flask import current_app
 from pydantic import BaseModel, model_validator
 
 from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
 from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
 from models.dataset import Dataset
 
+logger = logging.getLogger(__name__)
+
 
 class ElasticSearchConfig(BaseModel):
     host: str
-    port: str
+    port: int
     username: str
     password: str
 
@@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector):
     def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
         super().__init__(index_name.lower())
         self._client = self._init_client(config)
+        self._version = self._get_version()
+        self._check_version()
         self._attributes = attributes
 
     def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
         try:
+            parsed_url = urlparse(config.host)
+            if parsed_url.scheme in ['http', 'https']:
+                hosts = f'{config.host}:{config.port}'
+            else:
+                hosts = f'http://{config.host}:{config.port}'
             client = Elasticsearch(
-                hosts=f'{config.host}:{config.port}',
+                hosts=hosts,
                 basic_auth=(config.username, config.password),
                 request_timeout=100000,
                 retry_on_timeout=True,
@@ -53,42 +66,27 @@ class ElasticSearchVector(BaseVector):
 
         return client
 
+    def _get_version(self) -> str:
+        info = self._client.info()
+        return info['version']['number']
+
+    def _check_version(self):
+        if self._version < '8.0.0':
+            raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
+
     def get_type(self) -> str:
         return 'elasticsearch'
 
     def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
         uuids = self._get_uuids(documents)
-        texts = [d.page_content for d in documents]
-        metadatas = [d.metadata for d in documents]
-
-        if not self._client.indices.exists(index=self._collection_name):
-            dim = len(embeddings[0])
-            mapping = {
-                "properties": {
-                    "text": {
-                        "type": "text"
-                    },
-                    "vector": {
-                        "type": "dense_vector",
-                        "index": True,
-                        "dims": dim,
-                        "similarity": "l2_norm"
-                    },
-                }
-            }
-            self._client.indices.create(index=self._collection_name, mappings=mapping)
-
-        added_ids = []
-        for i, text in enumerate(texts):
+        for i in range(len(documents)):
             self._client.index(index=self._collection_name,
                                id=uuids[i],
                                document={
-                                   "text": text,
-                                   "vector": embeddings[i] if embeddings[i] else None,
-                                   "metadata": metadatas[i] if metadatas[i] else {},
+                                   Field.CONTENT_KEY.value: documents[i].page_content,
+                                   Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
+                                   Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
                                })
-            added_ids.append(uuids[i])
-
         self._client.indices.refresh(index=self._collection_name)
         return uuids
 
@@ -116,28 +114,21 @@ class ElasticSearchVector(BaseVector):
         self._client.indices.delete(index=self._collection_name)
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        query_str = {
-            "query": {
-                "script_score": {
-                    "query": {
-                        "match_all": {}
-                    },
-                    "script": {
-                        "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
-                        "params": {
-                            "query_vector": query_vector
-                        }
-                    }
-                }
-            }
+        top_k = kwargs.get("top_k", 10)
+        knn = {
+            "field": Field.VECTOR.value,
+            "query_vector": query_vector,
+            "k": top_k
         }
 
-        results = self._client.search(index=self._collection_name, body=query_str)
+        results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
 
         docs_and_scores = []
         for hit in results['hits']['hits']:
             docs_and_scores.append(
-                (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
+                (Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
+                          vector=hit['_source'][Field.VECTOR.value],
+                          metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
 
         docs = []
         for doc, score in docs_and_scores:
@@ -146,25 +137,61 @@ class ElasticSearchVector(BaseVector):
                 doc.metadata['score'] = score
             docs.append(doc)
 
-        # Sort the documents by score in descending order
-        docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
-
         return docs
+
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         query_str = {
             "match": {
-                "text": query
+                Field.CONTENT_KEY.value: query
             }
         }
         results = self._client.search(index=self._collection_name, query=query_str)
         docs = []
         for hit in results['hits']['hits']:
-            docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
+            docs.append(Document(
+                page_content=hit['_source'][Field.CONTENT_KEY.value],
+                vector=hit['_source'][Field.VECTOR.value],
+                metadata=hit['_source'][Field.METADATA_KEY.value],
+            ))
 
         return docs
 
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
-        return self.add_texts(texts, embeddings, **kwargs)
+        metadatas = [d.metadata for d in texts]
+        self.create_collection(embeddings, metadatas)
+        self.add_texts(texts, embeddings, **kwargs)
+
+    def create_collection(
+            self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
+    ):
+        lock_name = f'vector_indexing_lock_{self._collection_name}'
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
+            if redis_client.get(collection_exist_cache_key):
+                logger.info(f"Collection {self._collection_name} already exists.")
+                return
+
+            if not self._client.indices.exists(index=self._collection_name):
+                dim = len(embeddings[0])
+                mappings = {
+                    "properties": {
+                        Field.CONTENT_KEY.value: {"type": "text"},
+                        Field.VECTOR.value: {  # Make sure the dimension is correct here
+                            "type": "dense_vector",
+                            "dims": dim,
+                            "similarity": "cosine"
+                        },
+                        Field.METADATA_KEY.value: {
+                            "type": "object",
+                            "properties": {
+                                "doc_id": {"type": "keyword"}  # Map doc_id to keyword type
+                            }
+                        }
+                    }
+                }
+                self._client.indices.create(index=self._collection_name, mappings=mappings)
+
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
 
 
 class ElasticSearchVectorFactory(AbstractVectorFactory):