浏览代码

feat: add xAI model provider (#10272)

非法操作 5 月之前
父节点
当前提交
bf9349c4dc

+ 0 - 0
api/core/model_runtime/model_providers/x/__init__.py


+ 1 - 0
api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg

@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>

+ 0 - 0
api/core/model_runtime/model_providers/x/llm/__init__.py


+ 63 - 0
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml

@@ -0,0 +1,63 @@
+model: grok-beta
+label:
+  en_US: Grok beta
+model_type: llm
+features:
+  - multi-tool-call
+model_properties:
+  mode: chat
+  context_size: 131072
+parameter_rules:
+  - name: temperature
+    label:
+      en_US: "Temperature"
+      zh_Hans: "采样温度"
+    type: float
+    default: 0.7
+    min: 0.0
+    max: 2.0
+    precision: 1
+    required: true
+    help:
+      en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
+      zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
+
+  - name: top_p
+    label:
+      en_US: "Top P"
+      zh_Hans: "Top P"
+    type: float
+    default: 0.7
+    min: 0.0
+    max: 1.0
+    precision: 1
+    required: true
+    help:
+      en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
+      zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
+
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    label:
+      en_US: "Frequency Penalty"
+      zh_Hans: "频率惩罚"
+    type: float
+    default: 0
+    min: 0
+    max: 2.0
+    precision: 1
+    required: false
+    help:
+      en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
+      zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
+
+  - name: user
+    use_template: text
+    label:
+      en_US: "User"
+      zh_Hans: "用户"
+    type: string
+    required: false
+    help:
+      en_US: "Used to track and differentiate conversation requests from different users."
+      zh_Hans: "用于追踪和区分不同用户的对话请求。"

+ 37 - 0
api/core/model_runtime/model_providers/x/llm/llm.py

@@ -0,0 +1,37 @@
+from collections.abc import Generator
+from typing import Optional, Union
+
+from yarl import URL
+
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
+from core.model_runtime.entities.message_entities import (
+    PromptMessage,
+    PromptMessageTool,
+)
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: Optional[list[PromptMessageTool]] = None,
+        stop: Optional[list[str]] = None,
+        stream: bool = True,
+        user: Optional[str] = None,
+    ) -> Union[LLMResult, Generator]:
+        self._add_custom_parameters(credentials)
+        return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        self._add_custom_parameters(credentials)
+        super().validate_credentials(model, credentials)
+
+    @staticmethod
+    def _add_custom_parameters(credentials) -> None:
+        credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
+        credentials["mode"] = LLMMode.CHAT.value
+        credentials["function_calling_type"] = "tool_call"

+ 25 - 0
api/core/model_runtime/model_providers/x/x.py

@@ -0,0 +1,25 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class XAIProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+            model_instance.validate_credentials(model="grok-beta", credentials=credentials)
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
+            raise ex

+ 38 - 0
api/core/model_runtime/model_providers/x/x.yaml

@@ -0,0 +1,38 @@
+provider: x
+label:
+  en_US: xAI
+description:
+  en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
+icon_small:
+  en_US: x-ai-logo.svg
+icon_large:
+  en_US: x-ai-logo.svg
+help:
+  title:
+    en_US: Get your token from xAI
+    zh_Hans: 从 xAI 获取 token
+  url:
+    en_US: https://x.ai/api
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: endpoint_url
+      label:
+        en_US: API Base
+      type: text-input
+      required: false
+      default: https://api.x.ai/v1
+      placeholder:
+        zh_Hans: 在此输入您的 API Base
+        en_US: Enter your API Base

+ 4 - 0
api/tests/integration_tests/.env.example

@@ -95,3 +95,7 @@ GPUSTACK_API_KEY=
 
 # Gitee AI Credentials
 GITEE_AI_API_KEY=
+
+# xAI Credentials
+XAI_API_KEY=
+XAI_API_BASE=

+ 0 - 0
api/tests/integration_tests/model_runtime/x/__init__.py


+ 204 - 0
api/tests/integration_tests/model_runtime/x/test_llm.py

@@ -0,0 +1,204 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
+
+"""FOR MOCK FIXTURES, DO NOT REMOVE"""
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+def test_predefined_models():
+    model = XAILargeLanguageModel()
+    model_schemas = model.predefined_models()
+
+    assert len(model_schemas) >= 1
+    assert isinstance(model_schemas[0], AIModelEntity)
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_validate_credentials_for_chat_model(setup_openai_mock):
+    model = XAILargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        # model name to gpt-3.5-turbo because of mocking
+        model.validate_credentials(
+            model="gpt-3.5-turbo",
+            credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
+        )
+
+    model.validate_credentials(
+        model="grok-beta",
+        credentials={
+            "api_key": os.environ.get("XAI_API_KEY"),
+            "endpoint_url": os.environ.get("XAI_API_BASE"),
+            "mode": "chat",
+        },
+    )
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_chat_model(setup_openai_mock):
+    model = XAILargeLanguageModel()
+
+    result = model.invoke(
+        model="grok-beta",
+        credentials={
+            "api_key": os.environ.get("XAI_API_KEY"),
+            "endpoint_url": os.environ.get("XAI_API_BASE"),
+            "mode": "chat",
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        model_parameters={
+            "temperature": 0.0,
+            "top_p": 1.0,
+            "presence_penalty": 0.0,
+            "frequency_penalty": 0.0,
+            "max_tokens": 10,
+        },
+        stop=["How"],
+        stream=False,
+        user="foo",
+    )
+
+    assert isinstance(result, LLMResult)
+    assert len(result.message.content) > 0
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_chat_model_with_tools(setup_openai_mock):
+    model = XAILargeLanguageModel()
+
+    result = model.invoke(
+        model="grok-beta",
+        credentials={
+            "api_key": os.environ.get("XAI_API_KEY"),
+            "endpoint_url": os.environ.get("XAI_API_BASE"),
+            "mode": "chat",
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(
+                content="what's the weather today in London?",
+            ),
+        ],
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
+        tools=[
+            PromptMessageTool(
+                name="get_weather",
+                description="Determine weather in my location",
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
+                    },
+                    "required": ["location"],
+                },
+            ),
+            PromptMessageTool(
+                name="get_stock_price",
+                description="Get the current stock price",
+                parameters={
+                    "type": "object",
+                    "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
+                    "required": ["symbol"],
+                },
+            ),
+        ],
+        stream=False,
+        user="foo",
+    )
+
+    assert isinstance(result, LLMResult)
+    assert isinstance(result.message, AssistantPromptMessage)
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_stream_chat_model(setup_openai_mock):
+    model = XAILargeLanguageModel()
+
+    result = model.invoke(
+        model="grok-beta",
+        credentials={
+            "api_key": os.environ.get("XAI_API_KEY"),
+            "endpoint_url": os.environ.get("XAI_API_BASE"),
+            "mode": "chat",
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        model_parameters={"temperature": 0.0, "max_tokens": 100},
+        stream=True,
+        user="foo",
+    )
+
+    assert isinstance(result, Generator)
+
+    for chunk in result:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+        if chunk.delta.finish_reason is not None:
+            assert chunk.delta.usage is not None
+            assert chunk.delta.usage.completion_tokens > 0
+
+
+def test_get_num_tokens():
+    model = XAILargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model="grok-beta",
+        credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+    )
+
+    assert num_tokens == 10
+
+    num_tokens = model.get_num_tokens(
+        model="grok-beta",
+        credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        tools=[
+            PromptMessageTool(
+                name="get_weather",
+                description="Determine weather in my location",
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["c", "f"]},
+                    },
+                    "required": ["location"],
+                },
+            ),
+        ],
+    )
+
+    assert num_tokens == 77