cached_embedding.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import base64
  2. import logging
  3. from typing import Optional, cast
  4. import numpy as np
  5. from langchain.embeddings.base import Embeddings
  6. from sqlalchemy.exc import IntegrityError
  7. from core.model_manager import ModelInstance
  8. from core.model_runtime.entities.model_entities import ModelPropertyKey
  9. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  10. from extensions.ext_database import db
  11. from extensions.ext_redis import redis_client
  12. from libs import helper
  13. logger = logging.getLogger(__name__)
  14. class CacheEmbedding(Embeddings):
  15. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  16. self._model_instance = model_instance
  17. self._user = user
  18. def embed_documents(self, texts: list[str]) -> list[list[float]]:
  19. """Embed search docs in batches of 10."""
  20. text_embeddings = []
  21. try:
  22. model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
  23. model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
  24. max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
  25. if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
  26. for i in range(0, len(texts), max_chunks):
  27. batch_texts = texts[i:i + max_chunks]
  28. embedding_result = self._model_instance.invoke_text_embedding(
  29. texts=batch_texts,
  30. user=self._user
  31. )
  32. for vector in embedding_result.embeddings:
  33. try:
  34. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  35. text_embeddings.append(normalized_embedding)
  36. except IntegrityError:
  37. db.session.rollback()
  38. except Exception as e:
  39. logging.exception('Failed to add embedding to redis')
  40. except Exception as ex:
  41. logger.error('Failed to embed documents: ', ex)
  42. raise ex
  43. return text_embeddings
  44. def embed_query(self, text: str) -> list[float]:
  45. """Embed query text."""
  46. # use doc embedding cache or store if not exists
  47. hash = helper.generate_text_hash(text)
  48. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  49. embedding = redis_client.get(embedding_cache_key)
  50. if embedding:
  51. redis_client.expire(embedding_cache_key, 600)
  52. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  53. try:
  54. embedding_result = self._model_instance.invoke_text_embedding(
  55. texts=[text],
  56. user=self._user
  57. )
  58. embedding_results = embedding_result.embeddings[0]
  59. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  60. except Exception as ex:
  61. raise ex
  62. try:
  63. # encode embedding to base64
  64. embedding_vector = np.array(embedding_results)
  65. vector_bytes = embedding_vector.tobytes()
  66. # Transform to Base64
  67. encoded_vector = base64.b64encode(vector_bytes)
  68. # Transform to string
  69. encoded_str = encoded_vector.decode("utf-8")
  70. redis_client.setex(embedding_cache_key, 600, encoded_str)
  71. except IntegrityError:
  72. db.session.rollback()
  73. except:
  74. logging.exception('Failed to add embedding to redis')
  75. return embedding_results