cached_embedding.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import base64
  2. import logging
  3. from typing import Optional, cast
  4. import numpy as np
  5. from sqlalchemy.exc import IntegrityError
  6. from core.model_manager import ModelInstance
  7. from core.model_runtime.entities.model_entities import ModelPropertyKey
  8. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  9. from core.rag.datasource.entity.embedding import Embeddings
  10. from extensions.ext_database import db
  11. from extensions.ext_redis import redis_client
  12. from libs import helper
  13. from models.dataset import Embedding
  14. logger = logging.getLogger(__name__)
  15. class CacheEmbedding(Embeddings):
  16. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  17. self._model_instance = model_instance
  18. self._user = user
  19. def embed_documents(self, texts: list[str]) -> list[list[float]]:
  20. """Embed search docs in batches of 10."""
  21. # use doc embedding cache or store if not exists
  22. text_embeddings = [None for _ in range(len(texts))]
  23. embedding_queue_indices = []
  24. for i, text in enumerate(texts):
  25. hash = helper.generate_text_hash(text)
  26. embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
  27. hash=hash,
  28. provider_name=self._model_instance.provider).first()
  29. if embedding:
  30. text_embeddings[i] = embedding.get_embedding()
  31. else:
  32. embedding_queue_indices.append(i)
  33. if embedding_queue_indices:
  34. embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
  35. embedding_queue_embeddings = []
  36. try:
  37. model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
  38. model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
  39. max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
  40. if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
  41. for i in range(0, len(embedding_queue_texts), max_chunks):
  42. batch_texts = embedding_queue_texts[i:i + max_chunks]
  43. embedding_result = self._model_instance.invoke_text_embedding(
  44. texts=batch_texts,
  45. user=self._user
  46. )
  47. for vector in embedding_result.embeddings:
  48. try:
  49. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  50. embedding_queue_embeddings.append(normalized_embedding)
  51. except IntegrityError:
  52. db.session.rollback()
  53. except Exception as e:
  54. logging.exception('Failed transform embedding: ', e)
  55. for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
  56. text_embeddings[i] = embedding
  57. hash = helper.generate_text_hash(texts[i])
  58. embedding_cache = Embedding(model_name=self._model_instance.model,
  59. hash=hash,
  60. provider_name=self._model_instance.provider)
  61. embedding_cache.set_embedding(embedding)
  62. db.session.add(embedding_cache)
  63. db.session.commit()
  64. except Exception as ex:
  65. db.session.rollback()
  66. logger.error('Failed to embed documents: ', ex)
  67. raise ex
  68. return text_embeddings
  69. def embed_query(self, text: str) -> list[float]:
  70. """Embed query text."""
  71. # use doc embedding cache or store if not exists
  72. hash = helper.generate_text_hash(text)
  73. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  74. embedding = redis_client.get(embedding_cache_key)
  75. if embedding:
  76. redis_client.expire(embedding_cache_key, 600)
  77. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  78. try:
  79. embedding_result = self._model_instance.invoke_text_embedding(
  80. texts=[text],
  81. user=self._user
  82. )
  83. embedding_results = embedding_result.embeddings[0]
  84. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  85. except Exception as ex:
  86. raise ex
  87. try:
  88. # encode embedding to base64
  89. embedding_vector = np.array(embedding_results)
  90. vector_bytes = embedding_vector.tobytes()
  91. # Transform to Base64
  92. encoded_vector = base64.b64encode(vector_bytes)
  93. # Transform to string
  94. encoded_str = encoded_vector.decode("utf-8")
  95. redis_client.setex(embedding_cache_key, 600, encoded_str)
  96. except IntegrityError:
  97. db.session.rollback()
  98. except:
  99. logging.exception('Failed to add embedding to redis')
  100. return embedding_results