Ver Fonte

Add Fireworks AI as new model provider (#8428)

ice yao há 7 meses atrás
pai
commit
0665268578
31 ficheiros alterados com 1682 adições e 2 exclusões
  1. 1 0
      api/core/model_runtime/model_providers/_position.yaml
  2. 0 0
      api/core/model_runtime/model_providers/fireworks/__init__.py
  3. 2 0
      api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg
  4. 5 0
      api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg
  5. 52 0
      api/core/model_runtime/model_providers/fireworks/_common.py
  6. 27 0
      api/core/model_runtime/model_providers/fireworks/fireworks.py
  7. 29 0
      api/core/model_runtime/model_providers/fireworks/fireworks.yaml
  8. 0 0
      api/core/model_runtime/model_providers/fireworks/llm/__init__.py
  9. 16 0
      api/core/model_runtime/model_providers/fireworks/llm/_position.yaml
  10. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml
  11. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml
  12. 45 0
      api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml
  13. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml
  14. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml
  15. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml
  16. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml
  17. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml
  18. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml
  19. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml
  20. 610 0
      api/core/model_runtime/model_providers/fireworks/llm/llm.py
  21. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml
  22. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml
  23. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml
  24. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml
  25. 46 0
      api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml
  26. 45 0
      api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml
  27. 1 0
      api/pyproject.toml
  28. 0 0
      api/tests/integration_tests/model_runtime/fireworks/__init__.py
  29. 186 0
      api/tests/integration_tests/model_runtime/fireworks/test_llm.py
  30. 17 0
      api/tests/integration_tests/model_runtime/fireworks/test_provider.py
  31. 2 2
      dev/pytest/pytest_model_runtime.sh

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -37,3 +37,4 @@
 - siliconflow
 - perfxcloud
 - zhinao
+- fireworks

+ 0 - 0
api/core/model_runtime/model_providers/fireworks/__init__.py


Diff do ficheiro suprimidas por serem muito extensas
+ 2 - 0
api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg


+ 5 - 0
api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg

@@ -0,0 +1,5 @@
+<svg width="638" height="315" viewBox="0 0 638 315" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M318.563 221.755C300.863 221.755 284.979 211.247 278.206 194.978L196.549 0H244.342L318.842 178.361L393.273 0H441.066L358.92 195.048C352.112 211.247 336.263 221.755 318.563 221.755Z" fill="#6720FF"/>
+<path d="M425.111 314.933C407.481 314.933 391.667 304.494 384.824 288.366C377.947 272.097 381.507 253.524 393.936 240.921L542.657 90.2803L561.229 134.094L425.076 271.748L619.147 270.666L637.72 314.479L425.146 315.003L425.076 314.933H425.111Z" fill="#6720FF"/>
+<path d="M0 314.408L18.5727 270.595L212.643 271.677L76.525 133.988L95.0977 90.1748L243.819 240.816C256.247 253.384 259.843 272.026 252.93 288.26C246.088 304.424 230.203 314.827 212.643 314.827L0.0698221 314.339L0 314.408Z" fill="#6720FF"/>
+</svg>

+ 52 - 0
api/core/model_runtime/model_providers/fireworks/_common.py

@@ -0,0 +1,52 @@
+from collections.abc import Mapping
+
+import openai
+
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+
+
+class _CommonFireworks:
+    def _to_credential_kwargs(self, credentials: Mapping) -> dict:
+        """
+        Transform credentials to kwargs for model instance
+
+        :param credentials:
+        :return:
+        """
+        credentials_kwargs = {
+            "api_key": credentials["fireworks_api_key"],
+            "base_url": "https://api.fireworks.ai/inference/v1",
+            "max_retries": 1,
+        }
+
+        return credentials_kwargs
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
+            InvokeServerUnavailableError: [openai.InternalServerError],
+            InvokeRateLimitError: [openai.RateLimitError],
+            InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
+            InvokeBadRequestError: [
+                openai.BadRequestError,
+                openai.NotFoundError,
+                openai.UnprocessableEntityError,
+                openai.APIError,
+            ],
+        }

+ 27 - 0
api/core/model_runtime/model_providers/fireworks/fireworks.py

@@ -0,0 +1,27 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class FireworksProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+            model_instance.validate_credentials(
+                model="accounts/fireworks/models/llama-v3p1-8b-instruct", credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
+            raise ex

+ 29 - 0
api/core/model_runtime/model_providers/fireworks/fireworks.yaml

@@ -0,0 +1,29 @@
+provider: fireworks
+label:
+  zh_Hans: Fireworks AI
+  en_US: Fireworks AI
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+background: "#FCFDFF"
+help:
+  title:
+    en_US: Get your API Key from Fireworks AI
+    zh_Hans: 从 Fireworks AI 获取 API Key
+  url:
+    en_US: https://fireworks.ai/account/api-keys
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: fireworks_api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key

+ 0 - 0
api/core/model_runtime/model_providers/fireworks/llm/__init__.py


+ 16 - 0
api/core/model_runtime/model_providers/fireworks/llm/_position.yaml

@@ -0,0 +1,16 @@
+- llama-v3p1-405b-instruct
+- llama-v3p1-70b-instruct
+- llama-v3p1-8b-instruct
+- llama-v3-70b-instruct
+- mixtral-8x22b-instruct
+- mixtral-8x7b-instruct
+- firefunction-v2
+- firefunction-v1
+- gemma2-9b-it
+- llama-v3-70b-instruct-hf
+- llama-v3-8b-instruct
+- llama-v3-8b-instruct-hf
+- mixtral-8x7b-instruct-hf
+- mythomax-l2-13b
+- phi-3-vision-128k-instruct
+- yi-large

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/firefunction-v1
+label:
+  zh_Hans: Firefunction V1
+  en_US: Firefunction V1
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 32768
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.5'
+  output: '0.5'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/firefunction-v2
+label:
+  zh_Hans: Firefunction V2
+  en_US: Firefunction V2
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.9'
+  output: '0.9'
+  unit: '0.000001'
+  currency: USD

+ 45 - 0
api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml

@@ -0,0 +1,45 @@
+model: accounts/fireworks/models/gemma2-9b-it
+label:
+  zh_Hans: Gemma2 9B Instruct
+  en_US: Gemma2 9B Instruct
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-70b-instruct-hf
+label:
+  zh_Hans: Llama3 70B Instruct(HF version)
+  en_US: Llama3 70B Instruct(HF version)
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.9'
+  output: '0.9'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-70b-instruct
+label:
+  zh_Hans: Llama3 70B Instruct
+  en_US: Llama3 70B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.9'
+  output: '0.9'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-8b-instruct-hf
+label:
+  zh_Hans: Llama3 8B Instruct(HF version)
+  en_US: Llama3 8B Instruct(HF version)
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-8b-instruct
+label:
+  zh_Hans: Llama3 8B Instruct
+  en_US: Llama3 8B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3p1-405b-instruct
+label:
+  zh_Hans: Llama3.1 405B Instruct
+  en_US: Llama3.1 405B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 131072
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '3'
+  output: '3'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3p1-70b-instruct
+label:
+  zh_Hans: Llama3.1 70B Instruct
+  en_US: Llama3.1 70B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 131072
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3p1-8b-instruct
+label:
+  zh_Hans: Llama3.1 8B Instruct
+  en_US: Llama3.1 8B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 131072
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 610 - 0
api/core/model_runtime/model_providers/fireworks/llm/llm.py

@@ -0,0 +1,610 @@
+import logging
+from collections.abc import Generator
+from typing import Optional, Union, cast
+
+from openai import OpenAI, Stream
+from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
+from openai.types.chat.chat_completion_message import FunctionCall
+
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContentType,
+    PromptMessageTool,
+    SystemPromptMessage,
+    TextPromptMessageContent,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
+
+logger = logging.getLogger(__name__)
+
+FIREWORKS_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+<instructions>
+{{instructions}}
+</instructions>
+"""  # noqa: E501
+
+
+class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
+    """
+    Model class for Fireworks large language model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: Optional[list[PromptMessageTool]] = None,
+        stop: Optional[list[str]] = None,
+        stream: bool = True,
+        user: Optional[str] = None,
+    ) -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+
+        return self._chat_generate(
+            model=model,
+            credentials=credentials,
+            prompt_messages=prompt_messages,
+            model_parameters=model_parameters,
+            tools=tools,
+            stop=stop,
+            stream=stream,
+            user=user,
+        )
+
+    def _code_block_mode_wrapper(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: Optional[list[PromptMessageTool]] = None,
+        stop: Optional[list[str]] = None,
+        stream: bool = True,
+        user: Optional[str] = None,
+        callbacks: Optional[list[Callback]] = None,
+    ) -> Union[LLMResult, Generator]:
+        """
+        Code block mode wrapper for invoking large language model
+        """
+        if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
+            stop = stop or []
+            self._transform_chat_json_prompts(
+                model=model,
+                credentials=credentials,
+                prompt_messages=prompt_messages,
+                model_parameters=model_parameters,
+                tools=tools,
+                stop=stop,
+                stream=stream,
+                user=user,
+                response_format=model_parameters["response_format"],
+            )
+            model_parameters.pop("response_format")
+
+            return self._invoke(
+                model=model,
+                credentials=credentials,
+                prompt_messages=prompt_messages,
+                model_parameters=model_parameters,
+                tools=tools,
+                stop=stop,
+                stream=stream,
+                user=user,
+            )
+
+    def _transform_chat_json_prompts(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: list[PromptMessageTool] | None = None,
+        stop: list[str] | None = None,
+        stream: bool = True,
+        user: str | None = None,
+        response_format: str = "JSON",
+    ) -> None:
+        """
+        Transform json prompts
+        """
+        if stop is None:
+            stop = []
+        if "```\n" not in stop:
+            stop.append("```\n")
+        if "\n```" not in stop:
+            stop.append("\n```")
+
+        if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+            prompt_messages[0] = SystemPromptMessage(
+                content=FIREWORKS_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
+                    "{{block}}", response_format
+                )
+            )
+            prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
+        else:
+            prompt_messages.insert(
+                0,
+                SystemPromptMessage(
+                    content=FIREWORKS_BLOCK_MODE_PROMPT.replace(
+                        "{{instructions}}", f"Please output a valid {response_format} object."
+                    ).replace("{{block}}", response_format)
+                ),
+            )
+            prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
+
+    def get_num_tokens(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None,
+    ) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:
+        """
+        return self._num_tokens_from_messages(model, prompt_messages, tools)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            credentials_kwargs = self._to_credential_kwargs(credentials)
+            client = OpenAI(**credentials_kwargs)
+
+            client.chat.completions.create(
+                messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False
+            )
+        except Exception as e:
+            raise CredentialsValidateFailedError(str(e))
+
+    def _chat_generate(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: Optional[list[PromptMessageTool]] = None,
+        stop: Optional[list[str]] = None,
+        stream: bool = True,
+        user: Optional[str] = None,
+    ) -> Union[LLMResult, Generator]:
+        credentials_kwargs = self._to_credential_kwargs(credentials)
+        client = OpenAI(**credentials_kwargs)
+
+        extra_model_kwargs = {}
+
+        if tools:
+            extra_model_kwargs["functions"] = [
+                {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools
+            ]
+
+        if stop:
+            extra_model_kwargs["stop"] = stop
+
+        if user:
+            extra_model_kwargs["user"] = user
+
+        # chat model
+        response = client.chat.completions.create(
+            messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+            model=model,
+            stream=stream,
+            **model_parameters,
+            **extra_model_kwargs,
+        )
+
+        if stream:
+            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
+        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
+
+    def _handle_chat_generate_response(
+        self,
+        model: str,
+        credentials: dict,
+        response: ChatCompletion,
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None,
+    ) -> LLMResult:
+        """
+        Handle llm chat response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return: llm response
+        """
+        assistant_message = response.choices[0].message
+        # assistant_message_tool_calls = assistant_message.tool_calls
+        assistant_message_function_call = assistant_message.function_call
+
+        # extract tool calls from response
+        # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+        function_call = self._extract_response_function_call(assistant_message_function_call)
+        tool_calls = [function_call] if function_call else []
+
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
+
+        # calculate num tokens
+        if response.usage:
+            # transform usage
+            prompt_tokens = response.usage.prompt_tokens
+            completion_tokens = response.usage.completion_tokens
+        else:
+            # calculate num tokens
+            prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
+            completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # transform response
+        response = LLMResult(
+            model=response.model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage,
+            system_fingerprint=response.system_fingerprint,
+        )
+
+        return response
+
+    def _handle_chat_generate_stream_response(
+        self,
+        model: str,
+        credentials: dict,
+        response: Stream[ChatCompletionChunk],
+        prompt_messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None,
+    ) -> Generator:
+        """
+        Handle llm chat stream response
+
+        :param model: model name
+        :param response: response
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return: llm response chunk generator
+        """
+        full_assistant_content = ""
+        delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None
+        prompt_tokens = 0
+        completion_tokens = 0
+        final_tool_calls = []
+        final_chunk = LLMResultChunk(
+            model=model,
+            prompt_messages=prompt_messages,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=AssistantPromptMessage(content=""),
+            ),
+        )
+
+        for chunk in response:
+            if len(chunk.choices) == 0:
+                if chunk.usage:
+                    # calculate num tokens
+                    prompt_tokens = chunk.usage.prompt_tokens
+                    completion_tokens = chunk.usage.completion_tokens
+                continue
+
+            delta = chunk.choices[0]
+            has_finish_reason = delta.finish_reason is not None
+
+            if (
+                not has_finish_reason
+                and (delta.delta.content is None or delta.delta.content == "")
+                and delta.delta.function_call is None
+            ):
+                continue
+
+            # assistant_message_tool_calls = delta.delta.tool_calls
+            assistant_message_function_call = delta.delta.function_call
+
+            # extract tool calls from response
+            if delta_assistant_message_function_call_storage is not None:
+                # handle process of stream function call
+                if assistant_message_function_call:
+                    # message has not ended ever
+                    delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
+                    continue
+                else:
+                    # message has ended
+                    assistant_message_function_call = delta_assistant_message_function_call_storage
+                    delta_assistant_message_function_call_storage = None
+            else:
+                if assistant_message_function_call:
+                    # start of stream function call
+                    delta_assistant_message_function_call_storage = assistant_message_function_call
+                    if delta_assistant_message_function_call_storage.arguments is None:
+                        delta_assistant_message_function_call_storage.arguments = ""
+                    if not has_finish_reason:
+                        continue
+
+            # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+            function_call = self._extract_response_function_call(assistant_message_function_call)
+            tool_calls = [function_call] if function_call else []
+            if tool_calls:
+                final_tool_calls.extend(tool_calls)
+
+            # transform assistant message to prompt message
+            assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
+
+            full_assistant_content += delta.delta.content or ""
+
+            if has_finish_reason:
+                final_chunk = LLMResultChunk(
+                    model=chunk.model,
+                    prompt_messages=prompt_messages,
+                    system_fingerprint=chunk.system_fingerprint,
+                    delta=LLMResultChunkDelta(
+                        index=delta.index,
+                        message=assistant_prompt_message,
+                        finish_reason=delta.finish_reason,
+                    ),
+                )
+            else:
+                yield LLMResultChunk(
+                    model=chunk.model,
+                    prompt_messages=prompt_messages,
+                    system_fingerprint=chunk.system_fingerprint,
+                    delta=LLMResultChunkDelta(
+                        index=delta.index,
+                        message=assistant_prompt_message,
+                    ),
+                )
+
+        if not prompt_tokens:
+            prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
+
+        if not completion_tokens:
+            full_assistant_prompt_message = AssistantPromptMessage(
+                content=full_assistant_content, tool_calls=final_tool_calls
+            )
+            completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+        final_chunk.delta.usage = usage
+
+        yield final_chunk
+
+    def _extract_response_tool_calls(
+        self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]
+    ) -> list[AssistantPromptMessage.ToolCall]:
+        """
+        Extract tool calls from response
+
+        :param response_tool_calls: response tool calls
+        :return: list of tool calls
+        """
+        tool_calls = []
+        if response_tool_calls:
+            for response_tool_call in response_tool_calls:
+                function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                    name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
+                )
+
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=response_tool_call.id, type=response_tool_call.type, function=function
+                )
+                tool_calls.append(tool_call)
+
+        return tool_calls
+
+    def _extract_response_function_call(
+        self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
+    ) -> AssistantPromptMessage.ToolCall:
+        """
+        Extract function call from response
+
+        :param response_function_call: response function call
+        :return: tool call
+        """
+        tool_call = None
+        if response_function_call:
+            function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                name=response_function_call.name, arguments=response_function_call.arguments
+            )
+
+            tool_call = AssistantPromptMessage.ToolCall(
+                id=response_function_call.name, type="function", function=function
+            )
+
+        return tool_call
+
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+        """
+        Convert PromptMessage to dict for Fireworks API
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "user", "content": message.content}
+            else:
+                sub_messages = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(TextPromptMessageContent, message_content)
+                        sub_message_dict = {"type": "text", "text": message_content.data}
+                        sub_messages.append(sub_message_dict)
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "image_url",
+                            "image_url": {"url": message_content.data, "detail": message_content.detail.value},
+                        }
+                        sub_messages.append(sub_message_dict)
+
+                message_dict = {"role": "user", "content": sub_messages}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "assistant", "content": message.content}
+            if message.tool_calls:
+                # message_dict["tool_calls"] = [tool_call.dict() for tool_call in
+                #                               message.tool_calls]
+                function_call = message.tool_calls[0]
+                message_dict["function_call"] = {
+                    "name": function_call.function.name,
+                    "arguments": function_call.function.arguments,
+                }
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            # message_dict = {
+            #     "role": "tool",
+            #     "content": message.content,
+            #     "tool_call_id": message.tool_call_id
+            # }
+            message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        if message.name:
+            message_dict["name"] = message.name
+
+        return message_dict
+
+    def _num_tokens_from_messages(
+        self,
+        model: str,
+        messages: list[PromptMessage],
+        tools: Optional[list[PromptMessageTool]] = None,
+        credentials: dict = None,
+    ) -> int:
+        """
+        Approximate num tokens with GPT2 tokenizer.
+        """
+
+        tokens_per_message = 3
+        tokens_per_name = 1
+
+        num_tokens = 0
+        messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+        for message in messages_dict:
+            num_tokens += tokens_per_message
+            for key, value in message.items():
+                # Cast str(value) in case the message value is not a string
+                # This occurs with function messages
+                # TODO: The current token calculation method for the image type is not implemented,
+                #  which need to download the image and then get the resolution for calculation,
+                #  and will increase the request delay
+                if isinstance(value, list):
+                    text = ""
+                    for item in value:
+                        if isinstance(item, dict) and item["type"] == "text":
+                            text += item["text"]
+
+                    value = text
+
+                if key == "tool_calls":
+                    for tool_call in value:
+                        for t_key, t_value in tool_call.items():
+                            num_tokens += self._get_num_tokens_by_gpt2(t_key)
+                            if t_key == "function":
+                                for f_key, f_value in t_value.items():
+                                    num_tokens += self._get_num_tokens_by_gpt2(f_key)
+                                    num_tokens += self._get_num_tokens_by_gpt2(f_value)
+                            else:
+                                num_tokens += self._get_num_tokens_by_gpt2(t_key)
+                                num_tokens += self._get_num_tokens_by_gpt2(t_value)
+                else:
+                    num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+                if key == "name":
+                    num_tokens += tokens_per_name
+
+        # every reply is primed with <im_start>assistant
+        num_tokens += 3
+
+        if tools:
+            num_tokens += self._num_tokens_for_tools(tools)
+
+        return num_tokens
+
+    def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
+        """
+        Calculate num tokens for tool calling with tiktoken package.
+
+        :param tools: tools for tool calling
+        :return: number of tokens
+        """
+        num_tokens = 0
+        for tool in tools:
+            num_tokens += self._get_num_tokens_by_gpt2("type")
+            num_tokens += self._get_num_tokens_by_gpt2("function")
+            num_tokens += self._get_num_tokens_by_gpt2("function")
+
+            # calculate num tokens for function object
+            num_tokens += self._get_num_tokens_by_gpt2("name")
+            num_tokens += self._get_num_tokens_by_gpt2(tool.name)
+            num_tokens += self._get_num_tokens_by_gpt2("description")
+            num_tokens += self._get_num_tokens_by_gpt2(tool.description)
+            parameters = tool.parameters
+            num_tokens += self._get_num_tokens_by_gpt2("parameters")
+            if "title" in parameters:
+                num_tokens += self._get_num_tokens_by_gpt2("title")
+                num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
+            num_tokens += self._get_num_tokens_by_gpt2("type")
+            num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
+            if "properties" in parameters:
+                num_tokens += self._get_num_tokens_by_gpt2("properties")
+                for key, value in parameters.get("properties").items():
+                    num_tokens += self._get_num_tokens_by_gpt2(key)
+                    for field_key, field_value in value.items():
+                        num_tokens += self._get_num_tokens_by_gpt2(field_key)
+                        if field_key == "enum":
+                            for enum_field in field_value:
+                                num_tokens += 3
+                                num_tokens += self._get_num_tokens_by_gpt2(enum_field)
+                        else:
+                            num_tokens += self._get_num_tokens_by_gpt2(field_key)
+                            num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
+            if "required" in parameters:
+                num_tokens += self._get_num_tokens_by_gpt2("required")
+                for required_field in parameters["required"]:
+                    num_tokens += 3
+                    num_tokens += self._get_num_tokens_by_gpt2(required_field)
+
+        return num_tokens

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mixtral-8x22b-instruct
+label:
+  zh_Hans: Mixtral MoE 8x22B Instruct
+  en_US: Mixtral MoE 8x22B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 65536
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '1.2'
+  output: '1.2'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mixtral-8x7b-instruct-hf
+label:
+  zh_Hans: Mixtral MoE 8x7B Instruct(HF version)
+  en_US: Mixtral MoE 8x7B Instruct(HF version)
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 32768
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.5'
+  output: '0.5'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mixtral-8x7b-instruct
+label:
+  zh_Hans: Mixtral MoE 8x7B Instruct
+  en_US: Mixtral MoE 8x7B Instruct
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 32768
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.5'
+  output: '0.5'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mythomax-l2-13b
+label:
+  zh_Hans: MythoMax L2 13b
+  en_US: MythoMax L2 13b
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 46 - 0
api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml

@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/phi-3-vision-128k-instruct
+label:
+  zh_Hans: Phi3.5 Vision Instruct
+  en_US: Phi3.5 Vision Instruct
+model_type: llm
+features:
+  - agent-thought
+  - vision
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '0.2'
+  output: '0.2'
+  unit: '0.000001'
+  currency: USD

+ 45 - 0
api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml

@@ -0,0 +1,45 @@
+model: accounts/yi-01-ai/models/yi-large
+label:
+  zh_Hans: Yi-Large
+  en_US: Yi-Large
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 32768
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+  - name: max_tokens
+    use_template: max_tokens
+  - name: context_length_exceeded_behavior
+    default: None
+    label:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    help:
+      zh_Hans: 上下文长度超出行为
+      en_US: Context Length Exceeded Behavior
+    type: string
+    options:
+      - None
+      - truncate
+      - error
+  - name: response_format
+    use_template: response_format
+pricing:
+  input: '3'
+  output: '3'
+  unit: '0.000001'
+  currency: USD

+ 1 - 0
api/pyproject.toml

@@ -100,6 +100,7 @@ exclude = [
 [tool.pytest_env]
 OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
 UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa"
+FIREWORKS_API_KEY = "fw_aaaaaaaaaaaaaaaaaaaa"
 AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
 AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
 ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"

+ 0 - 0
api/tests/integration_tests/model_runtime/fireworks/__init__.py


+ 186 - 0
api/tests/integration_tests/model_runtime/fireworks/test_llm.py

@@ -0,0 +1,186 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel
+
+"""FOR MOCK FIXTURES, DO NOT REMOVE"""
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+def test_predefined_models():
+    model = FireworksLargeLanguageModel()
+    model_schemas = model.predefined_models()
+
+    assert len(model_schemas) >= 1
+    assert isinstance(model_schemas[0], AIModelEntity)
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_validate_credentials_for_chat_model(setup_openai_mock):
+    model = FireworksLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        # model name to gpt-3.5-turbo because of mocking
+        model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"})
+
+    model.validate_credentials(
+        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+        credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+    )
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_chat_model(setup_openai_mock):
+    model = FireworksLargeLanguageModel()
+
+    result = model.invoke(
+        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+        credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        model_parameters={
+            "temperature": 0.0,
+            "top_p": 1.0,
+            "presence_penalty": 0.0,
+            "frequency_penalty": 0.0,
+            "max_tokens": 10,
+        },
+        stop=["How"],
+        stream=False,
+        user="foo",
+    )
+
+    assert isinstance(result, LLMResult)
+    assert len(result.message.content) > 0
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_chat_model_with_tools(setup_openai_mock):
+    model = FireworksLargeLanguageModel()
+
+    result = model.invoke(
+        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+        credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(
+                content="what's the weather today in London?",
+            ),
+        ],
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
+        tools=[
+            PromptMessageTool(
+                name="get_weather",
+                description="Determine weather in my location",
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
+                    },
+                    "required": ["location"],
+                },
+            ),
+            PromptMessageTool(
+                name="get_stock_price",
+                description="Get the current stock price",
+                parameters={
+                    "type": "object",
+                    "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
+                    "required": ["symbol"],
+                },
+            ),
+        ],
+        stream=False,
+        user="foo",
+    )
+
+    assert isinstance(result, LLMResult)
+    assert isinstance(result.message, AssistantPromptMessage)
+    assert len(result.message.tool_calls) > 0
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_stream_chat_model(setup_openai_mock):
+    model = FireworksLargeLanguageModel()
+
+    result = model.invoke(
+        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+        credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
+        stream=True,
+        user="foo",
+    )
+
+    assert isinstance(result, Generator)
+
+    for chunk in result:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+        if chunk.delta.finish_reason is not None:
+            assert chunk.delta.usage is not None
+            assert chunk.delta.usage.completion_tokens > 0
+
+
+def test_get_num_tokens():
+    model = FireworksLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+        credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+    )
+
+    assert num_tokens == 10
+
+    num_tokens = model.get_num_tokens(
+        model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+        credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        tools=[
+            PromptMessageTool(
+                name="get_weather",
+                description="Determine weather in my location",
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
+                    },
+                    "required": ["location"],
+                },
+            ),
+        ],
+    )
+
+    assert num_tokens == 77

+ 17 - 0
api/tests/integration_tests/model_runtime/fireworks/test_provider.py

@@ -0,0 +1,17 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_validate_provider_credentials(setup_openai_mock):
+    provider = FireworksProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(credentials={})
+
+    provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")})

+ 2 - 2
dev/pytest/pytest_model_runtime.sh

@@ -6,5 +6,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
   api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
   api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
   api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
-  api/tests/integration_tests/model_runtime/upstage
-
+  api/tests/integration_tests/model_runtime/upstage \
+  api/tests/integration_tests/model_runtime/fireworks

Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff