소스 검색

chore: refactor the beichuan model (#7953)

非法操作 7 달 전
부모
커밋
0f72a8e89d

+ 0 - 8
api/core/model_runtime/model_providers/baichuan/baichuan.yaml

@@ -27,11 +27,3 @@ provider_credential_schema:
       placeholder:
         zh_Hans: 在此输入您的 API Key
         en_US: Enter your API Key
-    - variable: secret_key
-      label:
-        en_US: Secret Key
-      type: secret-input
-      required: false
-      placeholder:
-        zh_Hans: 在此输入您的 Secret Key
-        en_US: Enter your Secret Key

+ 1 - 0
api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml

@@ -43,3 +43,4 @@ parameter_rules:
       zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
       en_US: Allow the model to perform external search to enhance the generation results.
     required: false
+deprecated: true

+ 1 - 0
api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml

@@ -43,3 +43,4 @@ parameter_rules:
       zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
       en_US: Allow the model to perform external search to enhance the generation results.
     required: false
+deprecated: true

+ 7 - 11
api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml

@@ -4,36 +4,32 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - multi-tool-call
 model_properties:
   mode: chat
   context_size: 32000
 parameter_rules:
   - name: temperature
     use_template: temperature
+    default: 0.3
   - name: top_p
     use_template: top_p
+    default: 0.85
   - name: top_k
     label:
       zh_Hans: 取样数量
       en_US: Top k
     type: int
+    min: 0
+    max: 20
+    default: 5
     help:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
   - name: max_tokens
     use_template: max_tokens
-    required: true
-    default: 8000
-    min: 1
-    max: 192000
-  - name: presence_penalty
-    use_template: presence_penalty
-  - name: frequency_penalty
-    use_template: frequency_penalty
-    default: 1
-    min: 1
-    max: 2
+    default: 2048
   - name: with_search_enhance
     label:
       zh_Hans: 搜索增强

+ 19 - 11
api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml

@@ -4,36 +4,44 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - multi-tool-call
 model_properties:
   mode: chat
   context_size: 128000
 parameter_rules:
   - name: temperature
     use_template: temperature
+    default: 0.3
   - name: top_p
     use_template: top_p
+    default: 0.85
   - name: top_k
     label:
       zh_Hans: 取样数量
       en_US: Top k
     type: int
+    min: 0
+    max: 20
+    default: 5
     help:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
   - name: max_tokens
     use_template: max_tokens
-    required: true
-    default: 8000
-    min: 1
-    max: 128000
-  - name: presence_penalty
-    use_template: presence_penalty
-  - name: frequency_penalty
-    use_template: frequency_penalty
-    default: 1
-    min: 1
-    max: 2
+    default: 2048
+  - name: res_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
   - name: with_search_enhance
     label:
       zh_Hans: 搜索增强

+ 19 - 11
api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml

@@ -4,36 +4,44 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - multi-tool-call
 model_properties:
   mode: chat
   context_size: 32000
 parameter_rules:
   - name: temperature
     use_template: temperature
+    default: 0.3
   - name: top_p
     use_template: top_p
+    default: 0.85
   - name: top_k
     label:
       zh_Hans: 取样数量
       en_US: Top k
     type: int
+    min: 0
+    max: 20
+    default: 5
     help:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
   - name: max_tokens
     use_template: max_tokens
-    required: true
-    default: 8000
-    min: 1
-    max: 32000
-  - name: presence_penalty
-    use_template: presence_penalty
-  - name: frequency_penalty
-    use_template: frequency_penalty
-    default: 1
-    min: 1
-    max: 2
+    default: 2048
+  - name: res_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
   - name: with_search_enhance
     label:
       zh_Hans: 搜索增强

+ 19 - 11
api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml

@@ -4,36 +4,44 @@ label:
 model_type: llm
 features:
   - agent-thought
+  - multi-tool-call
 model_properties:
   mode: chat
   context_size: 32000
 parameter_rules:
   - name: temperature
     use_template: temperature
+    default: 0.3
   - name: top_p
     use_template: top_p
+    default: 0.85
   - name: top_k
     label:
       zh_Hans: 取样数量
       en_US: Top k
     type: int
+    min: 0
+    max: 20
+    default: 5
     help:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
   - name: max_tokens
     use_template: max_tokens
-    required: true
-    default: 8000
-    min: 1
-    max: 32000
-  - name: presence_penalty
-    use_template: presence_penalty
-  - name: frequency_penalty
-    use_template: frequency_penalty
-    default: 1
-    min: 1
-    max: 2
+    default: 2048
+  - name: res_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
   - name: with_search_enhance
     label:
       zh_Hans: 搜索增强

+ 94 - 165
api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py

@@ -1,11 +1,10 @@
-from collections.abc import Generator
-from enum import Enum
-from hashlib import md5
-from json import dumps, loads
-from typing import Any, Union
+import json
+from collections.abc import Iterator
+from typing import Any, Optional, Union
 
 from requests import post
 
+from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
     BadRequestError,
     InsufficientAccountBalance,
@@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
 )
 
 
-class BaichuanMessage:
-    class Role(Enum):
-        USER = 'user'
-        ASSISTANT = 'assistant'
-        # Baichuan does not have system message
-        _SYSTEM = 'system'
-
-    role: str = Role.USER.value
-    content: str
-    usage: dict[str, int] = None
-    stop_reason: str = ''
-
-    def to_dict(self) -> dict[str, Any]:
-        return {
-            'role': self.role,
-            'content': self.content,
-        }
-    
-    def __init__(self, content: str, role: str = 'user') -> None:
-        self.content = content
-        self.role = role
-
 class BaichuanModel:
     api_key: str
-    secret_key: str
 
-    def __init__(self, api_key: str, secret_key: str = '') -> None:
+    def __init__(self, api_key: str) -> None:
         self.api_key = api_key
-        self.secret_key = secret_key
 
-    def _model_mapping(self, model: str) -> str:
+    @property
+    def _model_mapping(self) -> dict:
         return {
-            'baichuan2-turbo': 'Baichuan2-Turbo',
-            'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
-            'baichuan2-53b': 'Baichuan2-53B',
-            'baichuan3-turbo': 'Baichuan3-Turbo',
-            'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
-            'baichuan4': 'Baichuan4',
-        }[model]
-
-    def _handle_chat_generate_response(self, response) -> BaichuanMessage:
-        resp = response.json()
-        choices = resp.get('choices', [])
-        message = BaichuanMessage(content='', role='assistant')
-        for choice in choices:
-            message.content += choice['message']['content']
-            message.role = choice['message']['role']
-            if choice['finish_reason']:
-                message.stop_reason = choice['finish_reason']
-
-        if 'usage' in resp:
-            message.usage = {
-                'prompt_tokens': resp['usage']['prompt_tokens'],
-                'completion_tokens': resp['usage']['completion_tokens'],
-                'total_tokens': resp['usage']['total_tokens'],
-            }
+            "baichuan2-turbo": "Baichuan2-Turbo",
+            "baichuan3-turbo": "Baichuan3-Turbo",
+            "baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
+            "baichuan4": "Baichuan4",
+        }
 
-        return message
-    
-    def _handle_chat_stream_generate_response(self, response) -> Generator:
-        for line in response.iter_lines():
-            if not line:
-                continue
-            line = line.decode('utf-8')
-            # remove the first `data: ` prefix
-            if line.startswith('data:'):
-                line = line[5:].strip()
-            try:
-                data = loads(line)
-            except Exception as e:
-                if line.strip() == '[DONE]':
-                    return
-            choices = data.get('choices', [])
-            # save stop reason temporarily
-            stop_reason = ''
-            for choice in choices:
-                if choice.get('finish_reason'):
-                    stop_reason = choice['finish_reason']
-
-                if len(choice['delta']['content']) == 0:
-                    continue
-                yield BaichuanMessage(**choice['delta'])
-
-            # if there is usage, the response is the last one, yield it and return
-            if 'usage' in data:
-                message = BaichuanMessage(content='', role='assistant')
-                message.usage = {
-                    'prompt_tokens': data['usage']['prompt_tokens'],
-                    'completion_tokens': data['usage']['completion_tokens'],
-                    'total_tokens': data['usage']['total_tokens'],
-                }
-                message.stop_reason = stop_reason
-                yield message
-
-    def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
-                               parameters: dict[str, Any]) \
-        -> dict[str, Any]:
-        if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
-                or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k'  or model == 'baichuan4'):
-            prompt_messages = []
-            for message in messages:
-                if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
-                    # check if the latest message is a user message
-                    if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
-                        prompt_messages[-1]['content'] += message.content
-                    else:
-                        prompt_messages.append({
-                            'content': message.content,
-                            'role': BaichuanMessage.Role.USER.value,
-                        })
-                elif message.role == BaichuanMessage.Role.ASSISTANT.value:
-                    prompt_messages.append({
-                        'content': message.content,
-                        'role': message.role,
-                    })
-            # [baichuan] frequency_penalty must be between 1 and 2
-            if 'frequency_penalty' in parameters:
-                if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
-                    parameters['frequency_penalty'] = 1
+    @property
+    def request_headers(self) -> dict[str, Any]:
+        return {
+            "Content-Type": "application/json",
+            "Authorization": "Bearer " + self.api_key,
+        }
+
+    def _build_parameters(
+        self,
+        model: str,
+        stream: bool,
+        messages: list[dict],
+        parameters: dict[str, Any],
+        tools: Optional[list[PromptMessageTool]] = None,
+    ) -> dict[str, Any]:
+        if model in self._model_mapping.keys():
+            # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
+            # we need to rename it to res_format to get its value
+            if parameters.get("res_format") == "json_object":
+                parameters["response_format"] = {"type": "json_object"}
+
+            if tools or parameters.get("with_search_enhance") is True:
+                parameters["tools"] = []
+
+            # with_search_enhance is deprecated, use web_search instead
+            if parameters.get("with_search_enhance") is True:
+                parameters["tools"].append(
+                    {
+                        "type": "web_search",
+                        "web_search": {"enable": True},
+                    }
+                )
+            if tools:
+                for tool in tools:
+                    parameters["tools"].append(
+                        {
+                            "type": "function",
+                            "function": {
+                                "name": tool.name,
+                                "description": tool.description,
+                                "parameters": tool.parameters,
+                            },
+                        }
+                    )
 
             # turbo api accepts flat parameters
             return {
-                'model': self._model_mapping(model),
-                'stream': stream,
-                'messages': prompt_messages,
+                "model": self._model_mapping.get(model),
+                "stream": stream,
+                "messages": messages,
                 **parameters,
             }
         else:
             raise BadRequestError(f"Unknown model: {model}")
-        
-    def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
-        if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
-                or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k'  or model == 'baichuan4'):
-            # there is no secret key for turbo api
-            return {
-                'Content-Type': 'application/json',
-                'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
-                'Authorization': 'Bearer ' + self.api_key,
-            }
-        else:
-            raise BadRequestError(f"Unknown model: {model}")
-        
-    def _calculate_md5(self, input_string):
-        return md5(input_string.encode('utf-8')).hexdigest()
-
-    def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], 
-                 parameters: dict[str, Any], timeout: int) \
-        -> Union[Generator, BaichuanMessage]:
-        
-        if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
-                or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k'  or model == 'baichuan4'):
-            api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
+
+    def generate(
+        self,
+        model: str,
+        stream: bool,
+        messages: list[dict],
+        parameters: dict[str, Any],
+        timeout: int,
+        tools: Optional[list[PromptMessageTool]] = None,
+    ) -> Union[Iterator, dict]:
+
+        if model in self._model_mapping.keys():
+            api_base = "https://api.baichuan-ai.com/v1/chat/completions"
         else:
             raise BadRequestError(f"Unknown model: {model}")
-        
-        try:
-            data = self._build_parameters(model, stream, messages, parameters)
-            headers = self._build_headers(model, data)
-        except KeyError:
-            raise InternalServerError(f"Failed to build parameters for model: {model}")
+
+        data = self._build_parameters(model, stream, messages, parameters, tools)
 
         try:
             response = post(
                 url=api_base,
-                headers=headers,
-                data=dumps(data),
+                headers=self.request_headers,
+                data=json.dumps(data),
                 timeout=timeout,
-                stream=stream
+                stream=stream,
             )
         except Exception as e:
             raise InternalServerError(f"Failed to invoke model: {e}")
-        
+
         if response.status_code != 200:
             try:
                 resp = response.json()
                 # try to parse error message
-                err = resp['error']['code']
-                msg = resp['error']['message']
+                err = resp["error"]["type"]
+                msg = resp["error"]["message"]
             except Exception as e:
-                raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
+                raise InternalServerError(
+                    f"Failed to convert response to json: {e} with text: {response.text}"
+                )
 
-            if err == 'invalid_api_key':
+            if err == "invalid_api_key":
                 raise InvalidAPIKeyError(msg)
-            elif err == 'insufficient_quota':
+            elif err == "insufficient_quota":
                 raise InsufficientAccountBalance(msg)
-            elif err == 'invalid_authentication':
+            elif err == "invalid_authentication":
                 raise InvalidAuthenticationError(msg)
-            elif 'rate' in err:
+            elif err == "invalid_request_error":
+                raise BadRequestError(msg)
+            elif "rate" in err:
                 raise RateLimitReachedError(msg)
-            elif 'internal' in err:
+            elif "internal" in err:
                 raise InternalServerError(msg)
-            elif err == 'api_key_empty':
+            elif err == "api_key_empty":
                 raise InvalidAPIKeyError(msg)
             else:
                 raise InternalServerError(f"Unknown error: {err} with message: {msg}")
-            
+
         if stream:
-            return self._handle_chat_stream_generate_response(response)
+            return response.iter_lines()
         else:
-            return self._handle_chat_generate_response(response)
+            return response.json()

+ 177 - 103
api/core/model_runtime/model_providers/baichuan/llm/llm.py

@@ -1,7 +1,12 @@
-from collections.abc import Generator
+import json
+from collections.abc import Generator, Iterator
 from typing import cast
 
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.llm_entities import (
+    LLMResult,
+    LLMResultChunk,
+    LLMResultChunkDelta,
+)
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
@@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
-from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
+from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
 from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
     BadRequestError,
     InsufficientAccountBalance,
@@ -33,19 +38,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
 
 
 class BaichuanLarguageModel(LargeLanguageModel):
-    def _invoke(self, model: str, credentials: dict,
-                prompt_messages: list[PromptMessage], model_parameters: dict,
-                tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
-                stream: bool = True, user: str | None = None) \
-            -> LLMResult | Generator:
-        return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
-                              model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
-
-    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                       tools: list[PromptMessageTool] | None = None) -> int:
+
+    def _invoke(
+            self,
+            model: str,
+            credentials: dict,
+            prompt_messages: list[PromptMessage],
+            model_parameters: dict,
+            tools: list[PromptMessageTool] | None = None,
+            stop: list[str] | None = None,
+            stream: bool = True,
+            user: str | None = None,
+    ) -> LLMResult | Generator:
+        return self._generate(
+            model=model,
+            credentials=credentials,
+            prompt_messages=prompt_messages,
+            model_parameters=model_parameters,
+            tools=tools,
+            stream=stream,
+        )
+
+    def get_num_tokens(
+            self,
+            model: str,
+            credentials: dict,
+            prompt_messages: list[PromptMessage],
+            tools: list[PromptMessageTool] | None = None,
+    ) -> int:
         return self._num_tokens_from_messages(prompt_messages)
 
-    def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
+    def _num_tokens_from_messages(
+            self,
+            messages: list[PromptMessage],
+    ) -> int:
         """Calculate num tokens for baichuan model"""
 
         def tokens(text: str):
@@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
             num_tokens += tokens_per_message
             for key, value in message.items():
                 if isinstance(value, list):
-                    text = ''
+                    text = ""
                     for item in value:
-                        if isinstance(item, dict) and item['type'] == 'text':
-                            text += item['text']
+                        if isinstance(item, dict) and item["type"] == "text":
+                            text += item["text"]
 
                     value = text
 
@@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
             message_dict = {"role": "assistant", "content": message.content}
+            if message.tool_calls:
+                message_dict["tool_calls"] = [tool_call.dict() for tool_call in
+                                              message.tool_calls]
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
-            message_dict = {"role": "user", "content": message.content}
+            message_dict = {"role": "system", "content": message.content}
         elif isinstance(message, ToolPromptMessage):
-            # copy from core/model_runtime/model_providers/anthropic/llm/llm.py
             message = cast(ToolPromptMessage, message)
             message_dict = {
-                "role": "user",
-                "content": [{
-                    "type": "tool_result",
-                    "tool_use_id": message.tool_call_id,
-                    "content": message.content
-                }]
+                "role": "tool",
+                "content": message.content,
+                "tool_call_id": message.tool_call_id
             }
         else:
             raise ValueError(f"Unknown message type {type(message)}")
@@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         # ping
-        instance = BaichuanModel(
-            api_key=credentials['api_key'],
-            secret_key=credentials.get('secret_key', '')
-        )
+        instance = BaichuanModel(api_key=credentials["api_key"])
 
         try:
-            instance.generate(model=model, stream=False, messages=[
-                BaichuanMessage(content='ping', role='user')
-            ], parameters={
-                'max_tokens': 1,
-            }, timeout=60)
+            instance.generate(
+                model=model,
+                stream=False,
+                messages=[{"content": "ping", "role": "user"}],
+                parameters={
+                    "max_tokens": 1,
+                },
+                timeout=60,
+            )
         except Exception as e:
             raise CredentialsValidateFailedError(f"Invalid API key: {e}")
 
-    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                  model_parameters: dict, tools: list[PromptMessageTool] | None = None,
-                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-            -> LLMResult | Generator:
-        if tools is not None and len(tools) > 0:
-            raise InvokeBadRequestError("Baichuan model doesn't support tools")
-
-        instance = BaichuanModel(
-            api_key=credentials['api_key'],
-            secret_key=credentials.get('secret_key', '')
-        )
+    def _generate(
+            self,
+            model: str,
+            credentials: dict,
+            prompt_messages: list[PromptMessage],
+            model_parameters: dict,
+            tools: list[PromptMessageTool] | None = None,
+            stream: bool = True,
+    ) -> LLMResult | Generator:
 
-        # convert prompt messages to baichuan messages
-        messages = [
-            BaichuanMessage(
-                content=message.content if isinstance(message.content, str) else ''.join([
-                    content.data for content in message.content
-                ]),
-                role=message.role.value
-            ) for message in prompt_messages
-        ]
+        instance = BaichuanModel(api_key=credentials["api_key"])
+        messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
 
         # invoke model
-        response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
-                                     timeout=60)
+        response = instance.generate(
+            model=model,
+            stream=stream,
+            messages=messages,
+            parameters=model_parameters,
+            timeout=60,
+            tools=tools,
+        )
 
         if stream:
-            return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
-
-        return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
-
-    def _handle_chat_generate_response(self, model: str,
-                                       prompt_messages: list[PromptMessage],
-                                       credentials: dict,
-                                       response: BaichuanMessage) -> LLMResult:
-        # convert baichuan message to llm result
-        usage = self._calc_response_usage(model=model, credentials=credentials,
-                                          prompt_tokens=response.usage['prompt_tokens'],
-                                          completion_tokens=response.usage['completion_tokens'])
+            return self._handle_chat_generate_stream_response(
+                model, prompt_messages, credentials, response
+            )
+
+        return self._handle_chat_generate_response(
+            model, prompt_messages, credentials, response
+        )
+
+    def _handle_chat_generate_response(
+            self,
+            model: str,
+            prompt_messages: list[PromptMessage],
+            credentials: dict,
+            response: dict,
+    ) -> LLMResult:
+        choices = response.get("choices", [])
+        assistant_message = AssistantPromptMessage(content='', tool_calls=[])
+        if choices and choices[0]["finish_reason"] == "tool_calls":
+            for choice in choices:
+                for tool_call in choice["message"]["tool_calls"]:
+                    tool = AssistantPromptMessage.ToolCall(
+                        id=tool_call.get("id", ""),
+                        type=tool_call.get("type", ""),
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=tool_call.get("function", {}).get("name", ""),
+                            arguments=tool_call.get("function", {}).get("arguments", "")
+                        ),
+                    )
+                    assistant_message.tool_calls.append(tool)
+        else:
+            for choice in choices:
+                assistant_message.content += choice["message"]["content"]
+                assistant_message.role = choice["message"]["role"]
+
+        usage = response.get("usage")
+        if usage:
+            # transform usage
+            prompt_tokens = usage["prompt_tokens"]
+            completion_tokens = usage["completion_tokens"]
+        else:
+            # calculate num tokens
+            prompt_tokens = self._num_tokens_from_messages(prompt_messages)
+            completion_tokens = self._num_tokens_from_messages([assistant_message])
+
+        usage = self._calc_response_usage(
+            model=model,
+            credentials=credentials,
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+        )
+
         return LLMResult(
             model=model,
             prompt_messages=prompt_messages,
-            message=AssistantPromptMessage(
-                content=response.content,
-                tool_calls=[]
-            ),
+            message=assistant_message,
             usage=usage,
         )
 
-    def _handle_chat_generate_stream_response(self, model: str,
-                                              prompt_messages: list[PromptMessage],
-                                              credentials: dict,
-                                              response: Generator[BaichuanMessage, None, None]) -> Generator:
-        for message in response:
-            if message.usage:
-                usage = self._calc_response_usage(model=model, credentials=credentials,
-                                                  prompt_tokens=message.usage['prompt_tokens'],
-                                                  completion_tokens=message.usage['completion_tokens'])
+    def _handle_chat_generate_stream_response(
+            self,
+            model: str,
+            prompt_messages: list[PromptMessage],
+            credentials: dict,
+            response: Iterator,
+    ) -> Generator:
+        for line in response:
+            if not line:
+                continue
+            line = line.decode("utf-8")
+            # remove the first `data: ` prefix
+            if line.startswith("data:"):
+                line = line[5:].strip()
+            try:
+                data = json.loads(line)
+            except Exception as e:
+                if line.strip() == "[DONE]":
+                    return
+            choices = data.get("choices", [])
+
+            stop_reason = ""
+            for choice in choices:
+                if choice.get("finish_reason"):
+                    stop_reason = choice["finish_reason"]
+
+                if len(choice["delta"]["content"]) == 0:
+                    continue
                 yield LLMResultChunk(
                     model=model,
                     prompt_messages=prompt_messages,
                     delta=LLMResultChunkDelta(
                         index=0,
                         message=AssistantPromptMessage(
-                            content=message.content,
-                            tool_calls=[]
+                            content=choice["delta"]["content"], tool_calls=[]
                         ),
-                        usage=usage,
-                        finish_reason=message.stop_reason if message.stop_reason else None,
+                        finish_reason=stop_reason,
                     ),
                 )
-            else:
+
+            # if there is usage, the response is the last one, yield it and return
+            if "usage" in data:
+                usage = self._calc_response_usage(
+                    model=model,
+                    credentials=credentials,
+                    prompt_tokens=data["usage"]["prompt_tokens"],
+                    completion_tokens=data["usage"]["completion_tokens"],
+                )
                 yield LLMResultChunk(
                     model=model,
                     prompt_messages=prompt_messages,
                     delta=LLMResultChunkDelta(
                         index=0,
-                        message=AssistantPromptMessage(
-                            content=message.content,
-                            tool_calls=[]
-                        ),
-                        finish_reason=message.stop_reason if message.stop_reason else None,
+                        message=AssistantPromptMessage(content="", tool_calls=[]),
+                        usage=usage,
+                        finish_reason=stop_reason,
                     ),
                 )
 
@@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
         :return: Invoke error mapping
         """
         return {
-            InvokeConnectionError: [
-            ],
-            InvokeServerUnavailableError: [
-                InternalServerError
-            ],
-            InvokeRateLimitError: [
-                RateLimitReachedError
-            ],
+            InvokeConnectionError: [],
+            InvokeServerUnavailableError: [InternalServerError],
+            InvokeRateLimitError: [RateLimitReachedError],
             InvokeAuthorizationError: [
                 InvalidAuthenticationError,
                 InsufficientAccountBalance,
                 InvalidAPIKeyError,
             ],
-            InvokeBadRequestError: [
-                BadRequestError,
-                KeyError
-            ]
+            InvokeBadRequestError: [BadRequestError, KeyError],
         }