|
@@ -620,6 +620,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
if "stream_options" in extra_model_kwargs:
|
|
|
del extra_model_kwargs["stream_options"]
|
|
|
|
|
|
+ if "stop" in extra_model_kwargs:
|
|
|
+ del extra_model_kwargs["stop"]
|
|
|
+
|
|
|
# chat model
|
|
|
response = client.chat.completions.create(
|
|
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
|
@@ -635,7 +638,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
|
|
|
if block_as_stream:
|
|
|
- return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
|
|
|
+ return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
|
|
|
|
|
return block_result
|
|
|
|
|
@@ -643,6 +646,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
self,
|
|
|
block_result: LLMResult,
|
|
|
prompt_messages: list[PromptMessage],
|
|
|
+ stop: Optional[list[str]] = None,
|
|
|
) -> Generator[LLMResultChunk, None, None]:
|
|
|
"""
|
|
|
Handle llm chat response
|
|
@@ -652,15 +656,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
:param response: response
|
|
|
:param prompt_messages: prompt messages
|
|
|
:param tools: tools for tool calling
|
|
|
+ :param stop: stop words
|
|
|
:return: llm response chunk generator
|
|
|
"""
|
|
|
+ text = block_result.message.content
|
|
|
+ text = cast(str, text)
|
|
|
+
|
|
|
+ if stop:
|
|
|
+ text = self.enforce_stop_tokens(text, stop)
|
|
|
+
|
|
|
yield LLMResultChunk(
|
|
|
model=block_result.model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
system_fingerprint=block_result.system_fingerprint,
|
|
|
delta=LLMResultChunkDelta(
|
|
|
index=0,
|
|
|
- message=block_result.message,
|
|
|
+ message=AssistantPromptMessage(content=text),
|
|
|
finish_reason="stop",
|
|
|
usage=block_result.usage,
|
|
|
),
|
|
@@ -912,6 +923,20 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
+ if model.startswith("o1"):
|
|
|
+ system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
|
|
+ if system_message_count > 0:
|
|
|
+ new_prompt_messages = []
|
|
|
+ for prompt_message in prompt_messages:
|
|
|
+ if isinstance(prompt_message, SystemPromptMessage):
|
|
|
+ prompt_message = UserPromptMessage(
|
|
|
+ content=prompt_message.content,
|
|
|
+ name=prompt_message.name,
|
|
|
+ )
|
|
|
+
|
|
|
+ new_prompt_messages.append(prompt_message)
|
|
|
+ prompt_messages = new_prompt_messages
|
|
|
+
|
|
|
return prompt_messages
|
|
|
|
|
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|