Parcourir la source

[Fix] revert sagemaker llm to support model hub (#12378)

Warren Chen il y a 3 mois
Parent
commit
147d578922
1 fichiers modifiés avec 52 ajouts et 108 suppressions
  1. 52 108
      api/core/model_runtime/model_providers/sagemaker/llm/llm.py

+ 52 - 108
api/core/model_runtime/model_providers/sagemaker/llm/llm.py

@@ -1,5 +1,6 @@
 import json
 import logging
+import re
 from collections.abc import Generator, Iterator
 from typing import Any, Optional, Union, cast
 
@@ -131,115 +132,58 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         """
         handle stream chat generate response
         """
-
-        class ChunkProcessor:
-            def __init__(self):
-                self.buffer = bytearray()
-
-            def try_decode_chunk(self, chunk: bytes) -> list[dict]:
-                """尝试从chunk中解码出完整的JSON对象"""
-                self.buffer.extend(chunk)
-                results = []
-
-                while True:
-                    try:
-                        start = self.buffer.find(b"{")
-                        if start == -1:
-                            self.buffer.clear()
-                            break
-
-                        bracket_count = 0
-                        end = start
-
-                        for i in range(start, len(self.buffer)):
-                            if self.buffer[i] == ord("{"):
-                                bracket_count += 1
-                            elif self.buffer[i] == ord("}"):
-                                bracket_count -= 1
-                                if bracket_count == 0:
-                                    end = i + 1
-                                    break
-
-                        if bracket_count != 0:
-                            # JSON不完整,等待更多数据
-                            if start > 0:
-                                self.buffer = self.buffer[start:]
-                            break
-
-                        json_bytes = self.buffer[start:end]
-                        try:
-                            data = json.loads(json_bytes)
-                            results.append(data)
-                            self.buffer = self.buffer[end:]
-                        except json.JSONDecodeError:
-                            self.buffer = self.buffer[start + 1 :]
-
-                    except Exception as e:
-                        logger.debug(f"Warning: Error processing chunk ({str(e)})")
-                        if start > 0:
-                            self.buffer = self.buffer[start:]
-                        break
-
-                return results
-
         full_response = ""
-        processor = ChunkProcessor()
-
-        try:
-            for chunk in resp:
-                json_objects = processor.try_decode_chunk(chunk)
-
-                for data in json_objects:
-                    if data.get("choices"):
-                        choice = data["choices"][0]
-
-                        if "delta" in choice and "content" in choice["delta"]:
-                            chunk_content = choice["delta"]["content"]
-                            assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
-
-                            if choice.get("finish_reason") is not None:
-                                temp_assistant_prompt_message = AssistantPromptMessage(
-                                    content=full_response, tool_calls=[]
-                                )
-
-                                prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
-                                completion_tokens = self._num_tokens_from_messages(
-                                    messages=[temp_assistant_prompt_message], tools=[]
-                                )
-
-                                usage = self._calc_response_usage(
-                                    model=model,
-                                    credentials=credentials,
-                                    prompt_tokens=prompt_tokens,
-                                    completion_tokens=completion_tokens,
-                                )
-
-                                yield LLMResultChunk(
-                                    model=model,
-                                    prompt_messages=prompt_messages,
-                                    system_fingerprint=None,
-                                    delta=LLMResultChunkDelta(
-                                        index=0,
-                                        message=assistant_prompt_message,
-                                        finish_reason=choice["finish_reason"],
-                                        usage=usage,
-                                    ),
-                                )
-                            else:
-                                yield LLMResultChunk(
-                                    model=model,
-                                    prompt_messages=prompt_messages,
-                                    system_fingerprint=None,
-                                    delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
-                                )
-
-                                full_response += chunk_content
-
-        except Exception as e:
-            raise
-
-        if not full_response:
-            logger.warning("No content received from stream response")
+        buffer = ""
+        for chunk_bytes in resp:
+            buffer += chunk_bytes.decode("utf-8")
+            last_idx = 0
+            for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
+                try:
+                    data = json.loads(match.group(1).strip())
+                    last_idx = match.span()[1]
+
+                    if "content" in data["choices"][0]["delta"]:
+                        chunk_content = data["choices"][0]["delta"]["content"]
+                        assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
+
+                        if data["choices"][0]["finish_reason"] is not None:
+                            temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
+                            prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
+                            completion_tokens = self._num_tokens_from_messages(
+                                messages=[temp_assistant_prompt_message], tools=[]
+                            )
+                            usage = self._calc_response_usage(
+                                model=model,
+                                credentials=credentials,
+                                prompt_tokens=prompt_tokens,
+                                completion_tokens=completion_tokens,
+                            )
+
+                            yield LLMResultChunk(
+                                model=model,
+                                prompt_messages=prompt_messages,
+                                system_fingerprint=None,
+                                delta=LLMResultChunkDelta(
+                                    index=0,
+                                    message=assistant_prompt_message,
+                                    finish_reason=data["choices"][0]["finish_reason"],
+                                    usage=usage,
+                                ),
+                            )
+                        else:
+                            yield LLMResultChunk(
+                                model=model,
+                                prompt_messages=prompt_messages,
+                                system_fingerprint=None,
+                                delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
+                            )
+
+                            full_response += chunk_content
+                except (json.JSONDecodeError, KeyError, IndexError) as e:
+                    logger.info("json parse exception, content: {}".format(match.group(1).strip()))
+                    pass
+
+            buffer = buffer[last_idx:]
 
     def _invoke(
         self,