Переглянути джерело

Fix Issue: switch LLM of SageMaker endpoint doesn't take effect (#8737)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
ybalbert001 7 місяців тому
батько
коміт
68c7e68a8a

+ 10 - 6
api/core/model_runtime/model_providers/sagemaker/llm/llm.py

@@ -84,8 +84,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
     Model class for Cohere large language model.
     """
 
-    sagemaker_client: Any = None
+    sagemaker_session: Any = None
     predictor: Any = None
+    sagemaker_endpoint: str = None
 
     def _handle_chat_generate_response(
         self,
@@ -211,7 +212,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         :param user: unique user id
         :return: full response or stream response chunk generator result
         """
-        if not self.sagemaker_client:
+        if not self.sagemaker_session:
             access_key = credentials.get("aws_access_key_id")
             secret_key = credentials.get("aws_secret_access_key")
             aws_region = credentials.get("aws_region")
@@ -226,11 +227,14 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
             else:
                 boto_session = boto3.Session()
 
-            self.sagemaker_client = boto_session.client("sagemaker")
-            sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
+            sagemaker_client = boto_session.client("sagemaker")
+            self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client)
+
+        if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"):
+            self.sagemaker_endpoint = credentials.get("sagemaker_endpoint")
             self.predictor = Predictor(
-                endpoint_name=credentials.get("sagemaker_endpoint"),
-                sagemaker_session=sagemaker_session,
+                endpoint_name=self.sagemaker_endpoint,
+                sagemaker_session=self.sagemaker_session,
                 serializer=serializers.JSONSerializer(),
             )