Pārlūkot izejas kodu

[Fix] modify sagemaker llm (#12274)

Warren Chen 3 mēneši atpakaļ
vecāks
revīzija
9954ddb780
1 mainītis faili ar 108 papildinājumiem un 52 dzēšanām
  1. 108 52
      api/core/model_runtime/model_providers/sagemaker/llm/llm.py

+ 108 - 52
api/core/model_runtime/model_providers/sagemaker/llm/llm.py

@@ -1,6 +1,5 @@
 import json
 import logging
-import re
 from collections.abc import Generator, Iterator
 from typing import Any, Optional, Union, cast
 
@@ -132,58 +131,115 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         """
         handle stream chat generate response
         """
+
+        class ChunkProcessor:
+            def __init__(self):
+                self.buffer = bytearray()
+
+            def try_decode_chunk(self, chunk: bytes) -> list[dict]:
+                """尝试从chunk中解码出完整的JSON对象"""
+                self.buffer.extend(chunk)
+                results = []
+
+                while True:
+                    try:
+                        start = self.buffer.find(b"{")
+                        if start == -1:
+                            self.buffer.clear()
+                            break
+
+                        bracket_count = 0
+                        end = start
+
+                        for i in range(start, len(self.buffer)):
+                            if self.buffer[i] == ord("{"):
+                                bracket_count += 1
+                            elif self.buffer[i] == ord("}"):
+                                bracket_count -= 1
+                                if bracket_count == 0:
+                                    end = i + 1
+                                    break
+
+                        if bracket_count != 0:
+                            # JSON不完整,等待更多数据
+                            if start > 0:
+                                self.buffer = self.buffer[start:]
+                            break
+
+                        json_bytes = self.buffer[start:end]
+                        try:
+                            data = json.loads(json_bytes)
+                            results.append(data)
+                            self.buffer = self.buffer[end:]
+                        except json.JSONDecodeError:
+                            self.buffer = self.buffer[start + 1 :]
+
+                    except Exception as e:
+                        logger.debug(f"Warning: Error processing chunk ({str(e)})")
+                        if start > 0:
+                            self.buffer = self.buffer[start:]
+                        break
+
+                return results
+
         full_response = ""
-        buffer = ""
-        for chunk_bytes in resp:
-            buffer += chunk_bytes.decode("utf-8")
-            last_idx = 0
-            for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
-                try:
-                    data = json.loads(match.group(1).strip())
-                    last_idx = match.span()[1]
-
-                    if "content" in data["choices"][0]["delta"]:
-                        chunk_content = data["choices"][0]["delta"]["content"]
-                        assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
-
-                        if data["choices"][0]["finish_reason"] is not None:
-                            temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
-                            prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
-                            completion_tokens = self._num_tokens_from_messages(
-                                messages=[temp_assistant_prompt_message], tools=[]
-                            )
-                            usage = self._calc_response_usage(
-                                model=model,
-                                credentials=credentials,
-                                prompt_tokens=prompt_tokens,
-                                completion_tokens=completion_tokens,
-                            )
-
-                            yield LLMResultChunk(
-                                model=model,
-                                prompt_messages=prompt_messages,
-                                system_fingerprint=None,
-                                delta=LLMResultChunkDelta(
-                                    index=0,
-                                    message=assistant_prompt_message,
-                                    finish_reason=data["choices"][0]["finish_reason"],
-                                    usage=usage,
-                                ),
-                            )
-                        else:
-                            yield LLMResultChunk(
-                                model=model,
-                                prompt_messages=prompt_messages,
-                                system_fingerprint=None,
-                                delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
-                            )
-
-                            full_response += chunk_content
-                except (json.JSONDecodeError, KeyError, IndexError) as e:
-                    logger.info("json parse exception, content: {}".format(match.group(1).strip()))
-                    pass
-
-            buffer = buffer[last_idx:]
+        processor = ChunkProcessor()
+
+        try:
+            for chunk in resp:
+                json_objects = processor.try_decode_chunk(chunk)
+
+                for data in json_objects:
+                    if data.get("choices"):
+                        choice = data["choices"][0]
+
+                        if "delta" in choice and "content" in choice["delta"]:
+                            chunk_content = choice["delta"]["content"]
+                            assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
+
+                            if choice.get("finish_reason") is not None:
+                                temp_assistant_prompt_message = AssistantPromptMessage(
+                                    content=full_response, tool_calls=[]
+                                )
+
+                                prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
+                                completion_tokens = self._num_tokens_from_messages(
+                                    messages=[temp_assistant_prompt_message], tools=[]
+                                )
+
+                                usage = self._calc_response_usage(
+                                    model=model,
+                                    credentials=credentials,
+                                    prompt_tokens=prompt_tokens,
+                                    completion_tokens=completion_tokens,
+                                )
+
+                                yield LLMResultChunk(
+                                    model=model,
+                                    prompt_messages=prompt_messages,
+                                    system_fingerprint=None,
+                                    delta=LLMResultChunkDelta(
+                                        index=0,
+                                        message=assistant_prompt_message,
+                                        finish_reason=choice["finish_reason"],
+                                        usage=usage,
+                                    ),
+                                )
+                            else:
+                                yield LLMResultChunk(
+                                    model=model,
+                                    prompt_messages=prompt_messages,
+                                    system_fingerprint=None,
+                                    delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
+                                )
+
+                                full_response += chunk_content
+
+        except Exception as e:
+            raise
+
+        if not full_response:
+            logger.warning("No content received from stream response")
 
     def _invoke(
         self,