浏览代码

feat: add support for Bedrock LLAMA3 (#3890)

longzhihun 1 年之前
父节点
当前提交
43a5ba9415

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/_position.yaml

@@ -8,6 +8,8 @@
 - anthropic.claude-3-haiku-v1:0
 - cohere.command-light-text-v14
 - cohere.command-text-v14
+- meta.llama3-8b-instruct-v1:0
+- meta.llama3-70b-instruct-v1:0
 - meta.llama2-13b-chat-v1
 - meta.llama2-70b-chat-v1
 - mistral.mistral-large-2402-v1:0

+ 22 - 27
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :return:md = genai.GenerativeModel(model)
         """
         prefix = model.split('.')[0]
-
+        model_name = model.split('.')[1]
         if isinstance(messages, str):
             prompt = messages
         else:
-            prompt = self._convert_messages_to_prompt(messages, prefix)
+            prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
 
         return self._get_num_tokens_by_gpt2(prompt)
     
-    def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
-        """
-        Format a list of messages into a full prompt for the Google model
-
-        :param messages: List of PromptMessage to combine.
-        :return: Combined string with necessary human_prompt and ai_prompt tags.
-        """
-        messages = messages.copy()  # don't mutate the original list
-        
-        text = "".join(
-            self._convert_one_message_to_text(message, model_prefix)
-            for message in messages
-        )
-
-        return text.rstrip()
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
-    def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
+    def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
         """
         Convert a single message to a string.
 
@@ -446,10 +431,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             ai_prompt = "\n\nAssistant:"
 
         elif model_prefix == "meta":
-            human_prompt_prefix = "\n[INST]"
-            human_prompt_postfix = "[\\INST]\n"
-            ai_prompt = ""
-        
+            # LLAMA3
+            if model_name.startswith("llama3"):
+                human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
+                human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
+                ai_prompt = "\n\nAssistant:"
+            else:
+                # LLAMA2
+                human_prompt_prefix = "\n[INST]"
+                human_prompt_postfix = "[\\INST]\n"
+                ai_prompt = ""
+
         elif model_prefix == "mistral":
             human_prompt_prefix = "<s>[INST]"
             human_prompt_postfix = "[\\INST]\n"
@@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 
         return message_text
 
-    def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
+    def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
         """
         Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
 
         :param messages: List of PromptMessage to combine.
+        :param model_name: specific model name.Optional,just to distinguish llama2 and llama3
         :return: Combined string with necessary human_prompt and ai_prompt tags.
         """
         if not messages:
@@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             messages.append(AssistantPromptMessage(content=""))
 
         text = "".join(
-            self._convert_one_message_to_text(message, model_prefix)
+            self._convert_one_message_to_text(message, model_prefix, model_name)
             for message in messages
         )
 
         # trim off the trailing ' ' that might come from the "Assistant: "
         return text.rstrip()
 
-    def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
+    def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
         """
         Create payload for bedrock api call depending on model provider
         """
         payload = dict()
+        model_prefix = model.split('.')[0]
+        model_name = model.split('.')[1]
 
         if model_prefix == "amazon":
             payload["textGenerationConfig"] = { **model_parameters }
@@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         
         elif model_prefix == "meta":
             payload = { **model_parameters }
-            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
 
         else:
             raise ValueError(f"Got unknown model prefix {model_prefix}")
@@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         )
 
         model_prefix = model.split('.')[0]
-        payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
+        payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
 
         # need workaround for ai21 models which doesn't support streaming
         if stream and model_prefix != "ai21":

+ 23 - 0
api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-70b-instruct-v1.yaml

@@ -0,0 +1,23 @@
+model: meta.llama3-70b-instruct-v1:0
+label:
+  en_US: Llama 3 Instruct 70B
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_gen_len
+    use_template: max_tokens
+    required: true
+    default: 512
+    min: 1
+    max: 2048
+pricing:
+  input: '0.00265'
+  output: '0.0035'
+  unit: '0.00001'
+  currency: USD

+ 23 - 0
api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-8b-instruct-v1.yaml

@@ -0,0 +1,23 @@
+model: meta.llama3-8b-instruct-v1:0
+label:
+  en_US: Llama 3 Instruct 8B
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_gen_len
+    use_template: max_tokens
+    required: true
+    default: 512
+    min: 1
+    max: 2048
+pricing:
+  input: '0.0004'
+  output: '0.0006'
+  unit: '0.0001'
+  currency: USD