cached_embedding.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import logging
  2. from typing import List
  3. import numpy as np
  4. from langchain.embeddings.base import Embeddings
  5. from sqlalchemy.exc import IntegrityError
  6. from core.model_providers.models.embedding.base import BaseEmbedding
  7. from extensions.ext_database import db
  8. from libs import helper
  9. from models.dataset import Embedding
  10. class CacheEmbedding(Embeddings):
  11. def __init__(self, embeddings: BaseEmbedding):
  12. self._embeddings = embeddings
  13. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  14. """Embed search docs."""
  15. # use doc embedding cache or store if not exists
  16. text_embeddings = [None for _ in range(len(texts))]
  17. embedding_queue_indices = []
  18. for i, text in enumerate(texts):
  19. hash = helper.generate_text_hash(text)
  20. embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
  21. if embedding:
  22. text_embeddings[i] = embedding.get_embedding()
  23. else:
  24. embedding_queue_indices.append(i)
  25. if embedding_queue_indices:
  26. try:
  27. embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
  28. except Exception as ex:
  29. raise self._embeddings.handle_exceptions(ex)
  30. for i, indice in enumerate(embedding_queue_indices):
  31. hash = helper.generate_text_hash(texts[indice])
  32. try:
  33. embedding = Embedding(model_name=self._embeddings.name, hash=hash)
  34. vector = embedding_results[i]
  35. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  36. text_embeddings[indice] = normalized_embedding
  37. embedding.set_embedding(normalized_embedding)
  38. db.session.add(embedding)
  39. db.session.commit()
  40. except IntegrityError:
  41. db.session.rollback()
  42. continue
  43. except:
  44. logging.exception('Failed to add embedding to db')
  45. continue
  46. return text_embeddings
  47. def embed_query(self, text: str) -> List[float]:
  48. """Embed query text."""
  49. # use doc embedding cache or store if not exists
  50. hash = helper.generate_text_hash(text)
  51. embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
  52. if embedding:
  53. return embedding.get_embedding()
  54. try:
  55. embedding_results = self._embeddings.client.embed_query(text)
  56. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  57. except Exception as ex:
  58. raise self._embeddings.handle_exceptions(ex)
  59. try:
  60. embedding = Embedding(model_name=self._embeddings.name, hash=hash)
  61. embedding.set_embedding(embedding_results)
  62. db.session.add(embedding)
  63. db.session.commit()
  64. except IntegrityError:
  65. db.session.rollback()
  66. except:
  67. logging.exception('Failed to add embedding to db')
  68. return embedding_results