Kaynağa Gözat

fix: azure openai stream response usage missing (#1998)

takatost 1 yıl önce
ebeveyn
işleme
5e97eb1840

+ 3 - 0
api/core/app_runner/app_runner.py

@@ -257,6 +257,9 @@ class AppRunner:
             if not usage and result.delta.usage:
                 usage = result.delta.usage
 
+        if not usage:
+            usage = LLMUsage.empty_usage()
+
         llm_result = LLMResult(
             model=model,
             prompt_messages=prompt_messages,

+ 38 - 31
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -322,8 +322,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
                                               response: Stream[ChatCompletionChunk],
                                               prompt_messages: list[PromptMessage],
                                               tools: Optional[list[PromptMessageTool]] = None) -> Generator:
-
+        index = 0
         full_assistant_content = ''
+        real_model = model
+        system_fingerprint = None
+        completion = ''
         for chunk in response:
             if len(chunk.choices) == 0:
                 continue
@@ -349,40 +352,44 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
 
             full_assistant_content += delta.delta.content if delta.delta.content else ''
 
-            if delta.finish_reason is not None:
-                # calculate num tokens
-                prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
+            real_model = chunk.model
+            system_fingerprint = chunk.system_fingerprint
+            completion += delta.delta.content if delta.delta.content else ''
 
-                full_assistant_prompt_message = AssistantPromptMessage(
-                    content=full_assistant_content,
-                    tool_calls=tool_calls
+            yield LLMResultChunk(
+                model=real_model,
+                prompt_messages=prompt_messages,
+                system_fingerprint=system_fingerprint,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=assistant_prompt_message,
                 )
-                completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
+            )
 
-                # transform usage
-                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+            index += 0
 
-                yield LLMResultChunk(
-                    model=chunk.model,
-                    prompt_messages=prompt_messages,
-                    system_fingerprint=chunk.system_fingerprint,
-                    delta=LLMResultChunkDelta(
-                        index=delta.index,
-                        message=assistant_prompt_message,
-                        finish_reason=delta.finish_reason,
-                        usage=usage
-                    )
-                )
-            else:
-                yield LLMResultChunk(
-                    model=chunk.model,
-                    prompt_messages=prompt_messages,
-                    system_fingerprint=chunk.system_fingerprint,
-                    delta=LLMResultChunkDelta(
-                        index=delta.index,
-                        message=assistant_prompt_message,
-                    )
-                )
+        # calculate num tokens
+        prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
+
+        full_assistant_prompt_message = AssistantPromptMessage(
+            content=completion
+        )
+        completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        yield LLMResultChunk(
+            model=real_model,
+            prompt_messages=prompt_messages,
+            system_fingerprint=system_fingerprint,
+            delta=LLMResultChunkDelta(
+                index=index,
+                message=AssistantPromptMessage(content=''),
+                finish_reason='stop',
+                usage=usage
+            )
+        )
 
     @staticmethod
     def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \

+ 0 - 1
api/tests/integration_tests/model_runtime/azure_openai/test_llm.py

@@ -190,7 +190,6 @@ def test_invoke_stream_chat_model(setup_openai_mock):
         assert isinstance(chunk, LLMResultChunk)
         assert isinstance(chunk.delta, LLMResultChunkDelta)
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
-        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
         if chunk.delta.finish_reason is not None:
             assert chunk.delta.usage is not None
             assert chunk.delta.usage.completion_tokens > 0