|
@@ -51,6 +51,9 @@ class BaichuanModel:
|
|
|
'baichuan2-turbo': 'Baichuan2-Turbo',
|
|
|
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
|
|
'baichuan2-53b': 'Baichuan2-53B',
|
|
|
+ 'baichuan3-turbo': 'Baichuan3-Turbo',
|
|
|
+ 'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
|
|
|
+ 'baichuan4': 'Baichuan4',
|
|
|
}[model]
|
|
|
|
|
|
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
|
@@ -110,7 +113,8 @@ class BaichuanModel:
|
|
|
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
|
|
parameters: dict[str, Any]) \
|
|
|
-> dict[str, Any]:
|
|
|
- if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
|
|
+ if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
|
|
+ or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
|
|
prompt_messages = []
|
|
|
for message in messages:
|
|
|
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
|
@@ -143,7 +147,8 @@ class BaichuanModel:
|
|
|
raise BadRequestError(f"Unknown model: {model}")
|
|
|
|
|
|
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
|
|
- if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
|
|
+ if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
|
|
+ or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
|
|
# there is no secret key for turbo api
|
|
|
return {
|
|
|
'Content-Type': 'application/json',
|
|
@@ -160,7 +165,8 @@ class BaichuanModel:
|
|
|
parameters: dict[str, Any], timeout: int) \
|
|
|
-> Union[Generator, BaichuanMessage]:
|
|
|
|
|
|
- if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
|
|
+ if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
|
|
+ or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
|
|
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
|
|
else:
|
|
|
raise BadRequestError(f"Unknown model: {model}")
|