openai_embedding.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from typing import Optional, Any, List
  2. import openai
  3. from llama_index.embeddings.base import BaseEmbedding
  4. from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
  5. _TEXT_MODE_MODEL_DICT
  6. from tenacity import wait_random_exponential, retry, stop_after_attempt
  7. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  8. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  9. def get_embedding(
  10. text: str,
  11. engine: Optional[str] = None,
  12. openai_api_key: Optional[str] = None,
  13. ) -> List[float]:
  14. """Get embedding.
  15. NOTE: Copied from OpenAI's embedding utils:
  16. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  17. Copied here to avoid importing unnecessary dependencies
  18. like matplotlib, plotly, scipy, sklearn.
  19. """
  20. text = text.replace("\n", " ")
  21. return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
  22. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  23. async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
  24. """Asynchronously get embedding.
  25. NOTE: Copied from OpenAI's embedding utils:
  26. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  27. Copied here to avoid importing unnecessary dependencies
  28. like matplotlib, plotly, scipy, sklearn.
  29. """
  30. # replace newlines, which can negatively affect performance.
  31. text = text.replace("\n", " ")
  32. return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
  33. "embedding"
  34. ]
  35. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  36. def get_embeddings(
  37. list_of_text: List[str],
  38. engine: Optional[str] = None,
  39. openai_api_key: Optional[str] = None
  40. ) -> List[List[float]]:
  41. """Get embeddings.
  42. NOTE: Copied from OpenAI's embedding utils:
  43. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  44. Copied here to avoid importing unnecessary dependencies
  45. like matplotlib, plotly, scipy, sklearn.
  46. """
  47. assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
  48. # replace newlines, which can negatively affect performance.
  49. list_of_text = [text.replace("\n", " ") for text in list_of_text]
  50. data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
  51. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
  52. return [d["embedding"] for d in data]
  53. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  54. async def aget_embeddings(
  55. list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
  56. ) -> List[List[float]]:
  57. """Asynchronously get embeddings.
  58. NOTE: Copied from OpenAI's embedding utils:
  59. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  60. Copied here to avoid importing unnecessary dependencies
  61. like matplotlib, plotly, scipy, sklearn.
  62. """
  63. assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
  64. # replace newlines, which can negatively affect performance.
  65. list_of_text = [text.replace("\n", " ") for text in list_of_text]
  66. data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
  67. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
  68. return [d["embedding"] for d in data]
  69. class OpenAIEmbedding(BaseEmbedding):
  70. def __init__(
  71. self,
  72. mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
  73. model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
  74. deployment_name: Optional[str] = None,
  75. openai_api_key: Optional[str] = None,
  76. **kwargs: Any,
  77. ) -> None:
  78. """Init params."""
  79. super().__init__(**kwargs)
  80. self.mode = OpenAIEmbeddingMode(mode)
  81. self.model = OpenAIEmbeddingModelType(model)
  82. self.deployment_name = deployment_name
  83. self.openai_api_key = openai_api_key
  84. @handle_llm_exceptions
  85. def _get_query_embedding(self, query: str) -> List[float]:
  86. """Get query embedding."""
  87. if self.deployment_name is not None:
  88. engine = self.deployment_name
  89. else:
  90. key = (self.mode, self.model)
  91. if key not in _QUERY_MODE_MODEL_DICT:
  92. raise ValueError(f"Invalid mode, model combination: {key}")
  93. engine = _QUERY_MODE_MODEL_DICT[key]
  94. return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
  95. def _get_text_embedding(self, text: str) -> List[float]:
  96. """Get text embedding."""
  97. if self.deployment_name is not None:
  98. engine = self.deployment_name
  99. else:
  100. key = (self.mode, self.model)
  101. if key not in _TEXT_MODE_MODEL_DICT:
  102. raise ValueError(f"Invalid mode, model combination: {key}")
  103. engine = _TEXT_MODE_MODEL_DICT[key]
  104. return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
  105. async def _aget_text_embedding(self, text: str) -> List[float]:
  106. """Asynchronously get text embedding."""
  107. if self.deployment_name is not None:
  108. engine = self.deployment_name
  109. else:
  110. key = (self.mode, self.model)
  111. if key not in _TEXT_MODE_MODEL_DICT:
  112. raise ValueError(f"Invalid mode, model combination: {key}")
  113. engine = _TEXT_MODE_MODEL_DICT[key]
  114. return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
  115. def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
  116. """Get text embeddings.
  117. By default, this is a wrapper around _get_text_embedding.
  118. Can be overriden for batch queries.
  119. """
  120. if self.deployment_name is not None:
  121. engine = self.deployment_name
  122. else:
  123. key = (self.mode, self.model)
  124. if key not in _TEXT_MODE_MODEL_DICT:
  125. raise ValueError(f"Invalid mode, model combination: {key}")
  126. engine = _TEXT_MODE_MODEL_DICT[key]
  127. embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
  128. return embeddings
  129. async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
  130. """Asynchronously get text embeddings."""
  131. if self.deployment_name is not None:
  132. engine = self.deployment_name
  133. else:
  134. key = (self.mode, self.model)
  135. if key not in _TEXT_MODE_MODEL_DICT:
  136. raise ValueError(f"Invalid mode, model combination: {key}")
  137. engine = _TEXT_MODE_MODEL_DICT[key]
  138. embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
  139. return embeddings