cached_embedding.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import base64
  2. import logging
  3. from typing import Optional, cast
  4. import numpy as np
  5. from sqlalchemy.exc import IntegrityError
  6. from core.model_manager import ModelInstance
  7. from core.model_runtime.entities.model_entities import ModelPropertyKey
  8. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  9. from core.rag.datasource.entity.embedding import Embeddings
  10. from extensions.ext_database import db
  11. from extensions.ext_redis import redis_client
  12. from libs import helper
  13. from models.dataset import Embedding
  14. logger = logging.getLogger(__name__)
  15. class CacheEmbedding(Embeddings):
  16. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  17. self._model_instance = model_instance
  18. self._user = user
  19. def embed_documents(self, texts: list[str]) -> list[list[float]]:
  20. """Embed search docs in batches of 10."""
  21. # use doc embedding cache or store if not exists
  22. text_embeddings = [None for _ in range(len(texts))]
  23. embedding_queue_indices = []
  24. for i, text in enumerate(texts):
  25. hash = helper.generate_text_hash(text)
  26. embedding = (
  27. db.session.query(Embedding)
  28. .filter_by(
  29. model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
  30. )
  31. .first()
  32. )
  33. if embedding:
  34. text_embeddings[i] = embedding.get_embedding()
  35. else:
  36. embedding_queue_indices.append(i)
  37. if embedding_queue_indices:
  38. embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
  39. embedding_queue_embeddings = []
  40. try:
  41. model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
  42. model_schema = model_type_instance.get_model_schema(
  43. self._model_instance.model, self._model_instance.credentials
  44. )
  45. max_chunks = (
  46. model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
  47. if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
  48. else 1
  49. )
  50. for i in range(0, len(embedding_queue_texts), max_chunks):
  51. batch_texts = embedding_queue_texts[i : i + max_chunks]
  52. embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
  53. for vector in embedding_result.embeddings:
  54. try:
  55. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  56. embedding_queue_embeddings.append(normalized_embedding)
  57. except IntegrityError:
  58. db.session.rollback()
  59. except Exception as e:
  60. logging.exception("Failed transform embedding: %s", e)
  61. cache_embeddings = []
  62. try:
  63. for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
  64. text_embeddings[i] = embedding
  65. hash = helper.generate_text_hash(texts[i])
  66. if hash not in cache_embeddings:
  67. embedding_cache = Embedding(
  68. model_name=self._model_instance.model,
  69. hash=hash,
  70. provider_name=self._model_instance.provider,
  71. )
  72. embedding_cache.set_embedding(embedding)
  73. db.session.add(embedding_cache)
  74. cache_embeddings.append(hash)
  75. db.session.commit()
  76. except IntegrityError:
  77. db.session.rollback()
  78. except Exception as ex:
  79. db.session.rollback()
  80. logger.error("Failed to embed documents: %s", ex)
  81. raise ex
  82. return text_embeddings
  83. def embed_query(self, text: str) -> list[float]:
  84. """Embed query text."""
  85. # use doc embedding cache or store if not exists
  86. hash = helper.generate_text_hash(text)
  87. embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
  88. embedding = redis_client.get(embedding_cache_key)
  89. if embedding:
  90. redis_client.expire(embedding_cache_key, 600)
  91. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  92. try:
  93. embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
  94. embedding_results = embedding_result.embeddings[0]
  95. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  96. except Exception as ex:
  97. raise ex
  98. try:
  99. # encode embedding to base64
  100. embedding_vector = np.array(embedding_results)
  101. vector_bytes = embedding_vector.tobytes()
  102. # Transform to Base64
  103. encoded_vector = base64.b64encode(vector_bytes)
  104. # Transform to string
  105. encoded_str = encoded_vector.decode("utf-8")
  106. redis_client.setex(embedding_cache_key, 600, encoded_str)
  107. except Exception as ex:
  108. logging.exception("Failed to add embedding to redis %s", ex)
  109. return embedding_results