cached_embedding.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import logging
  2. from typing import List
  3. from langchain.embeddings.base import Embeddings
  4. from sqlalchemy.exc import IntegrityError
  5. from extensions.ext_database import db
  6. from libs import helper
  7. from models.dataset import Embedding
  8. class CacheEmbedding(Embeddings):
  9. def __init__(self, embeddings: Embeddings):
  10. self._embeddings = embeddings
  11. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  12. """Embed search docs."""
  13. # use doc embedding cache or store if not exists
  14. text_embeddings = []
  15. embedding_queue_texts = []
  16. for text in texts:
  17. hash = helper.generate_text_hash(text)
  18. embedding = db.session.query(Embedding).filter_by(hash=hash).first()
  19. if embedding:
  20. text_embeddings.append(embedding.get_embedding())
  21. else:
  22. embedding_queue_texts.append(text)
  23. embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
  24. i = 0
  25. for text in embedding_queue_texts:
  26. hash = helper.generate_text_hash(text)
  27. try:
  28. embedding = Embedding(hash=hash)
  29. embedding.set_embedding(embedding_results[i])
  30. db.session.add(embedding)
  31. db.session.commit()
  32. except IntegrityError:
  33. db.session.rollback()
  34. continue
  35. except:
  36. logging.exception('Failed to add embedding to db')
  37. continue
  38. i += 1
  39. text_embeddings.extend(embedding_results)
  40. return text_embeddings
  41. def embed_query(self, text: str) -> List[float]:
  42. """Embed query text."""
  43. # use doc embedding cache or store if not exists
  44. hash = helper.generate_text_hash(text)
  45. embedding = db.session.query(Embedding).filter_by(hash=hash).first()
  46. if embedding:
  47. return embedding.get_embedding()
  48. embedding_results = self._embeddings.embed_query(text)
  49. try:
  50. embedding = Embedding(hash=hash)
  51. embedding.set_embedding(embedding_results)
  52. db.session.add(embedding)
  53. db.session.commit()
  54. except IntegrityError:
  55. db.session.rollback()
  56. except:
  57. logging.exception('Failed to add embedding to db')
  58. return embedding_results