from typing import Optional, Any, List import openai from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \ _TEXT_MODE_MODEL_DICT from tenacity import wait_random_exponential, retry, stop_after_attempt from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embedding( text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs ) -> List[float]: """Get embedding. NOTE: Copied from OpenAI's embedding utils: https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py Copied here to avoid importing unnecessary dependencies like matplotlib, plotly, scipy, sklearn. """ text = text.replace("\n", " ") return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ float]: """Asynchronously get embedding. NOTE: Copied from OpenAI's embedding utils: https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py Copied here to avoid importing unnecessary dependencies like matplotlib, plotly, scipy, sklearn. """ # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ "embedding" ] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embeddings( list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs ) -> List[List[float]]: """Get embeddings. NOTE: Copied from OpenAI's embedding utils: https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py Copied here to avoid importing unnecessary dependencies like matplotlib, plotly, scipy, sklearn. """ assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embeddings( list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs ) -> List[List[float]]: """Asynchronously get embeddings. NOTE: Copied from OpenAI's embedding utils: https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py Copied here to avoid importing unnecessary dependencies like matplotlib, plotly, scipy, sklearn. """ assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] class OpenAIEmbedding(BaseEmbedding): def __init__( self, mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, deployment_name: Optional[str] = None, openai_api_key: Optional[str] = None, **kwargs: Any, ) -> None: """Init params.""" new_kwargs = {} if 'embed_batch_size' in kwargs: new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] if 'tokenizer' in kwargs: new_kwargs['tokenizer'] = kwargs['tokenizer'] super().__init__(**new_kwargs) self.mode = OpenAIEmbeddingMode(mode) self.model = OpenAIEmbeddingModelType(model) self.deployment_name = deployment_name self.openai_api_key = openai_api_key self.openai_api_type = kwargs.get('openai_api_type') self.openai_api_version = kwargs.get('openai_api_version') self.openai_api_base = kwargs.get('openai_api_base') @handle_llm_exceptions def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" if self.deployment_name is not None: engine = self.deployment_name else: key = (self.mode, self.model) if key not in _QUERY_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _QUERY_MODE_MODEL_DICT[key] return get_embedding(query, engine=engine, api_key=self.openai_api_key, api_type=self.openai_api_type, api_version=self.openai_api_version, api_base=self.openai_api_base) def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" if self.deployment_name is not None: engine = self.deployment_name else: key = (self.mode, self.model) if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] return get_embedding(text, engine=engine, api_key=self.openai_api_key, api_type=self.openai_api_type, api_version=self.openai_api_version, api_base=self.openai_api_base) async def _aget_text_embedding(self, text: str) -> List[float]: """Asynchronously get text embedding.""" if self.deployment_name is not None: engine = self.deployment_name else: key = (self.mode, self.model) if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, api_type=self.openai_api_type, api_version=self.openai_api_version, api_base=self.openai_api_base) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings. By default, this is a wrapper around _get_text_embedding. Can be overriden for batch queries. """ if self.openai_api_type and self.openai_api_type == 'azure': embeddings = [] for text in texts: embeddings.append(self._get_text_embedding(text)) return embeddings if self.deployment_name is not None: engine = self.deployment_name else: key = (self.mode, self.model) if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, api_type=self.openai_api_type, api_version=self.openai_api_version, api_base=self.openai_api_base) return embeddings async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Asynchronously get text embeddings.""" if self.openai_api_type and self.openai_api_type == 'azure': embeddings = [] for text in texts: embeddings.append(await self._aget_text_embedding(text)) return embeddings if self.deployment_name is not None: engine = self.deployment_name else: key = (self.mode, self.model) if key not in _TEXT_MODE_MODEL_DICT: raise ValueError(f"Invalid mode, model combination: {key}") engine = _TEXT_MODE_MODEL_DICT[key] embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, api_type=self.openai_api_type, api_version=self.openai_api_version, api_base=self.openai_api_base) return embeddings