|
@@ -1,7 +1,12 @@
|
|
|
-from collections.abc import Generator
|
|
|
+import json
|
|
|
+from collections.abc import Generator, Iterator
|
|
|
from typing import cast
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
+from core.model_runtime.entities.llm_entities import (
|
|
|
+ LLMResult,
|
|
|
+ LLMResultChunk,
|
|
|
+ LLMResultChunkDelta,
|
|
|
+)
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
PromptMessage,
|
|
@@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
|
|
-from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
|
|
|
+from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
|
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
|
|
BadRequestError,
|
|
|
InsufficientAccountBalance,
|
|
@@ -33,19 +38,40 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|
|
|
|
|
|
|
|
class BaichuanLarguageModel(LargeLanguageModel):
|
|
|
- def _invoke(self, model: str, credentials: dict,
|
|
|
- prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
- tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
|
|
- stream: bool = True, user: str | None = None) \
|
|
|
- -> LLMResult | Generator:
|
|
|
- return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
|
|
- model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
|
|
-
|
|
|
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
- tools: list[PromptMessageTool] | None = None) -> int:
|
|
|
+
|
|
|
+ def _invoke(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict,
|
|
|
+ tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None,
|
|
|
+ stream: bool = True,
|
|
|
+ user: str | None = None,
|
|
|
+ ) -> LLMResult | Generator:
|
|
|
+ return self._generate(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ model_parameters=model_parameters,
|
|
|
+ tools=tools,
|
|
|
+ stream=stream,
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_num_tokens(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool] | None = None,
|
|
|
+ ) -> int:
|
|
|
return self._num_tokens_from_messages(prompt_messages)
|
|
|
|
|
|
- def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
|
|
+ def _num_tokens_from_messages(
|
|
|
+ self,
|
|
|
+ messages: list[PromptMessage],
|
|
|
+ ) -> int:
|
|
|
"""Calculate num tokens for baichuan model"""
|
|
|
|
|
|
def tokens(text: str):
|
|
@@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|
|
num_tokens += tokens_per_message
|
|
|
for key, value in message.items():
|
|
|
if isinstance(value, list):
|
|
|
- text = ''
|
|
|
+ text = ""
|
|
|
for item in value:
|
|
|
- if isinstance(item, dict) and item['type'] == 'text':
|
|
|
- text += item['text']
|
|
|
+ if isinstance(item, dict) and item["type"] == "text":
|
|
|
+ text += item["text"]
|
|
|
|
|
|
value = text
|
|
|
|
|
@@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
|
+ if message.tool_calls:
|
|
|
+ message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
|
|
+ message.tool_calls]
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
- message_dict = {"role": "user", "content": message.content}
|
|
|
+ message_dict = {"role": "system", "content": message.content}
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
- # copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
|
|
message = cast(ToolPromptMessage, message)
|
|
|
message_dict = {
|
|
|
- "role": "user",
|
|
|
- "content": [{
|
|
|
- "type": "tool_result",
|
|
|
- "tool_use_id": message.tool_call_id,
|
|
|
- "content": message.content
|
|
|
- }]
|
|
|
+ "role": "tool",
|
|
|
+ "content": message.content,
|
|
|
+ "tool_call_id": message.tool_call_id
|
|
|
}
|
|
|
else:
|
|
|
raise ValueError(f"Unknown message type {type(message)}")
|
|
@@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
# ping
|
|
|
- instance = BaichuanModel(
|
|
|
- api_key=credentials['api_key'],
|
|
|
- secret_key=credentials.get('secret_key', '')
|
|
|
- )
|
|
|
+ instance = BaichuanModel(api_key=credentials["api_key"])
|
|
|
|
|
|
try:
|
|
|
- instance.generate(model=model, stream=False, messages=[
|
|
|
- BaichuanMessage(content='ping', role='user')
|
|
|
- ], parameters={
|
|
|
- 'max_tokens': 1,
|
|
|
- }, timeout=60)
|
|
|
+ instance.generate(
|
|
|
+ model=model,
|
|
|
+ stream=False,
|
|
|
+ messages=[{"content": "ping", "role": "user"}],
|
|
|
+ parameters={
|
|
|
+ "max_tokens": 1,
|
|
|
+ },
|
|
|
+ timeout=60,
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
|
|
|
|
|
- def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
- model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
- stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
- -> LLMResult | Generator:
|
|
|
- if tools is not None and len(tools) > 0:
|
|
|
- raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
|
|
-
|
|
|
- instance = BaichuanModel(
|
|
|
- api_key=credentials['api_key'],
|
|
|
- secret_key=credentials.get('secret_key', '')
|
|
|
- )
|
|
|
+ def _generate(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict,
|
|
|
+ tools: list[PromptMessageTool] | None = None,
|
|
|
+ stream: bool = True,
|
|
|
+ ) -> LLMResult | Generator:
|
|
|
|
|
|
- # convert prompt messages to baichuan messages
|
|
|
- messages = [
|
|
|
- BaichuanMessage(
|
|
|
- content=message.content if isinstance(message.content, str) else ''.join([
|
|
|
- content.data for content in message.content
|
|
|
- ]),
|
|
|
- role=message.role.value
|
|
|
- ) for message in prompt_messages
|
|
|
- ]
|
|
|
+ instance = BaichuanModel(api_key=credentials["api_key"])
|
|
|
+ messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
|
|
|
|
|
# invoke model
|
|
|
- response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
|
|
- timeout=60)
|
|
|
+ response = instance.generate(
|
|
|
+ model=model,
|
|
|
+ stream=stream,
|
|
|
+ messages=messages,
|
|
|
+ parameters=model_parameters,
|
|
|
+ timeout=60,
|
|
|
+ tools=tools,
|
|
|
+ )
|
|
|
|
|
|
if stream:
|
|
|
- return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
|
|
-
|
|
|
- return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
|
|
-
|
|
|
- def _handle_chat_generate_response(self, model: str,
|
|
|
- prompt_messages: list[PromptMessage],
|
|
|
- credentials: dict,
|
|
|
- response: BaichuanMessage) -> LLMResult:
|
|
|
- # convert baichuan message to llm result
|
|
|
- usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
- prompt_tokens=response.usage['prompt_tokens'],
|
|
|
- completion_tokens=response.usage['completion_tokens'])
|
|
|
+ return self._handle_chat_generate_stream_response(
|
|
|
+ model, prompt_messages, credentials, response
|
|
|
+ )
|
|
|
+
|
|
|
+ return self._handle_chat_generate_response(
|
|
|
+ model, prompt_messages, credentials, response
|
|
|
+ )
|
|
|
+
|
|
|
+ def _handle_chat_generate_response(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ credentials: dict,
|
|
|
+ response: dict,
|
|
|
+ ) -> LLMResult:
|
|
|
+ choices = response.get("choices", [])
|
|
|
+ assistant_message = AssistantPromptMessage(content='', tool_calls=[])
|
|
|
+ if choices and choices[0]["finish_reason"] == "tool_calls":
|
|
|
+ for choice in choices:
|
|
|
+ for tool_call in choice["message"]["tool_calls"]:
|
|
|
+ tool = AssistantPromptMessage.ToolCall(
|
|
|
+ id=tool_call.get("id", ""),
|
|
|
+ type=tool_call.get("type", ""),
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=tool_call.get("function", {}).get("name", ""),
|
|
|
+ arguments=tool_call.get("function", {}).get("arguments", "")
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ assistant_message.tool_calls.append(tool)
|
|
|
+ else:
|
|
|
+ for choice in choices:
|
|
|
+ assistant_message.content += choice["message"]["content"]
|
|
|
+ assistant_message.role = choice["message"]["role"]
|
|
|
+
|
|
|
+ usage = response.get("usage")
|
|
|
+ if usage:
|
|
|
+ # transform usage
|
|
|
+ prompt_tokens = usage["prompt_tokens"]
|
|
|
+ completion_tokens = usage["completion_tokens"]
|
|
|
+ else:
|
|
|
+ # calculate num tokens
|
|
|
+ prompt_tokens = self._num_tokens_from_messages(prompt_messages)
|
|
|
+ completion_tokens = self._num_tokens_from_messages([assistant_message])
|
|
|
+
|
|
|
+ usage = self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ completion_tokens=completion_tokens,
|
|
|
+ )
|
|
|
+
|
|
|
return LLMResult(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
- message=AssistantPromptMessage(
|
|
|
- content=response.content,
|
|
|
- tool_calls=[]
|
|
|
- ),
|
|
|
+ message=assistant_message,
|
|
|
usage=usage,
|
|
|
)
|
|
|
|
|
|
- def _handle_chat_generate_stream_response(self, model: str,
|
|
|
- prompt_messages: list[PromptMessage],
|
|
|
- credentials: dict,
|
|
|
- response: Generator[BaichuanMessage, None, None]) -> Generator:
|
|
|
- for message in response:
|
|
|
- if message.usage:
|
|
|
- usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
- prompt_tokens=message.usage['prompt_tokens'],
|
|
|
- completion_tokens=message.usage['completion_tokens'])
|
|
|
+ def _handle_chat_generate_stream_response(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ credentials: dict,
|
|
|
+ response: Iterator,
|
|
|
+ ) -> Generator:
|
|
|
+ for line in response:
|
|
|
+ if not line:
|
|
|
+ continue
|
|
|
+ line = line.decode("utf-8")
|
|
|
+ # remove the first `data: ` prefix
|
|
|
+ if line.startswith("data:"):
|
|
|
+ line = line[5:].strip()
|
|
|
+ try:
|
|
|
+ data = json.loads(line)
|
|
|
+ except Exception as e:
|
|
|
+ if line.strip() == "[DONE]":
|
|
|
+ return
|
|
|
+ choices = data.get("choices", [])
|
|
|
+
|
|
|
+ stop_reason = ""
|
|
|
+ for choice in choices:
|
|
|
+ if choice.get("finish_reason"):
|
|
|
+ stop_reason = choice["finish_reason"]
|
|
|
+
|
|
|
+ if len(choice["delta"]["content"]) == 0:
|
|
|
+ continue
|
|
|
yield LLMResultChunk(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
delta=LLMResultChunkDelta(
|
|
|
index=0,
|
|
|
message=AssistantPromptMessage(
|
|
|
- content=message.content,
|
|
|
- tool_calls=[]
|
|
|
+ content=choice["delta"]["content"], tool_calls=[]
|
|
|
),
|
|
|
- usage=usage,
|
|
|
- finish_reason=message.stop_reason if message.stop_reason else None,
|
|
|
+ finish_reason=stop_reason,
|
|
|
),
|
|
|
)
|
|
|
- else:
|
|
|
+
|
|
|
+ # if there is usage, the response is the last one, yield it and return
|
|
|
+ if "usage" in data:
|
|
|
+ usage = self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ prompt_tokens=data["usage"]["prompt_tokens"],
|
|
|
+ completion_tokens=data["usage"]["completion_tokens"],
|
|
|
+ )
|
|
|
yield LLMResultChunk(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
delta=LLMResultChunkDelta(
|
|
|
index=0,
|
|
|
- message=AssistantPromptMessage(
|
|
|
- content=message.content,
|
|
|
- tool_calls=[]
|
|
|
- ),
|
|
|
- finish_reason=message.stop_reason if message.stop_reason else None,
|
|
|
+ message=AssistantPromptMessage(content="", tool_calls=[]),
|
|
|
+ usage=usage,
|
|
|
+ finish_reason=stop_reason,
|
|
|
),
|
|
|
)
|
|
|
|
|
@@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|
|
:return: Invoke error mapping
|
|
|
"""
|
|
|
return {
|
|
|
- InvokeConnectionError: [
|
|
|
- ],
|
|
|
- InvokeServerUnavailableError: [
|
|
|
- InternalServerError
|
|
|
- ],
|
|
|
- InvokeRateLimitError: [
|
|
|
- RateLimitReachedError
|
|
|
- ],
|
|
|
+ InvokeConnectionError: [],
|
|
|
+ InvokeServerUnavailableError: [InternalServerError],
|
|
|
+ InvokeRateLimitError: [RateLimitReachedError],
|
|
|
InvokeAuthorizationError: [
|
|
|
InvalidAuthenticationError,
|
|
|
InsufficientAccountBalance,
|
|
|
InvalidAPIKeyError,
|
|
|
],
|
|
|
- InvokeBadRequestError: [
|
|
|
- BadRequestError,
|
|
|
- KeyError
|
|
|
- ]
|
|
|
+ InvokeBadRequestError: [BadRequestError, KeyError],
|
|
|
}
|