|
@@ -1,3 +1,5 @@
|
|
|
+import base64
|
|
|
+import json
|
|
|
import logging
|
|
|
from typing import List, Optional
|
|
|
|
|
@@ -5,6 +7,8 @@ import numpy as np
|
|
|
from core.model_manager import ModelInstance
|
|
|
from extensions.ext_database import db
|
|
|
from langchain.embeddings.base import Embeddings
|
|
|
+
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
from libs import helper
|
|
|
from models.dataset import Embedding
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
@@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
|
|
|
embedding_queue_indices = []
|
|
|
for i, text in enumerate(texts):
|
|
|
hash = helper.generate_text_hash(text)
|
|
|
- embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
|
|
+ embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
|
|
+ embedding = redis_client.get(embedding_cache_key)
|
|
|
if embedding:
|
|
|
- text_embeddings[i] = embedding.get_embedding()
|
|
|
+ redis_client.expire(embedding_cache_key, 3600)
|
|
|
+ text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
|
|
+
|
|
|
else:
|
|
|
embedding_queue_indices.append(i)
|
|
|
|
|
@@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
|
|
|
hash = helper.generate_text_hash(texts[indice])
|
|
|
|
|
|
try:
|
|
|
- embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
|
|
+ embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
|
|
vector = embedding_results[i]
|
|
|
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
|
|
text_embeddings[indice] = normalized_embedding
|
|
|
- embedding.set_embedding(normalized_embedding)
|
|
|
- db.session.add(embedding)
|
|
|
- db.session.commit()
|
|
|
+ # encode embedding to base64
|
|
|
+ embedding_vector = np.array(normalized_embedding)
|
|
|
+ vector_bytes = embedding_vector.tobytes()
|
|
|
+ # Transform to Base64
|
|
|
+ encoded_vector = base64.b64encode(vector_bytes)
|
|
|
+ # Transform to string
|
|
|
+ encoded_str = encoded_vector.decode("utf-8")
|
|
|
+ redis_client.setex(embedding_cache_key, 3600, encoded_str)
|
|
|
+
|
|
|
except IntegrityError:
|
|
|
db.session.rollback()
|
|
|
continue
|
|
|
except:
|
|
|
- logging.exception('Failed to add embedding to db')
|
|
|
+ logging.exception('Failed to add embedding to redis')
|
|
|
continue
|
|
|
|
|
|
return text_embeddings
|
|
@@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
|
|
|
"""Embed query text."""
|
|
|
# use doc embedding cache or store if not exists
|
|
|
hash = helper.generate_text_hash(text)
|
|
|
- embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
|
|
+ embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
|
|
+ embedding = redis_client.get(embedding_cache_key)
|
|
|
if embedding:
|
|
|
- return embedding.get_embedding()
|
|
|
+ redis_client.expire(embedding_cache_key, 3600)
|
|
|
+ return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
|
|
+
|
|
|
|
|
|
try:
|
|
|
embedding_result = self._model_instance.invoke_text_embedding(
|
|
@@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
|
|
|
raise ex
|
|
|
|
|
|
try:
|
|
|
- embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
|
|
- embedding.set_embedding(embedding_results)
|
|
|
- db.session.add(embedding)
|
|
|
- db.session.commit()
|
|
|
+ # encode embedding to base64
|
|
|
+ embedding_vector = np.array(embedding_results)
|
|
|
+ vector_bytes = embedding_vector.tobytes()
|
|
|
+ # Transform to Base64
|
|
|
+ encoded_vector = base64.b64encode(vector_bytes)
|
|
|
+ # Transform to string
|
|
|
+ encoded_str = encoded_vector.decode("utf-8")
|
|
|
+ redis_client.setex(embedding_cache_key, 3600, encoded_str)
|
|
|
+
|
|
|
except IntegrityError:
|
|
|
db.session.rollback()
|
|
|
except:
|
|
|
- logging.exception('Failed to add embedding to db')
|
|
|
+ logging.exception('Failed to add embedding to redis')
|
|
|
|
|
|
return embedding_results
|