Explorar el Código

fix: Fix the problem of system not working (#2884)

Su Yang hace 1 año
padre
commit
507aa6d949
Se han modificado 1 ficheros con 36 adiciones y 11 borrados
  1. 36 11
      api/core/model_runtime/model_providers/bedrock/llm/llm.py

+ 36 - 11
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -74,12 +74,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 
         # invoke claude 3 models via anthropic official SDK
         if "anthropic.claude-3" in model:
-            return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream)
+            return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream, user)
         # invoke model
         return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
 
     def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
-                stop: Optional[list[str]] = None, stream: bool = True) -> Union[LLMResult, Generator]:
+                stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
         """
         Invoke Claude3 large language model
 
@@ -100,22 +100,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             aws_region=credentials["aws_region"],
         )
 
+        extra_model_kwargs = {}
+        if stop:
+            extra_model_kwargs['stop_sequences'] = stop
+
+        # Notice: If you request the current version of the SDK to the bedrock server,
+        #         you will get the following error message and you need to wait for the service or SDK to be updated.
+        #         Response:  Error code: 400
+        #                    {'message': 'Malformed input request: #: subject must not be valid against schema
+        #                        {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
+        # TODO: Open in the future when the interface is properly supported
+        # if user:
+            # ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
+            # extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
+
         system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages)
 
+        if system:
+            extra_model_kwargs['system'] = system
+
         response = client.messages.create(
             model=model,
             messages=prompt_message_dicts,
-            stop_sequences=stop if stop else [],
-            system=system,
             stream=stream,
             **model_parameters,
+            **extra_model_kwargs
         )
 
-        if stream is False:
-            return self._handle_claude3_response(model, credentials, response, prompt_messages)
-        else:
+        if stream:
             return self._handle_claude3_stream_response(model, credentials, response, prompt_messages)
 
+        return self._handle_claude3_response(model, credentials, response, prompt_messages)
+
     def _handle_claude3_response(self, model: str, credentials: dict, response: Message,
                                 prompt_messages: list[PromptMessage]) -> LLMResult:
         """
@@ -263,13 +279,22 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         """
         Convert prompt messages to dict list and system
         """
-        system = ""
-        prompt_message_dicts = []
 
+        system = ""
+        first_loop = True
         for message in prompt_messages:
             if isinstance(message, SystemPromptMessage):
-                system += message.content + ("\n" if not system else "")
-            else:
+                message.content=message.content.strip()
+                if first_loop:
+                    system=message.content
+                    first_loop=False
+                else:
+                    system+="\n"
+                    system+=message.content
+
+        prompt_message_dicts = []
+        for message in prompt_messages:
+            if not isinstance(message, SystemPromptMessage):
                 prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message))
 
         return system, prompt_message_dicts