cached_embedding.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import base64
  2. import logging
  3. from typing import Optional, cast
  4. import numpy as np
  5. from sqlalchemy.exc import IntegrityError
  6. from core.embedding.embedding_constant import EmbeddingInputType
  7. from core.model_manager import ModelInstance
  8. from core.model_runtime.entities.model_entities import ModelPropertyKey
  9. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  10. from core.rag.datasource.entity.embedding import Embeddings
  11. from extensions.ext_database import db
  12. from extensions.ext_redis import redis_client
  13. from libs import helper
  14. from models.dataset import Embedding
  15. logger = logging.getLogger(__name__)
  16. class CacheEmbedding(Embeddings):
  17. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  18. self._model_instance = model_instance
  19. self._user = user
  20. def embed_documents(self, texts: list[str]) -> list[list[float]]:
  21. """Embed search docs in batches of 10."""
  22. # use doc embedding cache or store if not exists
  23. text_embeddings = [None for _ in range(len(texts))]
  24. embedding_queue_indices = []
  25. for i, text in enumerate(texts):
  26. hash = helper.generate_text_hash(text)
  27. embedding = (
  28. db.session.query(Embedding)
  29. .filter_by(
  30. model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
  31. )
  32. .first()
  33. )
  34. if embedding:
  35. text_embeddings[i] = embedding.get_embedding()
  36. else:
  37. embedding_queue_indices.append(i)
  38. if embedding_queue_indices:
  39. embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
  40. embedding_queue_embeddings = []
  41. try:
  42. model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
  43. model_schema = model_type_instance.get_model_schema(
  44. self._model_instance.model, self._model_instance.credentials
  45. )
  46. max_chunks = (
  47. model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
  48. if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
  49. else 1
  50. )
  51. for i in range(0, len(embedding_queue_texts), max_chunks):
  52. batch_texts = embedding_queue_texts[i : i + max_chunks]
  53. embedding_result = self._model_instance.invoke_text_embedding(
  54. texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
  55. )
  56. for vector in embedding_result.embeddings:
  57. try:
  58. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  59. embedding_queue_embeddings.append(normalized_embedding)
  60. except IntegrityError:
  61. db.session.rollback()
  62. except Exception as e:
  63. logging.exception("Failed transform embedding: %s", e)
  64. cache_embeddings = []
  65. try:
  66. for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
  67. text_embeddings[i] = embedding
  68. hash = helper.generate_text_hash(texts[i])
  69. if hash not in cache_embeddings:
  70. embedding_cache = Embedding(
  71. model_name=self._model_instance.model,
  72. hash=hash,
  73. provider_name=self._model_instance.provider,
  74. )
  75. embedding_cache.set_embedding(embedding)
  76. db.session.add(embedding_cache)
  77. cache_embeddings.append(hash)
  78. db.session.commit()
  79. except IntegrityError:
  80. db.session.rollback()
  81. except Exception as ex:
  82. db.session.rollback()
  83. logger.error("Failed to embed documents: %s", ex)
  84. raise ex
  85. return text_embeddings
  86. def embed_query(self, text: str) -> list[float]:
  87. """Embed query text."""
  88. # use doc embedding cache or store if not exists
  89. hash = helper.generate_text_hash(text)
  90. embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
  91. embedding = redis_client.get(embedding_cache_key)
  92. if embedding:
  93. redis_client.expire(embedding_cache_key, 600)
  94. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  95. try:
  96. embedding_result = self._model_instance.invoke_text_embedding(
  97. texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
  98. )
  99. embedding_results = embedding_result.embeddings[0]
  100. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  101. except Exception as ex:
  102. raise ex
  103. try:
  104. # encode embedding to base64
  105. embedding_vector = np.array(embedding_results)
  106. vector_bytes = embedding_vector.tobytes()
  107. # Transform to Base64
  108. encoded_vector = base64.b64encode(vector_bytes)
  109. # Transform to string
  110. encoded_str = encoded_vector.decode("utf-8")
  111. redis_client.setex(embedding_cache_key, 600, encoded_str)
  112. except Exception as ex:
  113. logging.exception("Failed to add embedding to redis %s", ex)
  114. return embedding_results