cached_embedding.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import base64
  2. import json
  3. import logging
  4. from typing import List, Optional
  5. import numpy as np
  6. from core.model_manager import ModelInstance
  7. from extensions.ext_database import db
  8. from langchain.embeddings.base import Embeddings
  9. from extensions.ext_redis import redis_client
  10. from libs import helper
  11. from models.dataset import Embedding
  12. from sqlalchemy.exc import IntegrityError
  13. logger = logging.getLogger(__name__)
  14. class CacheEmbedding(Embeddings):
  15. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  16. self._model_instance = model_instance
  17. self._user = user
  18. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  19. """Embed search docs."""
  20. # use doc embedding cache or store if not exists
  21. text_embeddings = [None for _ in range(len(texts))]
  22. embedding_queue_indices = []
  23. for i, text in enumerate(texts):
  24. hash = helper.generate_text_hash(text)
  25. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  26. embedding = redis_client.get(embedding_cache_key)
  27. if embedding:
  28. redis_client.expire(embedding_cache_key, 3600)
  29. text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  30. else:
  31. embedding_queue_indices.append(i)
  32. if embedding_queue_indices:
  33. try:
  34. embedding_result = self._model_instance.invoke_text_embedding(
  35. texts=[texts[i] for i in embedding_queue_indices],
  36. user=self._user
  37. )
  38. embedding_results = embedding_result.embeddings
  39. except Exception as ex:
  40. logger.error('Failed to embed documents: ', ex)
  41. raise ex
  42. for i, indice in enumerate(embedding_queue_indices):
  43. hash = helper.generate_text_hash(texts[indice])
  44. try:
  45. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  46. vector = embedding_results[i]
  47. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  48. text_embeddings[indice] = normalized_embedding
  49. # encode embedding to base64
  50. embedding_vector = np.array(normalized_embedding)
  51. vector_bytes = embedding_vector.tobytes()
  52. # Transform to Base64
  53. encoded_vector = base64.b64encode(vector_bytes)
  54. # Transform to string
  55. encoded_str = encoded_vector.decode("utf-8")
  56. redis_client.setex(embedding_cache_key, 3600, encoded_str)
  57. except IntegrityError:
  58. db.session.rollback()
  59. continue
  60. except:
  61. logging.exception('Failed to add embedding to redis')
  62. continue
  63. return text_embeddings
  64. def embed_query(self, text: str) -> List[float]:
  65. """Embed query text."""
  66. # use doc embedding cache or store if not exists
  67. hash = helper.generate_text_hash(text)
  68. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  69. embedding = redis_client.get(embedding_cache_key)
  70. if embedding:
  71. redis_client.expire(embedding_cache_key, 3600)
  72. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  73. try:
  74. embedding_result = self._model_instance.invoke_text_embedding(
  75. texts=[text],
  76. user=self._user
  77. )
  78. embedding_results = embedding_result.embeddings[0]
  79. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  80. except Exception as ex:
  81. raise ex
  82. try:
  83. # encode embedding to base64
  84. embedding_vector = np.array(embedding_results)
  85. vector_bytes = embedding_vector.tobytes()
  86. # Transform to Base64
  87. encoded_vector = base64.b64encode(vector_bytes)
  88. # Transform to string
  89. encoded_str = encoded_vector.decode("utf-8")
  90. redis_client.setex(embedding_cache_key, 3600, encoded_str)
  91. except IntegrityError:
  92. db.session.rollback()
  93. except:
  94. logging.exception('Failed to add embedding to redis')
  95. return embedding_results