|
@@ -1,5 +1,6 @@
|
|
|
import json
|
|
|
import logging
|
|
|
+import re
|
|
|
from collections.abc import Generator
|
|
|
from typing import Any, Optional, Union, cast
|
|
|
|
|
@@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
|
|
|
|
|
# o1 compatibility
|
|
|
+ block_as_stream = False
|
|
|
if model.startswith("o1"):
|
|
|
if "max_tokens" in model_parameters:
|
|
|
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
|
|
del model_parameters["max_tokens"]
|
|
|
|
|
|
+ if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
|
|
|
+ if stream:
|
|
|
+ block_as_stream = True
|
|
|
+ stream = False
|
|
|
+ if "stream_options" in extra_model_kwargs:
|
|
|
+ del extra_model_kwargs["stream_options"]
|
|
|
+
|
|
|
if "stop" in extra_model_kwargs:
|
|
|
del extra_model_kwargs["stop"]
|
|
|
|
|
@@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
if stream:
|
|
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
|
|
|
|
|
- return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
+ block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
+
|
|
|
+ if block_as_stream:
|
|
|
+ return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
|
|
+
|
|
|
+ return block_result
|
|
|
+
|
|
|
+ def _handle_chat_block_as_stream_response(
|
|
|
+ self,
|
|
|
+ block_result: LLMResult,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ stop: Optional[list[str]] = None,
|
|
|
+ ) -> Generator[LLMResultChunk, None, None]:
|
|
|
+ """
|
|
|
+ Handle llm chat response
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: credentials
|
|
|
+ :param response: response
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param tools: tools for tool calling
|
|
|
+ :return: llm response chunk generator
|
|
|
+ """
|
|
|
+ text = block_result.message.content
|
|
|
+ text = cast(str, text)
|
|
|
+
|
|
|
+ if stop:
|
|
|
+ text = self.enforce_stop_tokens(text, stop)
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=block_result.model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=block_result.system_fingerprint,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=block_result.message,
|
|
|
+ finish_reason="stop",
|
|
|
+ usage=block_result.usage,
|
|
|
+ ),
|
|
|
+ )
|
|
|
|
|
|
def _handle_chat_generate_response(
|
|
|
self,
|