|
@@ -1,7 +1,7 @@
|
|
|
import base64
|
|
|
import copy
|
|
|
import time
|
|
|
-from typing import Optional, Tuple
|
|
|
+from typing import Optional, Tuple, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import tiktoken
|
|
@@ -76,7 +76,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
|
|
model=model,
|
|
|
client=client,
|
|
|
- texts=[""],
|
|
|
+ texts="",
|
|
|
extra_model_kwargs=extra_model_kwargs
|
|
|
)
|
|
|
|
|
@@ -147,7 +147,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|
|
return ai_model_entity.entity
|
|
|
|
|
|
@staticmethod
|
|
|
- def _embedding_invoke(model: str, client: AzureOpenAI, texts: list[str],
|
|
|
+ def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str],
|
|
|
extra_model_kwargs: dict) -> Tuple[list[list[float]], int]:
|
|
|
response = client.embeddings.create(
|
|
|
input=texts,
|