|
@@ -1,20 +1,38 @@
|
|
|
+import json
|
|
|
import logging
|
|
|
-from collections.abc import Generator
|
|
|
+from collections.abc import Generator, Iterator
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
|
import cohere
|
|
|
-from cohere.responses import Chat, Generations
|
|
|
-from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
|
|
|
-from cohere.responses.generation import StreamingGenerations, StreamingText
|
|
|
+from cohere import (
|
|
|
+ ChatMessage,
|
|
|
+ ChatStreamRequestToolResultsItem,
|
|
|
+ GenerateStreamedResponse,
|
|
|
+ GenerateStreamedResponse_StreamEnd,
|
|
|
+ GenerateStreamedResponse_StreamError,
|
|
|
+ GenerateStreamedResponse_TextGeneration,
|
|
|
+ Generation,
|
|
|
+ NonStreamedChatResponse,
|
|
|
+ StreamedChatResponse,
|
|
|
+ StreamedChatResponse_StreamEnd,
|
|
|
+ StreamedChatResponse_TextGeneration,
|
|
|
+ StreamedChatResponse_ToolCallsGeneration,
|
|
|
+ Tool,
|
|
|
+ ToolCall,
|
|
|
+ ToolParameterDefinitionsValue,
|
|
|
+)
|
|
|
+from cohere.core import RequestOptions
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
PromptMessage,
|
|
|
PromptMessageContentType,
|
|
|
+ PromptMessageRole,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
TextPromptMessageContent,
|
|
|
+ ToolPromptMessage,
|
|
|
UserPromptMessage,
|
|
|
)
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
|
@@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
credentials=credentials,
|
|
|
prompt_messages=prompt_messages,
|
|
|
model_parameters=model_parameters,
|
|
|
+ tools=tools,
|
|
|
stop=stop,
|
|
|
stream=stream,
|
|
|
user=user
|
|
@@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
if stop:
|
|
|
model_parameters['end_sequences'] = stop
|
|
|
|
|
|
- response = client.generate(
|
|
|
- prompt=prompt_messages[0].content,
|
|
|
- model=model,
|
|
|
- stream=stream,
|
|
|
- **model_parameters,
|
|
|
- )
|
|
|
-
|
|
|
if stream:
|
|
|
+ response = client.generate_stream(
|
|
|
+ prompt=prompt_messages[0].content,
|
|
|
+ model=model,
|
|
|
+ **model_parameters,
|
|
|
+ request_options=RequestOptions(max_retries=0)
|
|
|
+ )
|
|
|
+
|
|
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
|
+ else:
|
|
|
+ response = client.generate(
|
|
|
+ prompt=prompt_messages[0].content,
|
|
|
+ model=model,
|
|
|
+ **model_parameters,
|
|
|
+ request_options=RequestOptions(max_retries=0)
|
|
|
+ )
|
|
|
|
|
|
- return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
|
+ return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
|
|
|
|
- def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
|
|
|
+ def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
|
|
|
prompt_messages: list[PromptMessage]) \
|
|
|
-> LLMResult:
|
|
|
"""
|
|
@@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
# calculate num tokens
|
|
|
- prompt_tokens = response.meta['billed_units']['input_tokens']
|
|
|
- completion_tokens = response.meta['billed_units']['output_tokens']
|
|
|
+ prompt_tokens = int(response.meta.billed_units.input_tokens)
|
|
|
+ completion_tokens = int(response.meta.billed_units.output_tokens)
|
|
|
|
|
|
# transform usage
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
@@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
return response
|
|
|
|
|
|
- def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
|
|
|
+ def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
|
|
|
prompt_messages: list[PromptMessage]) -> Generator:
|
|
|
"""
|
|
|
Handle llm stream response
|
|
@@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
index = 1
|
|
|
full_assistant_content = ''
|
|
|
for chunk in response:
|
|
|
- if isinstance(chunk, StreamingText):
|
|
|
- chunk = cast(StreamingText, chunk)
|
|
|
+ if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
|
|
|
+ chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
|
|
|
text = chunk.text
|
|
|
|
|
|
if text is None:
|
|
@@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
index += 1
|
|
|
- elif chunk is None:
|
|
|
+ elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
|
|
|
+ chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
|
|
|
+
|
|
|
# calculate num tokens
|
|
|
- prompt_tokens = response.meta['billed_units']['input_tokens']
|
|
|
- completion_tokens = response.meta['billed_units']['output_tokens']
|
|
|
+ prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
|
|
+ completion_tokens = self._num_tokens_from_messages(
|
|
|
+ model,
|
|
|
+ credentials,
|
|
|
+ [AssistantPromptMessage(content=full_assistant_content)]
|
|
|
+ )
|
|
|
|
|
|
# transform usage
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
@@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
delta=LLMResultChunkDelta(
|
|
|
index=index,
|
|
|
message=AssistantPromptMessage(content=''),
|
|
|
- finish_reason=response.finish_reason,
|
|
|
+ finish_reason=chunk.finish_reason,
|
|
|
usage=usage
|
|
|
)
|
|
|
)
|
|
|
break
|
|
|
+ elif isinstance(chunk, GenerateStreamedResponse_StreamError):
|
|
|
+ chunk = cast(GenerateStreamedResponse_StreamError, chunk)
|
|
|
+ raise InvokeBadRequestError(chunk.err)
|
|
|
|
|
|
def _chat_generate(self, model: str, credentials: dict,
|
|
|
- prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
|
|
+ prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
"""
|
|
|
Invoke llm chat model
|
|
@@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
:param credentials: credentials
|
|
|
:param prompt_messages: prompt messages
|
|
|
:param model_parameters: model parameters
|
|
|
+ :param tools: tools for tool calling
|
|
|
:param stop: stop words
|
|
|
:param stream: is stream response
|
|
|
:param user: unique user id
|
|
@@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
# initialize client
|
|
|
client = cohere.Client(credentials.get('api_key'))
|
|
|
|
|
|
- if user:
|
|
|
- model_parameters['user_name'] = user
|
|
|
+ if stop:
|
|
|
+ model_parameters['stop_sequences'] = stop
|
|
|
+
|
|
|
+ if tools:
|
|
|
+ model_parameters['tools'] = self._convert_tools(tools)
|
|
|
|
|
|
- message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
|
|
+ message, chat_histories, tool_results \
|
|
|
+ = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
|
|
+
|
|
|
+ if tool_results:
|
|
|
+ model_parameters['tool_results'] = tool_results
|
|
|
|
|
|
# chat model
|
|
|
real_model = model
|
|
|
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
|
|
real_model = model.removesuffix('-chat')
|
|
|
|
|
|
- response = client.chat(
|
|
|
- message=message,
|
|
|
- chat_history=chat_histories,
|
|
|
- model=real_model,
|
|
|
- stream=stream,
|
|
|
- **model_parameters,
|
|
|
- )
|
|
|
-
|
|
|
if stream:
|
|
|
- return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
|
|
|
+ response = client.chat_stream(
|
|
|
+ message=message,
|
|
|
+ chat_history=chat_histories,
|
|
|
+ model=real_model,
|
|
|
+ **model_parameters,
|
|
|
+ request_options=RequestOptions(max_retries=0)
|
|
|
+ )
|
|
|
|
|
|
- return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
|
|
|
+ return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
|
|
+ else:
|
|
|
+ response = client.chat(
|
|
|
+ message=message,
|
|
|
+ chat_history=chat_histories,
|
|
|
+ model=real_model,
|
|
|
+ **model_parameters,
|
|
|
+ request_options=RequestOptions(max_retries=0)
|
|
|
+ )
|
|
|
|
|
|
- def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
|
|
|
- prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
|
|
|
+ return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
|
|
+
|
|
|
+ def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
|
|
|
+ prompt_messages: list[PromptMessage]) \
|
|
|
-> LLMResult:
|
|
|
"""
|
|
|
Handle llm chat response
|
|
@@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
:param credentials: credentials
|
|
|
:param response: response
|
|
|
:param prompt_messages: prompt messages
|
|
|
- :param stop: stop words
|
|
|
:return: llm response
|
|
|
"""
|
|
|
assistant_text = response.text
|
|
|
|
|
|
+ tool_calls = []
|
|
|
+ if response.tool_calls:
|
|
|
+ for cohere_tool_call in response.tool_calls:
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=cohere_tool_call.name,
|
|
|
+ type='function',
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=cohere_tool_call.name,
|
|
|
+ arguments=json.dumps(cohere_tool_call.parameters)
|
|
|
+ )
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
+
|
|
|
# transform assistant message to prompt message
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=assistant_text
|
|
|
+ content=assistant_text,
|
|
|
+ tool_calls=tool_calls
|
|
|
)
|
|
|
|
|
|
# calculate num tokens
|
|
@@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
# transform usage
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
|
|
|
- if stop:
|
|
|
- # enforce stop tokens
|
|
|
- assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
|
|
- assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=assistant_text
|
|
|
- )
|
|
|
-
|
|
|
# transform response
|
|
|
response = LLMResult(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
message=assistant_prompt_message,
|
|
|
- usage=usage,
|
|
|
- system_fingerprint=response.preamble
|
|
|
+ usage=usage
|
|
|
)
|
|
|
|
|
|
return response
|
|
|
|
|
|
- def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
|
|
|
- prompt_messages: list[PromptMessage],
|
|
|
- stop: Optional[list[str]] = None) -> Generator:
|
|
|
+ def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
|
|
+ response: Iterator[StreamedChatResponse],
|
|
|
+ prompt_messages: list[PromptMessage]) -> Generator:
|
|
|
"""
|
|
|
Handle llm chat stream response
|
|
|
|
|
|
:param model: model name
|
|
|
:param response: response
|
|
|
:param prompt_messages: prompt messages
|
|
|
- :param stop: stop words
|
|
|
:return: llm response chunk generator
|
|
|
"""
|
|
|
|
|
|
- def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
|
|
|
- preamble: Optional[str] = None) -> LLMResultChunk:
|
|
|
+ def final_response(full_text: str,
|
|
|
+ tool_calls: list[AssistantPromptMessage.ToolCall],
|
|
|
+ index: int,
|
|
|
+ finish_reason: Optional[str] = None) -> LLMResultChunk:
|
|
|
# calculate num tokens
|
|
|
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
|
|
|
|
|
full_assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=full_text
|
|
|
+ content=full_text,
|
|
|
+ tool_calls=tool_calls
|
|
|
)
|
|
|
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
|
|
|
|
@@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
return LLMResultChunk(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
- system_fingerprint=preamble,
|
|
|
delta=LLMResultChunkDelta(
|
|
|
index=index,
|
|
|
- message=AssistantPromptMessage(content=''),
|
|
|
+ message=AssistantPromptMessage(content='', tool_calls=tool_calls),
|
|
|
finish_reason=finish_reason,
|
|
|
usage=usage
|
|
|
)
|
|
@@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
index = 1
|
|
|
full_assistant_content = ''
|
|
|
+ tool_calls = []
|
|
|
for chunk in response:
|
|
|
- if isinstance(chunk, StreamTextGeneration):
|
|
|
- chunk = cast(StreamTextGeneration, chunk)
|
|
|
+ if isinstance(chunk, StreamedChatResponse_TextGeneration):
|
|
|
+ chunk = cast(StreamedChatResponse_TextGeneration, chunk)
|
|
|
text = chunk.text
|
|
|
|
|
|
if text is None:
|
|
@@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
content=text
|
|
|
)
|
|
|
|
|
|
- # stop
|
|
|
- # notice: This logic can only cover few stop scenarios
|
|
|
- if stop and text in stop:
|
|
|
- yield final_response(full_assistant_content, index, 'stop')
|
|
|
- break
|
|
|
-
|
|
|
full_assistant_content += text
|
|
|
|
|
|
yield LLMResultChunk(
|
|
@@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
index += 1
|
|
|
- elif isinstance(chunk, StreamEnd):
|
|
|
- chunk = cast(StreamEnd, chunk)
|
|
|
- yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
|
|
|
+ elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
|
|
|
+ chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
|
|
|
+
|
|
|
+ tool_calls = []
|
|
|
+ if chunk.tool_calls:
|
|
|
+ for cohere_tool_call in chunk.tool_calls:
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=cohere_tool_call.name,
|
|
|
+ type='function',
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=cohere_tool_call.name,
|
|
|
+ arguments=json.dumps(cohere_tool_call.parameters)
|
|
|
+ )
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
+ elif isinstance(chunk, StreamedChatResponse_StreamEnd):
|
|
|
+ chunk = cast(StreamedChatResponse_StreamEnd, chunk)
|
|
|
+ yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
|
|
|
index += 1
|
|
|
|
|
|
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
|
|
- -> tuple[str, list[dict]]:
|
|
|
+ -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
|
|
|
"""
|
|
|
Convert prompt messages to message and chat histories
|
|
|
:param prompt_messages: prompt messages
|
|
|
:return:
|
|
|
"""
|
|
|
chat_histories = []
|
|
|
+ latest_tool_call_n_outputs = []
|
|
|
for prompt_message in prompt_messages:
|
|
|
- chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
|
|
|
+ if prompt_message.role == PromptMessageRole.ASSISTANT:
|
|
|
+ prompt_message = cast(AssistantPromptMessage, prompt_message)
|
|
|
+ if prompt_message.tool_calls:
|
|
|
+ for tool_call in prompt_message.tool_calls:
|
|
|
+ latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
|
|
|
+ call=ToolCall(
|
|
|
+ name=tool_call.function.name,
|
|
|
+ parameters=json.loads(tool_call.function.arguments)
|
|
|
+ ),
|
|
|
+ outputs=[]
|
|
|
+ ))
|
|
|
+ else:
|
|
|
+ cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
|
|
+ if cohere_prompt_message:
|
|
|
+ chat_histories.append(cohere_prompt_message)
|
|
|
+ elif prompt_message.role == PromptMessageRole.TOOL:
|
|
|
+ prompt_message = cast(ToolPromptMessage, prompt_message)
|
|
|
+ if latest_tool_call_n_outputs:
|
|
|
+ i = 0
|
|
|
+ for tool_call_n_outputs in latest_tool_call_n_outputs:
|
|
|
+ if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
|
|
|
+ latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
|
|
|
+ call=ToolCall(
|
|
|
+ name=tool_call_n_outputs.call.name,
|
|
|
+ parameters=tool_call_n_outputs.call.parameters
|
|
|
+ ),
|
|
|
+ outputs=[{
|
|
|
+ "result": prompt_message.content
|
|
|
+ }]
|
|
|
+ )
|
|
|
+ break
|
|
|
+ i += 1
|
|
|
+ else:
|
|
|
+ cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
|
|
+ if cohere_prompt_message:
|
|
|
+ chat_histories.append(cohere_prompt_message)
|
|
|
+
|
|
|
+ if latest_tool_call_n_outputs:
|
|
|
+ new_latest_tool_call_n_outputs = []
|
|
|
+ for tool_call_n_outputs in latest_tool_call_n_outputs:
|
|
|
+ if tool_call_n_outputs.outputs:
|
|
|
+ new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
|
|
|
+
|
|
|
+ latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
|
|
|
|
|
|
# get latest message from chat histories and pop it
|
|
|
if len(chat_histories) > 0:
|
|
|
latest_message = chat_histories.pop()
|
|
|
- message = latest_message['message']
|
|
|
+ message = latest_message.message
|
|
|
else:
|
|
|
raise ValueError('Prompt messages is empty')
|
|
|
|
|
|
- return message, chat_histories
|
|
|
+ return message, chat_histories, latest_tool_call_n_outputs
|
|
|
|
|
|
- def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
+ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
|
|
|
"""
|
|
|
Convert PromptMessage to dict for Cohere model
|
|
|
"""
|
|
|
if isinstance(message, UserPromptMessage):
|
|
|
message = cast(UserPromptMessage, message)
|
|
|
if isinstance(message.content, str):
|
|
|
- message_dict = {"role": "USER", "message": message.content}
|
|
|
+ chat_message = ChatMessage(role="USER", message=message.content)
|
|
|
else:
|
|
|
sub_message_text = ''
|
|
|
for message_content in message.content:
|
|
@@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
message_content = cast(TextPromptMessageContent, message_content)
|
|
|
sub_message_text += message_content.data
|
|
|
|
|
|
- message_dict = {"role": "USER", "message": sub_message_text}
|
|
|
+ chat_message = ChatMessage(role="USER", message=sub_message_text)
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
- message_dict = {"role": "CHATBOT", "message": message.content}
|
|
|
+ if not message.content:
|
|
|
+ return None
|
|
|
+ chat_message = ChatMessage(role="CHATBOT", message=message.content)
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
- message_dict = {"role": "USER", "message": message.content}
|
|
|
+ chat_message = ChatMessage(role="USER", message=message.content)
|
|
|
+ elif isinstance(message, ToolPromptMessage):
|
|
|
+ return None
|
|
|
else:
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
- if message.name:
|
|
|
- message_dict["user_name"] = message.name
|
|
|
+ return chat_message
|
|
|
+
|
|
|
+ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
|
|
|
+ """
|
|
|
+ Convert tools to Cohere model
|
|
|
+ """
|
|
|
+ cohere_tools = []
|
|
|
+ for tool in tools:
|
|
|
+ properties = tool.parameters['properties']
|
|
|
+ required_properties = tool.parameters['required']
|
|
|
+
|
|
|
+ parameter_definitions = {}
|
|
|
+ for p_key, p_val in properties.items():
|
|
|
+ required = False
|
|
|
+ if property in required_properties:
|
|
|
+ required = True
|
|
|
+
|
|
|
+ desc = p_val['description']
|
|
|
+ if 'enum' in p_val:
|
|
|
+ desc += (f"; Only accepts one of the following predefined options: "
|
|
|
+ f"[{', '.join(p_val['enum'])}]")
|
|
|
+
|
|
|
+ parameter_definitions[p_key] = ToolParameterDefinitionsValue(
|
|
|
+ description=desc,
|
|
|
+ type=p_val['type'],
|
|
|
+ required=required
|
|
|
+ )
|
|
|
|
|
|
- return message_dict
|
|
|
+ cohere_tool = Tool(
|
|
|
+ name=tool.name,
|
|
|
+ description=tool.description,
|
|
|
+ parameter_definitions=parameter_definitions
|
|
|
+ )
|
|
|
+
|
|
|
+ cohere_tools.append(cohere_tool)
|
|
|
+
|
|
|
+ return cohere_tools
|
|
|
|
|
|
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
|
|
|
"""
|
|
@@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
model=model
|
|
|
)
|
|
|
|
|
|
- return response.length
|
|
|
+ return len(response.tokens)
|
|
|
|
|
|
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
|
|
|
"""Calculate num tokens Cohere model."""
|
|
|
- messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
|
|
- message_strs = [f"{message['role']}: {message['message']}" for message in messages]
|
|
|
+ calc_messages = []
|
|
|
+ for message in messages:
|
|
|
+ cohere_message = self._convert_prompt_message_to_dict(message)
|
|
|
+ if cohere_message:
|
|
|
+ calc_messages.append(cohere_message)
|
|
|
+ message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
|
|
|
message_str = "\n".join(message_strs)
|
|
|
|
|
|
real_model = model
|
|
@@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
return {
|
|
|
InvokeConnectionError: [
|
|
|
- cohere.CohereConnectionError
|
|
|
+ cohere.errors.service_unavailable_error.ServiceUnavailableError
|
|
|
+ ],
|
|
|
+ InvokeServerUnavailableError: [
|
|
|
+ cohere.errors.internal_server_error.InternalServerError
|
|
|
+ ],
|
|
|
+ InvokeRateLimitError: [
|
|
|
+ cohere.errors.too_many_requests_error.TooManyRequestsError
|
|
|
+ ],
|
|
|
+ InvokeAuthorizationError: [
|
|
|
+ cohere.errors.unauthorized_error.UnauthorizedError,
|
|
|
+ cohere.errors.forbidden_error.ForbiddenError
|
|
|
],
|
|
|
- InvokeServerUnavailableError: [],
|
|
|
- InvokeRateLimitError: [],
|
|
|
- InvokeAuthorizationError: [],
|
|
|
InvokeBadRequestError: [
|
|
|
- cohere.CohereAPIError,
|
|
|
- cohere.CohereError,
|
|
|
+ cohere.core.api_error.ApiError,
|
|
|
+ cohere.errors.bad_request_error.BadRequestError,
|
|
|
+ cohere.errors.not_found_error.NotFoundError,
|
|
|
]
|
|
|
}
|