Pārlūkot izejas kodu

Add VESSL AI OpenAI API-compatible model provider and LLM model (#9474)

Co-authored-by: moon <moon@vessl.ai>
larcane97 5 mēneši atpakaļ
vecāks
revīzija
8d5456b6d0

+ 0 - 0
api/core/model_runtime/model_providers/vessl_ai/__init__.py


BIN
api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png


+ 3 - 0
api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg

@@ -0,0 +1,3 @@
+<svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/>
+</svg>

+ 0 - 0
api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py


+ 83 - 0
api/core/model_runtime/model_providers/vessl_ai/llm/llm.py

@@ -0,0 +1,83 @@
+from decimal import Decimal
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    DefaultParameterName,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+    PriceConfig,
+)
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        features = []
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.LLM,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            features=features,
+            model_properties={
+                ModelPropertyKey.MODE: credentials.get("mode"),
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name=DefaultParameterName.TEMPERATURE.value,
+                    label=I18nObject(en_US="Temperature"),
+                    type=ParameterType.FLOAT,
+                    default=float(credentials.get("temperature", 0.7)),
+                    min=0,
+                    max=2,
+                    precision=2,
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.TOP_P.value,
+                    label=I18nObject(en_US="Top P"),
+                    type=ParameterType.FLOAT,
+                    default=float(credentials.get("top_p", 1)),
+                    min=0,
+                    max=1,
+                    precision=2,
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.TOP_K.value,
+                    label=I18nObject(en_US="Top K"),
+                    type=ParameterType.INT,
+                    default=int(credentials.get("top_k", 50)),
+                    min=-2147483647,
+                    max=2147483647,
+                    precision=0,
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.MAX_TOKENS.value,
+                    label=I18nObject(en_US="Max Tokens"),
+                    type=ParameterType.INT,
+                    default=512,
+                    min=1,
+                    max=int(credentials.get("max_tokens_to_sample", 4096)),
+                ),
+            ],
+            pricing=PriceConfig(
+                input=Decimal(credentials.get("input_price", 0)),
+                output=Decimal(credentials.get("output_price", 0)),
+                unit=Decimal(credentials.get("unit", 0)),
+                currency=credentials.get("currency", "USD"),
+            ),
+        )
+
+        if credentials["mode"] == "chat":
+            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
+        elif credentials["mode"] == "completion":
+            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
+        else:
+            raise ValueError(f"Unknown completion type {credentials['completion_type']}")
+
+        return entity

+ 10 - 0
api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py

@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VesslAIProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 56 - 0
api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml

@@ -0,0 +1,56 @@
+provider: vessl_ai
+label:
+  en_US: vessl_ai
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.png
+background: "#F1EFED"
+help:
+  title:
+    en_US: How to deploy VESSL AI LLM Model Endpoint
+  url:
+    en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
+supported_model_types:
+  - llm
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+    placeholder:
+      en_US: Enter your model name
+  credential_form_schemas:
+    - variable: endpoint_url
+      label:
+        en_US: endpoint url
+      type: text-input
+      required: true
+      placeholder:
+        en_US: Enter the url of your endpoint url
+    - variable: api_key
+      required: true
+      label:
+        en_US: API Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your VESSL AI secret key
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Completion mode
+      type: select
+      required: false
+      default: chat
+      placeholder:
+        en_US: Select completion mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+        - value: chat
+          label:
+            en_US: Chat

+ 6 - 1
api/tests/integration_tests/.env.example

@@ -84,5 +84,10 @@ VOLC_EMBEDDING_ENDPOINT_ID=
 # 360 AI Credentials
 ZHINAO_API_KEY=
 
+# VESSL AI Credentials
+VESSL_AI_MODEL_NAME=
+VESSL_AI_API_KEY=
+VESSL_AI_ENDPOINT_URL=
+
 # Gitee AI Credentials
-GITEE_AI_API_KEY=
+GITEE_AI_API_KEY=

+ 0 - 0
api/tests/integration_tests/model_runtime/vessl_ai/__init__.py


+ 131 - 0
api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py

@@ -0,0 +1,131 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel
+
+
+def test_validate_credentials():
+    model = VesslAILargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model=os.environ.get("VESSL_AI_MODEL_NAME"),
+            credentials={
+                "api_key": "invalid_key",
+                "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
+                "mode": "chat",
+            },
+        )
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model=os.environ.get("VESSL_AI_MODEL_NAME"),
+            credentials={
+                "api_key": os.environ.get("VESSL_AI_API_KEY"),
+                "endpoint_url": "http://invalid_url",
+                "mode": "chat",
+            },
+        )
+
+    model.validate_credentials(
+        model=os.environ.get("VESSL_AI_MODEL_NAME"),
+        credentials={
+            "api_key": os.environ.get("VESSL_AI_API_KEY"),
+            "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
+            "mode": "chat",
+        },
+    )
+
+
+def test_invoke_model():
+    model = VesslAILargeLanguageModel()
+
+    response = model.invoke(
+        model=os.environ.get("VESSL_AI_MODEL_NAME"),
+        credentials={
+            "api_key": os.environ.get("VESSL_AI_API_KEY"),
+            "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
+            "mode": "chat",
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Who are you?"),
+        ],
+        model_parameters={
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
+        },
+        stop=["How"],
+        stream=False,
+        user="abc-123",
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = VesslAILargeLanguageModel()
+
+    response = model.invoke(
+        model=os.environ.get("VESSL_AI_MODEL_NAME"),
+        credentials={
+            "api_key": os.environ.get("VESSL_AI_API_KEY"),
+            "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
+            "mode": "chat",
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Who are you?"),
+        ],
+        model_parameters={
+            "temperature": 1.0,
+            "top_k": 2,
+            "top_p": 0.5,
+        },
+        stop=["How"],
+        stream=True,
+        user="abc-123",
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_get_num_tokens():
+    model = VesslAILargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model=os.environ.get("VESSL_AI_MODEL_NAME"),
+        credentials={
+            "api_key": os.environ.get("VESSL_AI_API_KEY"),
+            "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 21