|
@@ -113,7 +113,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
try:
|
|
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
|
|
|
|
|
- if "o1" in model:
|
|
|
+ if model.startswith("o1"):
|
|
|
client.chat.completions.create(
|
|
|
messages=[{"role": "user", "content": "ping"}],
|
|
|
model=model,
|
|
@@ -311,7 +311,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
|
|
|
|
|
block_as_stream = False
|
|
|
- if "o1" in model:
|
|
|
+ if model.startswith("o1"):
|
|
|
+ if "max_tokens" in model_parameters:
|
|
|
+ model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
|
|
+ del model_parameters["max_tokens"]
|
|
|
if stream:
|
|
|
block_as_stream = True
|
|
|
stream = False
|
|
@@ -404,7 +407,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
- if "o1" in model:
|
|
|
+ if model.startswith("o1"):
|
|
|
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
|
|
if system_message_count > 0:
|
|
|
new_prompt_messages = []
|