|
@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
|
|
|
|
|
|
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
|
def get_embedding(
|
|
|
- text: str,
|
|
|
- engine: Optional[str] = None,
|
|
|
- openai_api_key: Optional[str] = None,
|
|
|
+ text: str,
|
|
|
+ engine: Optional[str] = None,
|
|
|
+ api_key: Optional[str] = None,
|
|
|
+ **kwargs
|
|
|
) -> List[float]:
|
|
|
"""Get embedding.
|
|
|
|
|
@@ -25,11 +26,12 @@ def get_embedding(
|
|
|
|
|
|
"""
|
|
|
text = text.replace("\n", " ")
|
|
|
- return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
|
|
|
+ return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
|
|
|
|
|
|
|
|
|
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
|
-async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
|
|
|
+async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
|
|
|
+ float]:
|
|
|
"""Asynchronously get embedding.
|
|
|
|
|
|
NOTE: Copied from OpenAI's embedding utils:
|
|
@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
|
|
|
# replace newlines, which can negatively affect performance.
|
|
|
text = text.replace("\n", " ")
|
|
|
|
|
|
- return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
|
|
|
+ return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
|
|
|
"embedding"
|
|
|
]
|
|
|
|
|
|
|
|
|
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
|
def get_embeddings(
|
|
|
- list_of_text: List[str],
|
|
|
- engine: Optional[str] = None,
|
|
|
- openai_api_key: Optional[str] = None
|
|
|
+ list_of_text: List[str],
|
|
|
+ engine: Optional[str] = None,
|
|
|
+ api_key: Optional[str] = None,
|
|
|
+ **kwargs
|
|
|
) -> List[List[float]]:
|
|
|
"""Get embeddings.
|
|
|
|
|
@@ -67,14 +70,14 @@ def get_embeddings(
|
|
|
# replace newlines, which can negatively affect performance.
|
|
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
|
|
|
|
- data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
|
|
|
+ data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
|
|
|
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
|
|
return [d["embedding"] for d in data]
|
|
|
|
|
|
|
|
|
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
|
async def aget_embeddings(
|
|
|
- list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
|
|
|
+ list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
|
|
) -> List[List[float]]:
|
|
|
"""Asynchronously get embeddings.
|
|
|
|
|
@@ -90,7 +93,7 @@ async def aget_embeddings(
|
|
|
# replace newlines, which can negatively affect performance.
|
|
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
|
|
|
|
- data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
|
|
|
+ data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
|
|
|
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
|
|
return [d["embedding"] for d in data]
|
|
|
|
|
@@ -98,19 +101,30 @@ async def aget_embeddings(
|
|
|
class OpenAIEmbedding(BaseEmbedding):
|
|
|
|
|
|
def __init__(
|
|
|
- self,
|
|
|
- mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
|
|
- model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
|
|
- deployment_name: Optional[str] = None,
|
|
|
- openai_api_key: Optional[str] = None,
|
|
|
- **kwargs: Any,
|
|
|
+ self,
|
|
|
+ mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
|
|
+ model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
|
|
+ deployment_name: Optional[str] = None,
|
|
|
+ openai_api_key: Optional[str] = None,
|
|
|
+ **kwargs: Any,
|
|
|
) -> None:
|
|
|
"""Init params."""
|
|
|
- super().__init__(**kwargs)
|
|
|
+ new_kwargs = {}
|
|
|
+
|
|
|
+ if 'embed_batch_size' in kwargs:
|
|
|
+ new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
|
|
|
+
|
|
|
+ if 'tokenizer' in kwargs:
|
|
|
+ new_kwargs['tokenizer'] = kwargs['tokenizer']
|
|
|
+
|
|
|
+ super().__init__(**new_kwargs)
|
|
|
self.mode = OpenAIEmbeddingMode(mode)
|
|
|
self.model = OpenAIEmbeddingModelType(model)
|
|
|
self.deployment_name = deployment_name
|
|
|
self.openai_api_key = openai_api_key
|
|
|
+ self.openai_api_type = kwargs.get('openai_api_type')
|
|
|
+ self.openai_api_version = kwargs.get('openai_api_version')
|
|
|
+ self.openai_api_base = kwargs.get('openai_api_base')
|
|
|
|
|
|
@handle_llm_exceptions
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|
|
if key not in _QUERY_MODE_MODEL_DICT:
|
|
|
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
|
engine = _QUERY_MODE_MODEL_DICT[key]
|
|
|
- return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
|
|
|
+ return get_embedding(query, engine=engine, api_key=self.openai_api_key,
|
|
|
+ api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
|
+ api_base=self.openai_api_base)
|
|
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
|
"""Get text embedding."""
|
|
@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|
|
if key not in _TEXT_MODE_MODEL_DICT:
|
|
|
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
|
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
|
- return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
|
|
|
+ return get_embedding(text, engine=engine, api_key=self.openai_api_key,
|
|
|
+ api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
|
+ api_base=self.openai_api_base)
|
|
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
|
"""Asynchronously get text embedding."""
|
|
@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|
|
if key not in _TEXT_MODE_MODEL_DICT:
|
|
|
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
|
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
|
- return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
|
|
|
+ return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
|
|
|
+ api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
|
+ api_base=self.openai_api_base)
|
|
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
|
"""Get text embeddings.
|
|
@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
|
|
|
if key not in _TEXT_MODE_MODEL_DICT:
|
|
|
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
|
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
|
- embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
|
|
|
+ embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
|
|
+ api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
|
+ api_base=self.openai_api_base)
|
|
|
return embeddings
|
|
|
|
|
|
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
|
|
|
if key not in _TEXT_MODE_MODEL_DICT:
|
|
|
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
|
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
|
- embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
|
|
|
+ embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
|
|
+ api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
|
+ api_base=self.openai_api_base)
|
|
|
return embeddings
|