cached_embedding.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import logging
  2. from typing import List
  3. import numpy as np
  4. from langchain.embeddings.base import Embeddings
  5. from sqlalchemy.exc import IntegrityError
  6. from core.model_providers.models.embedding.base import BaseEmbedding
  7. from extensions.ext_database import db
  8. from libs import helper
  9. from models.dataset import Embedding
  10. class CacheEmbedding(Embeddings):
  11. def __init__(self, embeddings: BaseEmbedding):
  12. self._embeddings = embeddings
  13. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  14. """Embed search docs."""
  15. # use doc embedding cache or store if not exists
  16. text_embeddings = []
  17. embedding_queue_texts = []
  18. for text in texts:
  19. hash = helper.generate_text_hash(text)
  20. embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
  21. if embedding:
  22. text_embeddings.append(embedding.get_embedding())
  23. else:
  24. embedding_queue_texts.append(text)
  25. if embedding_queue_texts:
  26. try:
  27. embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
  28. except Exception as ex:
  29. raise self._embeddings.handle_exceptions(ex)
  30. i = 0
  31. normalized_embedding_results = []
  32. for text in embedding_queue_texts:
  33. hash = helper.generate_text_hash(text)
  34. try:
  35. embedding = Embedding(model_name=self._embeddings.name, hash=hash)
  36. vector = embedding_results[i]
  37. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  38. normalized_embedding_results.append(normalized_embedding)
  39. embedding.set_embedding(normalized_embedding)
  40. db.session.add(embedding)
  41. db.session.commit()
  42. except IntegrityError:
  43. db.session.rollback()
  44. continue
  45. except:
  46. logging.exception('Failed to add embedding to db')
  47. continue
  48. finally:
  49. i += 1
  50. text_embeddings.extend(normalized_embedding_results)
  51. return text_embeddings
  52. def embed_query(self, text: str) -> List[float]:
  53. """Embed query text."""
  54. # use doc embedding cache or store if not exists
  55. hash = helper.generate_text_hash(text)
  56. embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
  57. if embedding:
  58. return embedding.get_embedding()
  59. try:
  60. embedding_results = self._embeddings.client.embed_query(text)
  61. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  62. except Exception as ex:
  63. raise self._embeddings.handle_exceptions(ex)
  64. try:
  65. embedding = Embedding(model_name=self._embeddings.name, hash=hash)
  66. embedding.set_embedding(embedding_results)
  67. db.session.add(embedding)
  68. db.session.commit()
  69. except IntegrityError:
  70. db.session.rollback()
  71. except:
  72. logging.exception('Failed to add embedding to db')
  73. return embedding_results