|
@@ -25,6 +25,7 @@ from botocore.exceptions import (
|
|
|
ServiceNotInRegionError,
|
|
|
UnknownServiceError,
|
|
|
)
|
|
|
+from cohere import ChatMessage
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
@@ -48,6 +49,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
+from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -75,8 +77,86 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
# invoke anthropic models via anthropic official SDK
|
|
|
if "anthropic" in model:
|
|
|
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
|
|
+ # invoke Cohere models via boto3 client
|
|
|
+ if "cohere.command-r" in model:
|
|
|
+ return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
|
|
# invoke other models via boto3 client
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
|
|
+
|
|
|
+ def _generate_cohere_chat(
|
|
|
+ self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
|
|
+ cohere_llm = CohereLargeLanguageModel()
|
|
|
+ client_config = Config(
|
|
|
+ region_name=credentials["aws_region"]
|
|
|
+ )
|
|
|
+
|
|
|
+ runtime_client = boto3.client(
|
|
|
+ service_name='bedrock-runtime',
|
|
|
+ config=client_config,
|
|
|
+ aws_access_key_id=credentials["aws_access_key_id"],
|
|
|
+ aws_secret_access_key=credentials["aws_secret_access_key"]
|
|
|
+ )
|
|
|
+
|
|
|
+ extra_model_kwargs = {}
|
|
|
+ if stop:
|
|
|
+ extra_model_kwargs['stop_sequences'] = stop
|
|
|
+
|
|
|
+ if tools:
|
|
|
+ tools = cohere_llm._convert_tools(tools)
|
|
|
+ model_parameters['tools'] = tools
|
|
|
+
|
|
|
+ message, chat_histories, tool_results \
|
|
|
+ = cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
|
|
+
|
|
|
+ if tool_results:
|
|
|
+ model_parameters['tool_results'] = tool_results
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ **model_parameters,
|
|
|
+ "message": message,
|
|
|
+ "chat_history": chat_histories,
|
|
|
+ }
|
|
|
+
|
|
|
+ # need workaround for ai21 models which doesn't support streaming
|
|
|
+ if stream:
|
|
|
+ invoke = runtime_client.invoke_model_with_response_stream
|
|
|
+ else:
|
|
|
+ invoke = runtime_client.invoke_model
|
|
|
+
|
|
|
+ def serialize(obj):
|
|
|
+ if isinstance(obj, ChatMessage):
|
|
|
+ return obj.__dict__
|
|
|
+ raise TypeError(f"Type {type(obj)} not serializable")
|
|
|
+
|
|
|
+ try:
|
|
|
+ body_jsonstr=json.dumps(payload, default=serialize)
|
|
|
+ response = invoke(
|
|
|
+ modelId=model,
|
|
|
+ contentType="application/json",
|
|
|
+ accept="*/*",
|
|
|
+ body=body_jsonstr
|
|
|
+ )
|
|
|
+ except ClientError as ex:
|
|
|
+ error_code = ex.response['Error']['Code']
|
|
|
+ full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
|
|
+ raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
|
|
+
|
|
|
+ except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
|
|
+ raise InvokeConnectionError(str(ex))
|
|
|
+
|
|
|
+ except UnknownServiceError as ex:
|
|
|
+ raise InvokeServerUnavailableError(str(ex))
|
|
|
+
|
|
|
+ except Exception as ex:
|
|
|
+ raise InvokeError(str(ex))
|
|
|
+
|
|
|
+ if stream:
|
|
|
+ return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
|
+
|
|
|
+ return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
|
+
|
|
|
|
|
|
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|