Bladeren bron

fix llm integration problem: It doesn't work on docker env (#8701)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
ybalbert001 7 maanden geleden
bovenliggende
commit
7c485f8bb8
1 gewijzigde bestanden met toevoegingen van 9 en 11 verwijderingen
  1. 9 11
      api/core/model_runtime/model_providers/sagemaker/llm/llm.py

+ 9 - 11
api/core/model_runtime/model_providers/sagemaker/llm/llm.py

@@ -85,7 +85,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
     """
 
     sagemaker_client: Any = None
-    sagemaker_sess: Any = None
     predictor: Any = None
 
     def _handle_chat_generate_response(
@@ -213,23 +212,22 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         :return: full response or stream response chunk generator result
         """
         if not self.sagemaker_client:
-            access_key = credentials.get("access_key")
-            secret_key = credentials.get("secret_key")
+            access_key = credentials.get("aws_access_key_id")
+            secret_key = credentials.get("aws_secret_access_key")
             aws_region = credentials.get("aws_region")
+            boto_session = None
             if aws_region:
                 if access_key and secret_key:
-                    self.sagemaker_client = boto3.client(
-                        "sagemaker-runtime",
-                        aws_access_key_id=access_key,
-                        aws_secret_access_key=secret_key,
-                        region_name=aws_region,
+                    boto_session = boto3.Session(
+                        aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region
                     )
                 else:
-                    self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+                    boto_session = boto3.Session(region_name=aws_region)
             else:
-                self.sagemaker_client = boto3.client("sagemaker-runtime")
+                boto_session = boto3.Session()
 
-            sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
+            self.sagemaker_client = boto_session.client("sagemaker")
+            sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
             self.predictor = Predictor(
                 endpoint_name=credentials.get("sagemaker_endpoint"),
                 sagemaker_session=sagemaker_session,