소스 검색

feat: add support for bedrock Mistral AI model (#3676)

Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
longzhihun 1 년 전
부모
커밋
28236147ee

+ 3 - 0
api/core/model_runtime/model_providers/bedrock/llm/_position.yaml

@@ -10,3 +10,6 @@
 - cohere.command-text-v14
 - meta.llama2-13b-chat-v1
 - meta.llama2-70b-chat-v1
+- mistral.mistral-large-2402-v1:0
+- mistral.mixtral-8x7b-instruct-v0:1
+- mistral.mistral-7b-instruct-v0:2

+ 21 - 0
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             human_prompt_prefix = "\n[INST]"
             human_prompt_postfix = "[\\INST]\n"
             ai_prompt = ""
+        
+        elif model_prefix == "mistral":
+            human_prompt_prefix = "<s>[INST]"
+            human_prompt_postfix = "[\\INST]\n"
+            ai_prompt = "\n\nAssistant:"
 
         elif model_prefix == "amazon":
             human_prompt_prefix = "\n\nUser:"
@@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                 payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
             if model_parameters.get("countPenalty"):
                 payload["countPenalty"] = {model_parameters.get("countPenalty")}
+        
+        elif model_prefix == "mistral":
+            payload["temperature"] = model_parameters.get("temperature")
+            payload["top_p"] = model_parameters.get("top_p")
+            payload["max_tokens"] = model_parameters.get("max_tokens")
+            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+            payload["stop"] = stop[:10] if stop else []
 
         elif model_prefix == "anthropic":
             payload = { **model_parameters }
@@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             output = response_body.get("generation").strip('\n')
             prompt_tokens = response_body.get("prompt_token_count")
             completion_tokens = response_body.get("generation_token_count")
+        
+        elif model_prefix == "mistral":
+            output = response_body.get("outputs")[0].get("text")
+            prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
+            completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
 
         else:
             raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                 content_delta = payload.get("text")
                 finish_reason = payload.get("finish_reason")
             
+            elif model_prefix == "mistral":
+                content_delta = payload.get('outputs')[0].get("text")
+                finish_reason = payload.get('outputs')[0].get("stop_reason")
+
             elif model_prefix == "meta":
                 content_delta = payload.get("generation").strip('\n')
                 finish_reason = payload.get("stop_reason")

+ 39 - 0
api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-7b-instruct-v0:2.yaml

@@ -0,0 +1,39 @@
+model: mistral.mistral-7b-instruct-v0:2
+label:
+  en_US: Mistral 7B Instruct
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    required: false
+    default: 0.5
+  - name: top_p
+    use_template: top_p
+    required: false
+    default: 0.9
+  - name: top_k
+    use_template: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 50
+    max: 200
+  - name: max_tokens
+    use_template: max_tokens
+    required: true
+    default: 512
+    min: 1
+    max: 8192
+pricing:
+  input: '0.00015'
+  output: '0.0002'
+  unit: '0.00001'
+  currency: USD

+ 27 - 0
api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2402-v1:0.yaml

@@ -0,0 +1,27 @@
+model: mistral.mistral-large-2402-v1:0
+label:
+  en_US: Mistral Large
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    required: false
+    default: 0.7
+  - name: top_p
+    use_template: top_p
+    required: false
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    required: true
+    default: 512
+    min: 1
+    max: 4096
+pricing:
+  input: '0.008'
+  output: '0.024'
+  unit: '0.001'
+  currency: USD

+ 39 - 0
api/core/model_runtime/model_providers/bedrock/llm/mistral.mixtral-8x7b-instruct-v0:1.yaml

@@ -0,0 +1,39 @@
+model: mistral.mixtral-8x7b-instruct-v0:1
+label:
+  en_US: Mixtral 8X7B Instruct
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    required: false
+    default: 0.5
+  - name: top_p
+    use_template: top_p
+    required: false
+    default: 0.9
+  - name: top_k
+    use_template: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 50
+    max: 200
+  - name: max_tokens
+    use_template: max_tokens
+    required: true
+    default: 512
+    min: 1
+    max: 8192
+pricing:
+  input: '0.00045'
+  output: '0.0007'
+  unit: '0.00001'
+  currency: USD