cached_embedding.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import base64
  2. import json
  3. import logging
  4. from typing import List, Optional, cast
  5. import numpy as np
  6. from core.model_manager import ModelInstance
  7. from core.model_runtime.entities.model_entities import ModelPropertyKey
  8. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  9. from extensions.ext_database import db
  10. from langchain.embeddings.base import Embeddings
  11. from extensions.ext_redis import redis_client
  12. from libs import helper
  13. from models.dataset import Embedding
  14. from sqlalchemy.exc import IntegrityError
  15. logger = logging.getLogger(__name__)
  16. class CacheEmbedding(Embeddings):
  17. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  18. self._model_instance = model_instance
  19. self._user = user
  20. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  21. """Embed search docs in batches of 10."""
  22. text_embeddings = []
  23. try:
  24. model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
  25. model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
  26. max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
  27. if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
  28. for i in range(0, len(texts), max_chunks):
  29. batch_texts = texts[i:i + max_chunks]
  30. embedding_result = self._model_instance.invoke_text_embedding(
  31. texts=batch_texts,
  32. user=self._user
  33. )
  34. for vector in embedding_result.embeddings:
  35. try:
  36. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  37. text_embeddings.append(normalized_embedding)
  38. except IntegrityError:
  39. db.session.rollback()
  40. except Exception as e:
  41. logging.exception('Failed to add embedding to redis')
  42. except Exception as ex:
  43. logger.error('Failed to embed documents: ', ex)
  44. raise ex
  45. return text_embeddings
  46. def embed_query(self, text: str) -> List[float]:
  47. """Embed query text."""
  48. # use doc embedding cache or store if not exists
  49. hash = helper.generate_text_hash(text)
  50. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  51. embedding = redis_client.get(embedding_cache_key)
  52. if embedding:
  53. redis_client.expire(embedding_cache_key, 600)
  54. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  55. try:
  56. embedding_result = self._model_instance.invoke_text_embedding(
  57. texts=[text],
  58. user=self._user
  59. )
  60. embedding_results = embedding_result.embeddings[0]
  61. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  62. except Exception as ex:
  63. raise ex
  64. try:
  65. # encode embedding to base64
  66. embedding_vector = np.array(embedding_results)
  67. vector_bytes = embedding_vector.tobytes()
  68. # Transform to Base64
  69. encoded_vector = base64.b64encode(vector_bytes)
  70. # Transform to string
  71. encoded_str = encoded_vector.decode("utf-8")
  72. redis_client.setex(embedding_cache_key, 600, encoded_str)
  73. except IntegrityError:
  74. db.session.rollback()
  75. except:
  76. logging.exception('Failed to add embedding to redis')
  77. return embedding_results