Bladeren bron

Add bedrock (#2119)

Co-authored-by: takatost <takatost@users.noreply.github.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Charlie.Wei <luowei@cvte.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Benjamin <benjaminx@gmail.com>
Chenhe Gu 1 jaar geleden
bovenliggende
commit
14a2eeba0c
23 gewijzigde bestanden met toevoegingen van 1143 en 0 verwijderingen
  1. 0 0
      api/core/model_runtime/model_providers/bedrock/__init__.py
  2. 8 0
      api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg
  3. 3 0
      api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg
  4. 30 0
      api/core/model_runtime/model_providers/bedrock/bedrock.py
  5. 71 0
      api/core/model_runtime/model_providers/bedrock/bedrock.yaml
  6. 0 0
      api/core/model_runtime/model_providers/bedrock/llm/__init__.py
  7. 10 0
      api/core/model_runtime/model_providers/bedrock/llm/_position.yaml
  8. 47 0
      api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml
  9. 47 0
      api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml
  10. 25 0
      api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml
  11. 25 0
      api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml
  12. 35 0
      api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml
  13. 35 0
      api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml
  14. 35 0
      api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml
  15. 35 0
      api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml
  16. 35 0
      api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml
  17. 32 0
      api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml
  18. 486 0
      api/core/model_runtime/model_providers/bedrock/llm/llm.py
  19. 23 0
      api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml
  20. 23 0
      api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml
  21. 0 0
      api/tests/integration_tests/model_runtime/bedrock/__init__.py
  22. 117 0
      api/tests/integration_tests/model_runtime/bedrock/test_llm.py
  23. 21 0
      api/tests/integration_tests/model_runtime/bedrock/test_provider.py

+ 0 - 0
api/core/model_runtime/model_providers/bedrock/__init__.py


File diff suppressed because it is too large
+ 8 - 0
api/core/model_runtime/model_providers/bedrock/_assets/icon_l_en.svg


File diff suppressed because it is too large
+ 3 - 0
api/core/model_runtime/model_providers/bedrock/_assets/icon_s_en.svg


+ 30 - 0
api/core/model_runtime/model_providers/bedrock/bedrock.py

@@ -0,0 +1,30 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+class BedrockProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            # Use `gemini-pro` model for validate,
+            model_instance.validate_credentials(
+                model='amazon.titan-text-lite-v1',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 71 - 0
api/core/model_runtime/model_providers/bedrock/bedrock.yaml

@@ -0,0 +1,71 @@
+provider: bedrock
+label:
+  en_US: AWS
+description:
+  en_US: AWS Bedrock's models.
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+background: "#FCFDFF"
+help:
+  title:
+    en_US: Get your Access Key and Secret Access Key from AWS Console
+  url:
+    en_US: https://console.aws.amazon.com/
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: aws_access_key_id
+      required: true
+      label:
+        en_US: Access Key
+        zh_Hans: Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Access Key
+        zh_Hans: 在此输入您的 Access Key
+    - variable: aws_secret_access_key
+      required: true
+      label:
+        en_US: Secret Access Key
+        zh_Hans: Secret Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Secret Access Key
+        zh_Hans: 在此输入您的 Secret Access Key
+    - variable: aws_region
+      required: true
+      label:
+        en_US: AWS Region
+        zh_Hans: AWS 地区
+      type: select
+      default: us-east-1
+      options:
+        - value: us-east-1
+          label:
+            en_US: US East (N. Virginia)
+            zh_Hans: US East (N. Virginia)
+        - value: us-west-2
+          label:
+            en_US: US West (Oregon)
+            zh_Hans: US West (Oregon)
+        - value: ap-southeast-1
+          label:
+            en_US: Asia Pacific (Singapore)
+            zh_Hans: Asia Pacific (Singapore)
+        - value: ap-northeast-1
+          label:
+            en_US: Asia Pacific (Tokyo)
+            zh_Hans: Asia Pacific (Tokyo)
+        - value: eu-central-1
+          label:
+            en_US: Europe (Frankfurt)
+            zh_Hans: Europe (Frankfurt)
+        - value: us-gov-west-1
+          label:
+            en_US: AWS GovCloud (US-West)
+            zh_Hans: AWS GovCloud (US-West)

+ 0 - 0
api/core/model_runtime/model_providers/bedrock/llm/__init__.py


+ 10 - 0
api/core/model_runtime/model_providers/bedrock/llm/_position.yaml

@@ -0,0 +1,10 @@
+- amazon.titan-text-express-v1
+- amazon.titan-text-lite-v1
+- anthropic.claude-instant-v1
+- anthropic.claude-v1
+- anthropic.claude-v2
+- anthropic.claude-v2:1
+- cohere.command-light-text-v14
+- cohere.command-text-v14
+- meta.llama2-13b-chat-v1
+- meta.llama2-70b-chat-v1

+ 47 - 0
api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-mid-v1.yaml

@@ -0,0 +1,47 @@
+model: ai21.j2-mid-v1
+label:
+  en_US: J2 Mid V1
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 8191
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: maxTokens
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 2048
+  - name: count_penalty
+    label:
+      en_US: Count Penalty
+    required: false
+    type: float
+    default: 0
+    min: 0
+    max: 1
+  - name: presence_penalty
+    label:
+      en_US: Presence Penalty
+    required: false
+    type: float
+    default: 0
+    min: 0
+    max: 5
+  - name: frequency_penalty
+    label:
+      en_US: Frequency Penalty
+    required: false
+    type: float
+    default: 0
+    min: 0
+    max: 500
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.000001'
+  currency: USD

+ 47 - 0
api/core/model_runtime/model_providers/bedrock/llm/ai21.j2-ultra-v1.yaml

@@ -0,0 +1,47 @@
+model: ai21.j2-ultra-v1
+label:
+  en_US: J2 Ultra V1
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 8191
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: maxTokens
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 2048
+  - name: count_penalty
+    label:
+      en_US: Count Penalty
+    required: false
+    type: float
+    default: 0
+    min: 0
+    max: 1
+  - name: presence_penalty
+    label:
+      en_US: Presence Penalty
+    required: false
+    type: float
+    default: 0
+    min: 0
+    max: 5
+  - name: frequency_penalty
+    label:
+      en_US: Frequency Penalty
+    required: false
+    type: float
+    default: 0
+    min: 0
+    max: 500
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.000001'
+  currency: USD

+ 25 - 0
api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml

@@ -0,0 +1,25 @@
+model: amazon.titan-text-express-v1
+label:
+  en_US: Titan Text G1 - Express
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: maxTokenCount
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 8000
+pricing:
+  input: '0.0008'
+  output: '0.0016'
+  unit: '0.001'
+  currency: USD

+ 25 - 0
api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml

@@ -0,0 +1,25 @@
+model: amazon.titan-text-lite-v1
+label:
+  en_US: Titan Text G1 - Lite
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: maxTokenCount
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 2048
+pricing:
+  input: '0.0003'
+  output: '0.0004'
+  unit: '0.001'
+  currency: USD

+ 35 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-instant-v1.yaml

@@ -0,0 +1,35 @@
+model: anthropic.claude-instant-v1
+label:
+  en_US: Claude Instant V1
+model_type: llm
+model_properties:
+  mode: chat
+  context_size: 100000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top K
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 250
+    min: 0
+    max: 500
+  - name: max_tokens_to_sample
+    use_template: max_tokens
+    required: true
+    default: 4096
+    min: 1
+    max: 4096
+pricing:
+  input: '0.0008'
+  output: '0.0024'
+  unit: '0.001'
+  currency: USD

+ 35 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml

@@ -0,0 +1,35 @@
+model: anthropic.claude-v1
+label:
+  en_US: Claude V1
+model_type: llm
+model_properties:
+  mode: chat
+  context_size: 100000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top K
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 250
+    min: 0
+    max: 500
+  - name: max_tokens_to_sample
+    use_template: max_tokens
+    required: true
+    default: 4096
+    min: 1
+    max: 4096
+pricing:
+  input: '0.008'
+  output: '0.024'
+  unit: '0.001'
+  currency: USD

+ 35 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml

@@ -0,0 +1,35 @@
+model: anthropic.claude-v2
+label:
+  en_US: Claude V2
+model_type: llm
+model_properties:
+  mode: chat
+  context_size: 100000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top K
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 250
+    min: 0
+    max: 500
+  - name: max_tokens_to_sample
+    use_template: max_tokens
+    required: true
+    default: 4096
+    min: 1
+    max: 4096
+pricing:
+  input: '0.008'
+  output: '0.024'
+  unit: '0.001'
+  currency: USD

+ 35 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2:1.yaml

@@ -0,0 +1,35 @@
+model: anthropic.claude-v2:1
+label:
+  en_US: Claude V2.1
+model_type: llm
+model_properties:
+  mode: chat
+  context_size: 200000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: topP
+    use_template: top_p
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top K
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 250
+    min: 0
+    max: 500
+  - name: max_tokens_to_sample
+    use_template: max_tokens
+    required: true
+    default: 4096
+    min: 1
+    max: 4096
+pricing:
+  input: '0.008'
+  output: '0.024'
+  unit: '0.001'
+  currency: USD

+ 35 - 0
api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml

@@ -0,0 +1,35 @@
+model: cohere.command-light-text-v14
+label:
+  en_US: Command Light Text V14
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: p
+    use_template: top_p
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    min: 0
+    max: 500
+    default: 0
+  - name: max_tokens_to_sample
+    use_template: max_tokens
+    required: true
+    default: 4096
+    min: 1
+    max: 4096
+pricing:
+  input: '0.0003'
+  output: '0.0006'
+  unit: '0.001'
+  currency: USD

+ 32 - 0
api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml

@@ -0,0 +1,32 @@
+model: cohere.command-text-v14
+label:
+  en_US: Command Text V14
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+  - name: max_tokens_to_sample
+    use_template: max_tokens
+    required: true
+    default: 4096
+    min: 1
+    max: 4096
+pricing:
+  input: '0.0015'
+  output: '0.0020'
+  unit: '0.001'
+  currency: USD

+ 486 - 0
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -0,0 +1,486 @@
+import logging
+from typing import Generator, List, Optional, Union
+
+import boto3
+from botocore.exceptions import ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError, UnknownServiceError
+from botocore.config import Config
+import json
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
+                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage)
+from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
+                                              InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+class BedrockLargeLanguageModel(LargeLanguageModel):
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # invoke model
+        return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+
+    def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param messages: prompt messages or message string
+        :param tools: tools for tool calling
+        :return:md = genai.GenerativeModel(model)
+        """
+        prefix = model.split('.')[0]
+
+        if isinstance(messages, str):
+            prompt = messages
+        else:
+            prompt = self._convert_messages_to_prompt(messages, prefix)
+
+        return self._get_num_tokens_by_gpt2(prompt)
+    
+    def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
+        """
+        Format a list of messages into a full prompt for the Google model
+
+        :param messages: List of PromptMessage to combine.
+        :return: Combined string with necessary human_prompt and ai_prompt tags.
+        """
+        messages = messages.copy()  # don't mutate the original list
+        
+        text = "".join(
+            self._convert_one_message_to_text(message, model_prefix)
+            for message in messages
+        )
+
+        return text.rstrip()
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        
+        try:
+            ping_message = UserPromptMessage(content="ping")
+            self._generate(model=model,
+                           credentials=credentials,
+                           prompt_messages=[ping_message],
+                           model_parameters={},
+                           stream=False)
+        
+        except ClientError as ex:
+            error_code = ex.response['Error']['Code']
+            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
+
+            raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
+
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
+        """
+        Convert a single message to a string.
+
+        :param message: PromptMessage to convert.
+        :return: String representation of the message.
+        """
+        
+        if model_prefix == "anthropic":
+            human_prompt_prefix = "\n\nHuman:"
+            human_prompt_postfix = ""
+            ai_prompt = "\n\nAssistant:"
+
+        elif model_prefix == "meta":
+            human_prompt_prefix = "\n[INST]"
+            human_prompt_postfix = "[\\INST]\n"
+            ai_prompt = ""
+
+        elif model_prefix == "amazon":
+            human_prompt_prefix = "\n\nUser:"
+            human_prompt_postfix = ""
+            ai_prompt = "\n\nBot:"
+        
+        else:
+            human_prompt_prefix = ""
+            human_prompt_postfix = ""
+            ai_prompt = ""
+
+        content = message.content
+
+        if isinstance(message, UserPromptMessage):
+            message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}"
+        elif isinstance(message, AssistantPromptMessage):
+            message_text = f"{ai_prompt} {content}"
+        elif isinstance(message, SystemPromptMessage):
+            message_text = content
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        return message_text
+
+    def _convert_messages_to_prompt(self, messages: List[PromptMessage], model_prefix: str) -> str:
+        """
+        Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
+
+        :param messages: List of PromptMessage to combine.
+        :return: Combined string with necessary human_prompt and ai_prompt tags.
+        """
+        if not messages:
+            return ''
+
+        messages = messages.copy()  # don't mutate the original list
+        if not isinstance(messages[-1], AssistantPromptMessage):
+            messages.append(AssistantPromptMessage(content=""))
+
+        text = "".join(
+            self._convert_one_message_to_text(message, model_prefix)
+            for message in messages
+        )
+
+        # trim off the trailing ' ' that might come from the "Assistant: "
+        return text.rstrip()
+
+    def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, stream: bool = True):
+        """
+        Create payload for bedrock api call depending on model provider
+        """
+        payload = dict()
+
+        if model_prefix == "amazon":
+            payload["textGenerationConfig"] = { **model_parameters }
+            payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else [])
+            
+            payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+        
+        elif model_prefix == "ai21":
+            payload["temperature"] = model_parameters.get("temperature")
+            payload["topP"] = model_parameters.get("topP")
+            payload["maxTokens"] = model_parameters.get("maxTokens")
+            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+
+            # jurassic models only support a single stop sequence
+            if stop:
+                payload["stopSequences"] = stop[0]
+
+            if model_parameters.get("presencePenalty"):
+                payload["presencePenalty"] = {model_parameters.get("presencePenalty")}
+            if model_parameters.get("frequencyPenalty"):
+                payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
+            if model_parameters.get("countPenalty"):
+                payload["countPenalty"] = {model_parameters.get("countPenalty")}
+
+        elif model_prefix == "anthropic":
+            payload = { **model_parameters }
+            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+            payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
+            
+        elif model_prefix == "cohere":
+            payload = { **model_parameters }
+            payload["prompt"] = prompt_messages[0].content
+            payload["stream"] = stream
+        
+        elif model_prefix == "meta":
+            payload = { **model_parameters }
+            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
+
+        else:
+            raise ValueError(f"Got unknown model prefix {model_prefix}")
+        
+        return payload
+
+    def _generate(self, model: str, credentials: dict,
+                  prompt_messages: list[PromptMessage], model_parameters: dict,
+                  stop: Optional[List[str]] = None, stream: bool = True,
+                  user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: credentials kwargs
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        client_config = Config(
+            region_name=credentials["aws_region"]
+        )
+
+        runtime_client = boto3.client(
+            service_name='bedrock-runtime',
+            config=client_config,
+            aws_access_key_id=credentials["aws_access_key_id"],
+            aws_secret_access_key=credentials["aws_secret_access_key"]
+        )
+
+        model_prefix = model.split('.')[0]
+        payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
+
+        # need workaround for ai21 models which doesn't support streaming
+        if stream and model_prefix != "ai21":
+            invoke = runtime_client.invoke_model_with_response_stream
+        else:
+            invoke = runtime_client.invoke_model
+
+        try:
+            response = invoke(
+                body=json.dumps(payload),
+                modelId=model,
+            )
+        except ClientError as ex:
+            error_code = ex.response['Error']['Code']
+            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
+            raise self._map_client_to_invoke_error(error_code, full_error_msg)
+        
+        except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
+            raise InvokeConnectionError(str(ex))
+
+        except UnknownServiceError as ex:
+            raise InvokeServerUnavailableError(str(ex))
+
+        except Exception as ex:
+            raise InvokeError(str(ex))
+        
+
+        if stream:
+            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+        return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+    def _handle_generate_response(self, model: str, credentials: dict, response: dict,
+                                  prompt_messages: list[PromptMessage]) -> LLMResult:
+        """
+        Handle llm response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response
+        """
+        response_body = json.loads(response.get('body').read().decode('utf-8'))
+
+        finish_reason = response_body.get("error")
+
+        if finish_reason is not None:
+            raise InvokeError(finish_reason)
+
+        # get output text and calculate num tokens based on model / provider
+        model_prefix = model.split('.')[0]
+
+        if model_prefix == "amazon":
+            output = response_body.get("results")[0].get("outputText").strip('\n')
+            prompt_tokens = response_body.get("inputTextTokenCount")
+            completion_tokens = response_body.get("results")[0].get("tokenCount")
+
+        elif model_prefix == "ai21":
+            output = response_body.get('completions')[0].get('data').get('text')
+            prompt_tokens = len(response_body.get("prompt").get("tokens"))
+            completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
+
+        elif model_prefix == "anthropic":
+            output = response_body.get("completion")
+            prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+            completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
+            
+        elif model_prefix == "cohere":
+            output = response_body.get("generations")[0].get("text")
+            prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+            completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
+            
+        elif model_prefix == "meta":
+            output = response_body.get("generation").strip('\n')
+            prompt_tokens = response_body.get("prompt_token_count")
+            completion_tokens = response_body.get("generation_token_count")
+
+        else:
+            raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
+
+        # construct assistant message from output
+        assistant_prompt_message = AssistantPromptMessage(
+            content=output
+        )
+
+        # calculate usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # construct response
+        result = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage,
+        )
+
+        return result
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator result
+        """
+        model_prefix = model.split('.')[0]
+        if model_prefix == "ai21":
+            response_body = json.loads(response.get('body').read().decode('utf-8'))
+
+            content = response_body.get('completions')[0].get('data').get('text')
+            finish_reason = response_body.get('completions')[0].get('finish_reason')
+
+            prompt_tokens = len(response_body.get("prompt").get("tokens"))
+            completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+            yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=0,
+                        message=AssistantPromptMessage(content=content),
+                        finish_reason=finish_reason,
+                        usage=usage
+                    )
+                )
+            return
+        
+        stream = response.get('body')
+        if not stream:
+            raise InvokeError('No response body')
+        
+        index = -1
+        for event in stream:
+            chunk = event.get('chunk')
+            
+            if not chunk:
+                exception_name = next(iter(event))
+                full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
+
+                raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
+
+            payload = json.loads(chunk.get('bytes').decode())
+
+            model_prefix = model.split('.')[0]
+            if model_prefix == "amazon":
+                content_delta = payload.get("outputText").strip('\n')
+                finish_reason = payload.get("completion_reason")
+ 
+            elif model_prefix == "anthropic":
+                content_delta = payload
+                finish_reason = payload.get("stop_reason")
+
+            elif model_prefix == "cohere":
+                content_delta = payload.get("text")
+                finish_reason = payload.get("finish_reason")
+            
+            elif model_prefix == "meta":
+                content_delta = payload.get("generation").strip('\n')
+                finish_reason = payload.get("stop_reason")
+            
+            else:
+                raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
+
+            index += 1
+           
+            assistant_prompt_message = AssistantPromptMessage(
+                content = content_delta if content_delta else '',
+            )
+  
+            if not finish_reason:
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message
+                    )
+                )
+
+            else:
+                # get num tokens from metrics in last chunk
+                prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"]
+                completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"]
+
+                # transform usage
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+                
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                        finish_reason=finish_reason,
+                        usage=usage
+                    )
+                )
+    
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
+        The value is the md = genai.GenerativeModel(model)error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke emd = genai.GenerativeModel(model)rror mapping
+        """
+        return {
+            InvokeConnectionError: [],
+            InvokeServerUnavailableError: [],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [],
+            InvokeBadRequestError: []
+        }
+    
+    def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
+        """
+        Map client error to invoke error
+
+        :param error_code: error code
+        :param error_msg: error message
+        :return: invoke error
+        """
+
+        if error_code == "AccessDeniedException":
+            return InvokeAuthorizationError(error_msg)
+        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+            return InvokeBadRequestError(error_msg)
+        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+            return InvokeRateLimitError(error_msg)
+        elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
+            return InvokeServerUnavailableError(error_msg)
+        elif error_code == "ModelStreamErrorException":
+            return InvokeConnectionError(error_msg)
+
+        return InvokeError(error_msg)

+ 23 - 0
api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-13b-chat-v1.yaml

@@ -0,0 +1,23 @@
+model: meta.llama2-13b-chat-v1
+label:
+  en_US: Llama 2 Chat 13B
+model_type: llm
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_gen_len
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 2048
+pricing:
+  input: '0.00075'
+  output: '0.00100'
+  unit: '0.001'
+  currency: USD

+ 23 - 0
api/core/model_runtime/model_providers/bedrock/llm/meta.llama2-70b-chat-v1.yaml

@@ -0,0 +1,23 @@
+model: meta.llama2-70b-chat-v1
+label:
+  en_US: Llama 2 Chat 70B
+model_type: llm
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_gen_len
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 2048
+pricing:
+  input: '0.00195'
+  output: '0.00256'
+  unit: '0.001'
+  currency: USD

+ 0 - 0
api/tests/integration_tests/model_runtime/bedrock/__init__.py


+ 117 - 0
api/tests/integration_tests/model_runtime/bedrock/test_llm.py

@@ -0,0 +1,117 @@
+import os
+from typing import Generator
+
+import pytest
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
+
+def test_validate_credentials():
+    model = BedrockLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='meta.llama2-13b-chat-v1',
+            credentials={
+                'anthropic_api_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='meta.llama2-13b-chat-v1',
+        credentials={
+            "aws_region": os.getenv("AWS_REGION"),
+            "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+        }
+    )
+
+def test_invoke_model():
+    model = BedrockLargeLanguageModel()
+
+    response = model.invoke(
+        model='meta.llama2-13b-chat-v1',
+        credentials={
+            "aws_region": os.getenv("AWS_REGION"),
+            "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'top_p': 1.0,
+            'max_tokens_to_sample': 10
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+def test_invoke_stream_model():
+    model = BedrockLargeLanguageModel()
+
+    response = model.invoke(
+        model='meta.llama2-13b-chat-v1',
+        credentials={
+            "aws_region": os.getenv("AWS_REGION"),
+            "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'max_tokens_to_sample': 100
+        },
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        print(chunk)
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+    model = BedrockLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='meta.llama2-13b-chat-v1',
+        credentials = {
+            "aws_region": os.getenv("AWS_REGION"),
+            "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+        },
+        messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert num_tokens == 18

+ 21 - 0
api/tests/integration_tests/model_runtime/bedrock/test_provider.py

@@ -0,0 +1,21 @@
+import os
+
+import pytest
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider
+
+def test_validate_provider_credentials():
+    provider = BedrockProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(
+            credentials={}
+        )
+
+    provider.validate_provider_credentials(
+        credentials={
+            "aws_region": os.getenv("AWS_REGION"),
+            "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+        }
+    )

Some files were not shown because too many files changed in this diff