cached_embedding.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import logging
  2. from typing import List, Optional
  3. import numpy as np
  4. from core.model_manager import ModelInstance
  5. from extensions.ext_database import db
  6. from langchain.embeddings.base import Embeddings
  7. from libs import helper
  8. from models.dataset import Embedding
  9. from sqlalchemy.exc import IntegrityError
  10. logger = logging.getLogger(__name__)
  11. class CacheEmbedding(Embeddings):
  12. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  13. self._model_instance = model_instance
  14. self._user = user
  15. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  16. """Embed search docs."""
  17. # use doc embedding cache or store if not exists
  18. text_embeddings = [None for _ in range(len(texts))]
  19. embedding_queue_indices = []
  20. for i, text in enumerate(texts):
  21. hash = helper.generate_text_hash(text)
  22. embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
  23. if embedding:
  24. text_embeddings[i] = embedding.get_embedding()
  25. else:
  26. embedding_queue_indices.append(i)
  27. if embedding_queue_indices:
  28. try:
  29. embedding_result = self._model_instance.invoke_text_embedding(
  30. texts=[texts[i] for i in embedding_queue_indices],
  31. user=self._user
  32. )
  33. embedding_results = embedding_result.embeddings
  34. except Exception as ex:
  35. logger.error('Failed to embed documents: ', ex)
  36. raise ex
  37. for i, indice in enumerate(embedding_queue_indices):
  38. hash = helper.generate_text_hash(texts[indice])
  39. try:
  40. embedding = Embedding(model_name=self._model_instance.model, hash=hash)
  41. vector = embedding_results[i]
  42. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  43. text_embeddings[indice] = normalized_embedding
  44. embedding.set_embedding(normalized_embedding)
  45. db.session.add(embedding)
  46. db.session.commit()
  47. except IntegrityError:
  48. db.session.rollback()
  49. continue
  50. except:
  51. logging.exception('Failed to add embedding to db')
  52. continue
  53. return text_embeddings
  54. def embed_query(self, text: str) -> List[float]:
  55. """Embed query text."""
  56. # use doc embedding cache or store if not exists
  57. hash = helper.generate_text_hash(text)
  58. embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
  59. if embedding:
  60. return embedding.get_embedding()
  61. try:
  62. embedding_result = self._model_instance.invoke_text_embedding(
  63. texts=[text],
  64. user=self._user
  65. )
  66. embedding_results = embedding_result.embeddings[0]
  67. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  68. except Exception as ex:
  69. raise ex
  70. try:
  71. embedding = Embedding(model_name=self._model_instance.model, hash=hash)
  72. embedding.set_embedding(embedding_results)
  73. db.session.add(embedding)
  74. db.session.commit()
  75. except IntegrityError:
  76. db.session.rollback()
  77. except:
  78. logging.exception('Failed to add embedding to db')
  79. return embedding_results