Ver código fonte

feat: add o1-series models support in Agent App (ReACT only) (#8350)

takatost 7 meses atrás
pai
commit
4637ddaa7f

+ 27 - 2
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -620,6 +620,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
             if "stream_options" in extra_model_kwargs:
                 del extra_model_kwargs["stream_options"]
 
+            if "stop" in extra_model_kwargs:
+                del extra_model_kwargs["stop"]
+
         # chat model
         response = client.chat.completions.create(
             messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@@ -635,7 +638,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
 
         if block_as_stream:
-            return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
+            return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
 
         return block_result
 
@@ -643,6 +646,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         self,
         block_result: LLMResult,
         prompt_messages: list[PromptMessage],
+        stop: Optional[list[str]] = None,
     ) -> Generator[LLMResultChunk, None, None]:
         """
         Handle llm chat response
@@ -652,15 +656,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         :param response: response
         :param prompt_messages: prompt messages
         :param tools: tools for tool calling
+        :param stop: stop words
         :return: llm response chunk generator
         """
+        text = block_result.message.content
+        text = cast(str, text)
+
+        if stop:
+            text = self.enforce_stop_tokens(text, stop)
+
         yield LLMResultChunk(
             model=block_result.model,
             prompt_messages=prompt_messages,
             system_fingerprint=block_result.system_fingerprint,
             delta=LLMResultChunkDelta(
                 index=0,
-                message=block_result.message,
+                message=AssistantPromptMessage(content=text),
                 finish_reason="stop",
                 usage=block_result.usage,
             ),
@@ -912,6 +923,20 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                                 ]
                             )
 
+        if model.startswith("o1"):
+            system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
+            if system_message_count > 0:
+                new_prompt_messages = []
+                for prompt_message in prompt_messages:
+                    if isinstance(prompt_message, SystemPromptMessage):
+                        prompt_message = UserPromptMessage(
+                            content=prompt_message.content,
+                            name=prompt_message.name,
+                        )
+
+                    new_prompt_messages.append(prompt_message)
+                prompt_messages = new_prompt_messages
+
         return prompt_messages
 
     def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: