|
@@ -0,0 +1,238 @@
|
|
|
+import json
|
|
|
+import logging
|
|
|
+from collections.abc import Generator
|
|
|
+from typing import Any, Optional, Union
|
|
|
+
|
|
|
+import boto3
|
|
|
+
|
|
|
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
|
|
+from core.model_runtime.entities.message_entities import (
|
|
|
+ AssistantPromptMessage,
|
|
|
+ PromptMessage,
|
|
|
+ PromptMessageTool,
|
|
|
+)
|
|
|
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
|
|
+from core.model_runtime.errors.invoke import (
|
|
|
+ InvokeAuthorizationError,
|
|
|
+ InvokeBadRequestError,
|
|
|
+ InvokeConnectionError,
|
|
|
+ InvokeError,
|
|
|
+ InvokeRateLimitError,
|
|
|
+ InvokeServerUnavailableError,
|
|
|
+)
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
+ """
|
|
|
+ Model class for Cohere large language model.
|
|
|
+ """
|
|
|
+ sagemaker_client: Any = None
|
|
|
+
|
|
|
+ def _invoke(self, model: str, credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
+ stream: bool = True, user: Optional[str] = None) \
|
|
|
+ -> Union[LLMResult, Generator]:
|
|
|
+ """
|
|
|
+ Invoke large language model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param model_parameters: model parameters
|
|
|
+ :param tools: tools for tool calling
|
|
|
+ :param stop: stop words
|
|
|
+ :param stream: is stream response
|
|
|
+ :param user: unique user id
|
|
|
+ :return: full response or stream response chunk generator result
|
|
|
+ """
|
|
|
+ # get model mode
|
|
|
+ model_mode = self.get_model_mode(model, credentials)
|
|
|
+
|
|
|
+ if not self.sagemaker_client:
|
|
|
+ access_key = credentials.get('access_key')
|
|
|
+ secret_key = credentials.get('secret_key')
|
|
|
+ aws_region = credentials.get('aws_region')
|
|
|
+ if aws_region:
|
|
|
+ if access_key and secret_key:
|
|
|
+ self.sagemaker_client = boto3.client("sagemaker-runtime",
|
|
|
+ aws_access_key_id=access_key,
|
|
|
+ aws_secret_access_key=secret_key,
|
|
|
+ region_name=aws_region)
|
|
|
+ else:
|
|
|
+ self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
|
|
+ else:
|
|
|
+ self.sagemaker_client = boto3.client("sagemaker-runtime")
|
|
|
+
|
|
|
+
|
|
|
+ sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
|
|
+ response_model = self.sagemaker_client.invoke_endpoint(
|
|
|
+ EndpointName=sagemaker_endpoint,
|
|
|
+ Body=json.dumps(
|
|
|
+ {
|
|
|
+ "inputs": prompt_messages[0].content,
|
|
|
+ "parameters": { "stop" : stop},
|
|
|
+ "history" : []
|
|
|
+ }
|
|
|
+ ),
|
|
|
+ ContentType="application/json",
|
|
|
+ )
|
|
|
+
|
|
|
+ assistant_text = response_model['Body'].read().decode('utf8')
|
|
|
+
|
|
|
+ # transform assistant message to prompt message
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=assistant_text
|
|
|
+ )
|
|
|
+
|
|
|
+ usage = self._calc_response_usage(model, credentials, 0, 0)
|
|
|
+
|
|
|
+ response = LLMResult(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ usage=usage
|
|
|
+ )
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
|
+ """
|
|
|
+ Get number of tokens for given prompt messages
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param tools: tools for tool calling
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # get model mode
|
|
|
+ model_mode = self.get_model_mode(model)
|
|
|
+
|
|
|
+ try:
|
|
|
+ return 0
|
|
|
+ except Exception as e:
|
|
|
+ raise self._transform_invoke_error(e)
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # get model mode
|
|
|
+ model_mode = self.get_model_mode(model)
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(str(ex))
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
|
+ """
|
|
|
+ Map model invoke error to unified error
|
|
|
+ The key is the error type thrown to the caller
|
|
|
+ The value is the error type thrown by the model,
|
|
|
+ which needs to be converted into a unified error type for the caller.
|
|
|
+
|
|
|
+ :return: Invoke error mapping
|
|
|
+ """
|
|
|
+ return {
|
|
|
+ InvokeConnectionError: [
|
|
|
+ InvokeConnectionError
|
|
|
+ ],
|
|
|
+ InvokeServerUnavailableError: [
|
|
|
+ InvokeServerUnavailableError
|
|
|
+ ],
|
|
|
+ InvokeRateLimitError: [
|
|
|
+ InvokeRateLimitError
|
|
|
+ ],
|
|
|
+ InvokeAuthorizationError: [
|
|
|
+ InvokeAuthorizationError
|
|
|
+ ],
|
|
|
+ InvokeBadRequestError: [
|
|
|
+ InvokeBadRequestError,
|
|
|
+ KeyError,
|
|
|
+ ValueError
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
|
+ """
|
|
|
+ used to define customizable model schema
|
|
|
+ """
|
|
|
+ rules = [
|
|
|
+ ParameterRule(
|
|
|
+ name='temperature',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='temperature',
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='温度',
|
|
|
+ en_US='Temperature'
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='top_p',
|
|
|
+ type=ParameterType.FLOAT,
|
|
|
+ use_template='top_p',
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='Top P',
|
|
|
+ en_US='Top P'
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ ParameterRule(
|
|
|
+ name='max_tokens',
|
|
|
+ type=ParameterType.INT,
|
|
|
+ use_template='max_tokens',
|
|
|
+ min=1,
|
|
|
+ max=credentials.get('context_length', 2048),
|
|
|
+ default=512,
|
|
|
+ label=I18nObject(
|
|
|
+ zh_Hans='最大生成长度',
|
|
|
+ en_US='Max Tokens'
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ completion_type = LLMMode.value_of(credentials["mode"])
|
|
|
+
|
|
|
+ if completion_type == LLMMode.CHAT:
|
|
|
+ print(f"completion_type : {LLMMode.CHAT.value}")
|
|
|
+
|
|
|
+ if completion_type == LLMMode.COMPLETION:
|
|
|
+ print(f"completion_type : {LLMMode.COMPLETION.value}")
|
|
|
+
|
|
|
+ features = []
|
|
|
+
|
|
|
+ support_function_call = credentials.get('support_function_call', False)
|
|
|
+ if support_function_call:
|
|
|
+ features.append(ModelFeature.TOOL_CALL)
|
|
|
+
|
|
|
+ support_vision = credentials.get('support_vision', False)
|
|
|
+ if support_vision:
|
|
|
+ features.append(ModelFeature.VISION)
|
|
|
+
|
|
|
+ context_length = credentials.get('context_length', 2048)
|
|
|
+
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(
|
|
|
+ en_US=model
|
|
|
+ ),
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_type=ModelType.LLM,
|
|
|
+ features=features,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.MODE: completion_type,
|
|
|
+ ModelPropertyKey.CONTEXT_SIZE: context_length
|
|
|
+ },
|
|
|
+ parameter_rules=rules
|
|
|
+ )
|
|
|
+
|
|
|
+ return entity
|