Ver Fonte

use redis to cache embeddings (#2085)

Co-authored-by: jyong <jyong@dify.ai>
Jyong há 1 ano atrás
pai
commit
a3c7c07ecc
1 ficheiros alterados com 35 adições e 14 exclusões
  1. 35 14
      api/core/embedding/cached_embedding.py

+ 35 - 14
api/core/embedding/cached_embedding.py

@@ -1,3 +1,5 @@
+import base64
+import json
 import logging
 from typing import List, Optional
 
@@ -5,6 +7,8 @@ import numpy as np
 from core.model_manager import ModelInstance
 from extensions.ext_database import db
 from langchain.embeddings.base import Embeddings
+
+from extensions.ext_redis import redis_client
 from libs import helper
 from models.dataset import Embedding
 from sqlalchemy.exc import IntegrityError
@@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
         embedding_queue_indices = []
         for i, text in enumerate(texts):
             hash = helper.generate_text_hash(text)
-            embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
+            embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
+            embedding = redis_client.get(embedding_cache_key)
             if embedding:
-                text_embeddings[i] = embedding.get_embedding()
+                redis_client.expire(embedding_cache_key, 3600)
+                text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
+
             else:
                 embedding_queue_indices.append(i)
 
@@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
                 hash = helper.generate_text_hash(texts[indice])
 
                 try:
-                    embedding = Embedding(model_name=self._model_instance.model, hash=hash)
+                    embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
                     vector = embedding_results[i]
                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
                     text_embeddings[indice] = normalized_embedding
-                    embedding.set_embedding(normalized_embedding)
-                    db.session.add(embedding)
-                    db.session.commit()
+                    # encode embedding to base64
+                    embedding_vector = np.array(normalized_embedding)
+                    vector_bytes = embedding_vector.tobytes()
+                    # Transform to Base64
+                    encoded_vector = base64.b64encode(vector_bytes)
+                    # Transform to string
+                    encoded_str = encoded_vector.decode("utf-8")
+                    redis_client.setex(embedding_cache_key, 3600, encoded_str)
+
                 except IntegrityError:
                     db.session.rollback()
                     continue
                 except:
-                    logging.exception('Failed to add embedding to db')
+                    logging.exception('Failed to add embedding to redis')
                     continue
 
         return text_embeddings
@@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
         """Embed query text."""
         # use doc embedding cache or store if not exists
         hash = helper.generate_text_hash(text)
-        embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
+        embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
+        embedding = redis_client.get(embedding_cache_key)
         if embedding:
-            return embedding.get_embedding()
+            redis_client.expire(embedding_cache_key, 3600)
+            return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
+
 
         try:
             embedding_result = self._model_instance.invoke_text_embedding(
@@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
             raise ex
 
         try:
-            embedding = Embedding(model_name=self._model_instance.model, hash=hash)
-            embedding.set_embedding(embedding_results)
-            db.session.add(embedding)
-            db.session.commit()
+            # encode embedding to base64
+            embedding_vector = np.array(embedding_results)
+            vector_bytes = embedding_vector.tobytes()
+            # Transform to Base64
+            encoded_vector = base64.b64encode(vector_bytes)
+            # Transform to string
+            encoded_str = encoded_vector.decode("utf-8")
+            redis_client.setex(embedding_cache_key, 3600, encoded_str)
+
         except IntegrityError:
             db.session.rollback()
         except:
-            logging.exception('Failed to add embedding to db')
+            logging.exception('Failed to add embedding to redis')
 
         return embedding_results