|
@@ -1,5 +1,6 @@
|
|
|
import json
|
|
|
import logging
|
|
|
+import re
|
|
|
from collections.abc import Generator, Iterator
|
|
|
from typing import Any, Optional, Union, cast
|
|
|
|
|
@@ -131,115 +132,58 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
handle stream chat generate response
|
|
|
"""
|
|
|
-
|
|
|
- class ChunkProcessor:
|
|
|
- def __init__(self):
|
|
|
- self.buffer = bytearray()
|
|
|
-
|
|
|
- def try_decode_chunk(self, chunk: bytes) -> list[dict]:
|
|
|
- """尝试从chunk中解码出完整的JSON对象"""
|
|
|
- self.buffer.extend(chunk)
|
|
|
- results = []
|
|
|
-
|
|
|
- while True:
|
|
|
- try:
|
|
|
- start = self.buffer.find(b"{")
|
|
|
- if start == -1:
|
|
|
- self.buffer.clear()
|
|
|
- break
|
|
|
-
|
|
|
- bracket_count = 0
|
|
|
- end = start
|
|
|
-
|
|
|
- for i in range(start, len(self.buffer)):
|
|
|
- if self.buffer[i] == ord("{"):
|
|
|
- bracket_count += 1
|
|
|
- elif self.buffer[i] == ord("}"):
|
|
|
- bracket_count -= 1
|
|
|
- if bracket_count == 0:
|
|
|
- end = i + 1
|
|
|
- break
|
|
|
-
|
|
|
- if bracket_count != 0:
|
|
|
- # JSON不完整,等待更多数据
|
|
|
- if start > 0:
|
|
|
- self.buffer = self.buffer[start:]
|
|
|
- break
|
|
|
-
|
|
|
- json_bytes = self.buffer[start:end]
|
|
|
- try:
|
|
|
- data = json.loads(json_bytes)
|
|
|
- results.append(data)
|
|
|
- self.buffer = self.buffer[end:]
|
|
|
- except json.JSONDecodeError:
|
|
|
- self.buffer = self.buffer[start + 1 :]
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.debug(f"Warning: Error processing chunk ({str(e)})")
|
|
|
- if start > 0:
|
|
|
- self.buffer = self.buffer[start:]
|
|
|
- break
|
|
|
-
|
|
|
- return results
|
|
|
-
|
|
|
full_response = ""
|
|
|
- processor = ChunkProcessor()
|
|
|
-
|
|
|
- try:
|
|
|
- for chunk in resp:
|
|
|
- json_objects = processor.try_decode_chunk(chunk)
|
|
|
-
|
|
|
- for data in json_objects:
|
|
|
- if data.get("choices"):
|
|
|
- choice = data["choices"][0]
|
|
|
-
|
|
|
- if "delta" in choice and "content" in choice["delta"]:
|
|
|
- chunk_content = choice["delta"]["content"]
|
|
|
- assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
|
|
|
-
|
|
|
- if choice.get("finish_reason") is not None:
|
|
|
- temp_assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=full_response, tool_calls=[]
|
|
|
- )
|
|
|
-
|
|
|
- prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
|
- completion_tokens = self._num_tokens_from_messages(
|
|
|
- messages=[temp_assistant_prompt_message], tools=[]
|
|
|
- )
|
|
|
-
|
|
|
- usage = self._calc_response_usage(
|
|
|
- model=model,
|
|
|
- credentials=credentials,
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- completion_tokens=completion_tokens,
|
|
|
- )
|
|
|
-
|
|
|
- yield LLMResultChunk(
|
|
|
- model=model,
|
|
|
- prompt_messages=prompt_messages,
|
|
|
- system_fingerprint=None,
|
|
|
- delta=LLMResultChunkDelta(
|
|
|
- index=0,
|
|
|
- message=assistant_prompt_message,
|
|
|
- finish_reason=choice["finish_reason"],
|
|
|
- usage=usage,
|
|
|
- ),
|
|
|
- )
|
|
|
- else:
|
|
|
- yield LLMResultChunk(
|
|
|
- model=model,
|
|
|
- prompt_messages=prompt_messages,
|
|
|
- system_fingerprint=None,
|
|
|
- delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
|
|
|
- )
|
|
|
-
|
|
|
- full_response += chunk_content
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- raise
|
|
|
-
|
|
|
- if not full_response:
|
|
|
- logger.warning("No content received from stream response")
|
|
|
+ buffer = ""
|
|
|
+ for chunk_bytes in resp:
|
|
|
+ buffer += chunk_bytes.decode("utf-8")
|
|
|
+ last_idx = 0
|
|
|
+ for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
|
|
|
+ try:
|
|
|
+ data = json.loads(match.group(1).strip())
|
|
|
+ last_idx = match.span()[1]
|
|
|
+
|
|
|
+ if "content" in data["choices"][0]["delta"]:
|
|
|
+ chunk_content = data["choices"][0]["delta"]["content"]
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
|
|
|
+
|
|
|
+ if data["choices"][0]["finish_reason"] is not None:
|
|
|
+ temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[])
|
|
|
+ prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
|
+ completion_tokens = self._num_tokens_from_messages(
|
|
|
+ messages=[temp_assistant_prompt_message], tools=[]
|
|
|
+ )
|
|
|
+ usage = self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ completion_tokens=completion_tokens,
|
|
|
+ )
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=None,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ finish_reason=data["choices"][0]["finish_reason"],
|
|
|
+ usage=usage,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=None,
|
|
|
+ delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
|
|
|
+ )
|
|
|
+
|
|
|
+ full_response += chunk_content
|
|
|
+ except (json.JSONDecodeError, KeyError, IndexError) as e:
|
|
|
+ logger.info("json parse exception, content: {}".format(match.group(1).strip()))
|
|
|
+ pass
|
|
|
+
|
|
|
+ buffer = buffer[last_idx:]
|
|
|
|
|
|
def _invoke(
|
|
|
self,
|