|
@@ -84,8 +84,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
Model class for Cohere large language model.
|
|
|
"""
|
|
|
|
|
|
- sagemaker_client: Any = None
|
|
|
+ sagemaker_session: Any = None
|
|
|
predictor: Any = None
|
|
|
+ sagemaker_endpoint: str = None
|
|
|
|
|
|
def _handle_chat_generate_response(
|
|
|
self,
|
|
@@ -211,7 +212,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
:param user: unique user id
|
|
|
:return: full response or stream response chunk generator result
|
|
|
"""
|
|
|
- if not self.sagemaker_client:
|
|
|
+ if not self.sagemaker_session:
|
|
|
access_key = credentials.get("aws_access_key_id")
|
|
|
secret_key = credentials.get("aws_secret_access_key")
|
|
|
aws_region = credentials.get("aws_region")
|
|
@@ -226,11 +227,14 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
else:
|
|
|
boto_session = boto3.Session()
|
|
|
|
|
|
- self.sagemaker_client = boto_session.client("sagemaker")
|
|
|
- sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
|
|
|
+ sagemaker_client = boto_session.client("sagemaker")
|
|
|
+ self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client)
|
|
|
+
|
|
|
+ if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"):
|
|
|
+ self.sagemaker_endpoint = credentials.get("sagemaker_endpoint")
|
|
|
self.predictor = Predictor(
|
|
|
- endpoint_name=credentials.get("sagemaker_endpoint"),
|
|
|
- sagemaker_session=sagemaker_session,
|
|
|
+ endpoint_name=self.sagemaker_endpoint,
|
|
|
+ sagemaker_session=self.sagemaker_session,
|
|
|
serializer=serializers.JSONSerializer(),
|
|
|
)
|
|
|
|