Quellcode durchsuchen

openai compatiable api usage and id (#9800)

Co-authored-by: jinqi.guo <jinqi.guo@ubtrobot.com>
guogeer vor 6 Monaten
Ursprung
Commit
70ddc0ce43

+ 1 - 0
api/core/model_runtime/entities/llm_entities.py

@@ -105,6 +105,7 @@ class LLMResult(BaseModel):
     Model class for llm result.
     """
 
+    id: Optional[str] = None
     model: str
     prompt_messages: list[PromptMessage]
     message: AssistantPromptMessage

+ 26 - 7
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -397,16 +397,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
         chunk_index = 0
 
         def create_final_llm_result_chunk(
-            index: int, message: AssistantPromptMessage, finish_reason: str
+            id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
         ) -> LLMResultChunk:
             # calculate num tokens
-            prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
-            completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
+            prompt_tokens = usage and usage.get("prompt_tokens")
+            if prompt_tokens is None:
+                prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
+            completion_tokens = usage and usage.get("completion_tokens")
+            if completion_tokens is None:
+                completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
 
             # transform usage
             usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 
             return LLMResultChunk(
+                id=id,
                 model=model,
                 prompt_messages=prompt_messages,
                 delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
@@ -450,7 +455,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                     tool_call.function.arguments += new_tool_call.function.arguments
 
         finish_reason = None  # The default value of finish_reason is None
-
+        message_id, usage = None, None
         for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             chunk = chunk.strip()
             if chunk:
@@ -462,20 +467,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                     continue
 
                 try:
-                    chunk_json = json.loads(decoded_chunk)
+                    chunk_json: dict = json.loads(decoded_chunk)
                 # stream ended
                 except json.JSONDecodeError as e:
                     yield create_final_llm_result_chunk(
+                        id=message_id,
                         index=chunk_index + 1,
                         message=AssistantPromptMessage(content=""),
                         finish_reason="Non-JSON encountered.",
+                        usage=usage,
                     )
                     break
+                if chunk_json:
+                    if u := chunk_json.get("usage"):
+                        usage = u
                 if not chunk_json or len(chunk_json["choices"]) == 0:
                     continue
 
                 choice = chunk_json["choices"][0]
                 finish_reason = chunk_json["choices"][0].get("finish_reason")
+                message_id = chunk_json.get("id")
                 chunk_index += 1
 
                 if "delta" in choice:
@@ -524,6 +535,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                     continue
 
                 yield LLMResultChunk(
+                    id=message_id,
                     model=model,
                     prompt_messages=prompt_messages,
                     delta=LLMResultChunkDelta(
@@ -536,6 +548,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
 
         if tools_calls:
             yield LLMResultChunk(
+                id=message_id,
                 model=model,
                 prompt_messages=prompt_messages,
                 delta=LLMResultChunkDelta(
@@ -545,17 +558,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
             )
 
         yield create_final_llm_result_chunk(
-            index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
+            id=message_id,
+            index=chunk_index,
+            message=AssistantPromptMessage(content=""),
+            finish_reason=finish_reason,
+            usage=usage,
         )
 
     def _handle_generate_response(
         self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
     ) -> LLMResult:
-        response_json = response.json()
+        response_json: dict = response.json()
 
         completion_type = LLMMode.value_of(credentials["mode"])
 
         output = response_json["choices"][0]
+        message_id = response_json.get("id")
 
         response_content = ""
         tool_calls = None
@@ -593,6 +611,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
 
         # transform response
         result = LLMResult(
+            id=message_id,
             model=response_json["model"],
             prompt_messages=prompt_messages,
             message=assistant_message,