Browse Source

feat: support openai stream usage (#4140)

Yeuoly 1 year ago
parent
commit
d5d8b98d82
2 changed files with 77 additions and 30 deletions
  1. 76 29
      api/core/model_runtime/model_providers/openai/llm/llm.py
  2. 1 1
      api/requirements.txt

+ 76 - 29
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -378,6 +378,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         if user:
             extra_model_kwargs['user'] = user
 
+        if stream:
+            extra_model_kwargs['stream_options'] = {
+                "include_usage": True
+            }
+        
         # text completion model
         response = client.completions.create(
             prompt=prompt_messages[0].content,
@@ -446,8 +451,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         :return: llm response chunk generator result
         """
         full_text = ''
+        prompt_tokens = 0
+        completion_tokens = 0
+
+        final_chunk = LLMResultChunk(
+            model=model,
+            prompt_messages=prompt_messages,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=AssistantPromptMessage(content=''),
+            )
+        )
+
         for chunk in response:
             if len(chunk.choices) == 0:
+                if chunk.usage:
+                    # calculate num tokens
+                    prompt_tokens = chunk.usage.prompt_tokens
+                    completion_tokens = chunk.usage.completion_tokens
                 continue
 
             delta = chunk.choices[0]
@@ -464,20 +485,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
             full_text += text
 
             if delta.finish_reason is not None:
-                # calculate num tokens
-                if chunk.usage:
-                    # transform usage
-                    prompt_tokens = chunk.usage.prompt_tokens
-                    completion_tokens = chunk.usage.completion_tokens
-                else:
-                    # calculate num tokens
-                    prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
-                    completion_tokens = self._num_tokens_from_string(model, full_text)
-
-                # transform usage
-                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
-                yield LLMResultChunk(
+                final_chunk = LLMResultChunk(
                     model=chunk.model,
                     prompt_messages=prompt_messages,
                     system_fingerprint=chunk.system_fingerprint,
@@ -485,7 +493,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                         index=delta.index,
                         message=assistant_prompt_message,
                         finish_reason=delta.finish_reason,
-                        usage=usage
                     )
                 )
             else:
@@ -499,6 +506,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                     )
                 )
 
+        if not prompt_tokens:
+            prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
+
+        if not completion_tokens:
+            completion_tokens = self._num_tokens_from_string(model, full_text)
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        final_chunk.delta.usage = usage
+
+        yield final_chunk
+
     def _chat_generate(self, model: str, credentials: dict,
                        prompt_messages: list[PromptMessage], model_parameters: dict,
                        tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@@ -531,6 +551,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
 
             model_parameters["response_format"] = response_format
 
+
         extra_model_kwargs = {}
 
         if tools:
@@ -547,6 +568,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         if user:
             extra_model_kwargs['user'] = user
 
+        if stream:
+            extra_model_kwargs['stream_options'] = {
+                'include_usage': True
+            }
+
         # clear illegal prompt messages
         prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
 
@@ -630,8 +656,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         """
         full_assistant_content = ''
         delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
+        prompt_tokens = 0
+        completion_tokens = 0
+        final_tool_calls = []
+        final_chunk = LLMResultChunk(
+            model=model,
+            prompt_messages=prompt_messages,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=AssistantPromptMessage(content=''),
+            )
+        )
+
         for chunk in response:
             if len(chunk.choices) == 0:
+                if chunk.usage:
+                    # calculate num tokens
+                    prompt_tokens = chunk.usage.prompt_tokens
+                    completion_tokens = chunk.usage.completion_tokens
                 continue
 
             delta = chunk.choices[0]
@@ -667,6 +709,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
             # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
             function_call = self._extract_response_function_call(assistant_message_function_call)
             tool_calls = [function_call] if function_call else []
+            if tool_calls:
+                final_tool_calls.extend(tool_calls)
 
             # transform assistant message to prompt message
             assistant_prompt_message = AssistantPromptMessage(
@@ -677,19 +721,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
             full_assistant_content += delta.delta.content if delta.delta.content else ''
 
             if has_finish_reason:
-                # calculate num tokens
-                prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
-
-                full_assistant_prompt_message = AssistantPromptMessage(
-                    content=full_assistant_content,
-                    tool_calls=tool_calls
-                )
-                completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
-
-                # transform usage
-                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
-                yield LLMResultChunk(
+                final_chunk = LLMResultChunk(
                     model=chunk.model,
                     prompt_messages=prompt_messages,
                     system_fingerprint=chunk.system_fingerprint,
@@ -697,7 +729,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                         index=delta.index,
                         message=assistant_prompt_message,
                         finish_reason=delta.finish_reason,
-                        usage=usage
                     )
                 )
             else:
@@ -711,6 +742,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                     )
                 )
 
+        if not prompt_tokens:
+            prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
+
+        if not completion_tokens:
+            full_assistant_prompt_message = AssistantPromptMessage(
+                content=full_assistant_content,
+                tool_calls=final_tool_calls
+            )
+            completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+        final_chunk.delta.usage = usage
+
+        yield final_chunk
+
     def _extract_response_tool_calls(self,
                                      response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
             -> list[AssistantPromptMessage.ToolCall]:

+ 1 - 1
api/requirements.txt

@@ -9,7 +9,7 @@ flask-restful~=0.3.10
 flask-cors~=4.0.0
 gunicorn~=22.0.0
 gevent~=23.9.1
-openai~=1.13.3
+openai~=1.26.0
 tiktoken~=0.6.0
 psycopg2-binary~=2.9.6
 pycryptodome==3.19.1