|
@@ -1,11 +1,13 @@
|
|
|
-from typing import Dict, Any, Optional, List, Tuple, Union
|
|
|
+from typing import Dict, Any, Optional, List, Tuple, Union, cast
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
from langchain.chat_models import AzureChatOpenAI
|
|
|
from langchain.chat_models.openai import _convert_dict_to_message
|
|
|
-from langchain.schema import ChatResult, BaseMessage, ChatGeneration
|
|
|
from pydantic import root_validator
|
|
|
|
|
|
+from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
|
|
+from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
|
|
+
|
|
|
|
|
|
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
|
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
|
@@ -51,13 +53,18 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
|
|
}
|
|
|
|
|
|
def _generate(
|
|
|
- self,
|
|
|
- messages: List[BaseMessage],
|
|
|
- stop: Optional[List[str]] = None,
|
|
|
- run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
- **kwargs: Any,
|
|
|
+ self,
|
|
|
+ messages: List[BaseMessage],
|
|
|
+ stop: Optional[List[str]] = None,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ **kwargs: Any,
|
|
|
) -> ChatResult:
|
|
|
- message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
+ params = self._client_params
|
|
|
+ if stop is not None:
|
|
|
+ if "stop" in params:
|
|
|
+ raise ValueError("`stop` found in both the input and default params.")
|
|
|
+ params["stop"] = stop
|
|
|
+ message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
|
|
params = {**params, **kwargs}
|
|
|
if self.streaming:
|
|
|
inner_completion = ""
|
|
@@ -65,7 +72,7 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
|
|
params["stream"] = True
|
|
|
function_call: Optional[dict] = None
|
|
|
for stream_resp in self.completion_with_retry(
|
|
|
- messages=message_dicts, **params
|
|
|
+ messages=message_dicts, **params
|
|
|
):
|
|
|
if len(stream_resp["choices"]) > 0:
|
|
|
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
@@ -88,4 +95,47 @@ class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
|
|
)
|
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
|
|
- return self._create_chat_result(response)
|
|
|
+ return self._create_chat_result(response)
|
|
|
+
|
|
|
+ def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
|
|
+ if isinstance(message, ChatMessage):
|
|
|
+ message_dict = {"role": message.role, "content": message.content}
|
|
|
+ elif isinstance(message, LCHumanMessageWithFiles):
|
|
|
+ content = [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": message.content
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ for file in message.files:
|
|
|
+ if file.type == PromptMessageFileType.IMAGE:
|
|
|
+ file = cast(ImagePromptMessageFile, file)
|
|
|
+ content.append({
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": file.data,
|
|
|
+ "detail": file.detail.value
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ message_dict = {"role": "user", "content": content}
|
|
|
+ elif isinstance(message, HumanMessage):
|
|
|
+ message_dict = {"role": "user", "content": message.content}
|
|
|
+ elif isinstance(message, AIMessage):
|
|
|
+ message_dict = {"role": "assistant", "content": message.content}
|
|
|
+ if "function_call" in message.additional_kwargs:
|
|
|
+ message_dict["function_call"] = message.additional_kwargs["function_call"]
|
|
|
+ elif isinstance(message, SystemMessage):
|
|
|
+ message_dict = {"role": "system", "content": message.content}
|
|
|
+ elif isinstance(message, FunctionMessage):
|
|
|
+ message_dict = {
|
|
|
+ "role": "function",
|
|
|
+ "content": message.content,
|
|
|
+ "name": message.name,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Got unknown type {message}")
|
|
|
+ if "name" in message.additional_kwargs:
|
|
|
+ message_dict["name"] = message.additional_kwargs["name"]
|
|
|
+ return message_dict
|