Переглянути джерело

feat: add gpustack model provider (#10158)

Lawrence Li 5 місяців тому
батько
коміт
76b0328eb1

BIN
api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png


Різницю між файлами не показано, бо вона завелика
+ 2 - 0
api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg


BIN
api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png


+ 11 - 0
api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg

@@ -0,0 +1,11 @@
+<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
+<rect width="24" height="24" rx="6" fill="url(#paint0_linear_7301_16076)"/>
+<path d="M20 12.0116C15.7043 12.42 12.3692 15.757 11.9995 20C11.652 15.8183 8.20301 12.361 4 12.0181C8.21855 11.6991 11.6656 8.1853 12.006 4C12.2833 8.19653 15.8057 11.7005 20 12.0116Z" fill="white" fill-opacity="0.88"/>
+<defs>
+<linearGradient id="paint0_linear_7301_16076" x1="-9" y1="29.5" x2="19.4387" y2="1.43791" gradientUnits="userSpaceOnUse">
+<stop offset="0.192878" stop-color="#1C7DFF"/>
+<stop offset="0.520213" stop-color="#1C69FF"/>
+<stop offset="1" stop-color="#F0DCD6"/>
+</linearGradient>
+</defs>
+</svg>

+ 10 - 0
api/core/model_runtime/model_providers/gpustack/gpustack.py

@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class GPUStackProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 120 - 0
api/core/model_runtime/model_providers/gpustack/gpustack.yaml

@@ -0,0 +1,120 @@
+provider: gpustack
+label:
+  en_US: GPUStack
+icon_small:
+  en_US: icon_s_en.png
+icon_large:
+  en_US: icon_l_en.png
+supported_model_types:
+  - llm
+  - text-embedding
+  - rerank
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: endpoint_url
+      label:
+        zh_Hans: 服务器地址
+        en_US: Server URL
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
+        en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 输入您的 API Key
+        en_US: Enter your API Key
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Completion mode
+      type: select
+      required: false
+      default: chat
+      placeholder:
+        zh_Hans: 选择补全类型
+        en_US: Select completion type
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: "8192"
+      placeholder:
+        zh_Hans: 输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens_to_sample
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: "8192"
+      type: text-input
+    - variable: function_calling_type
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Function calling
+      type: select
+      required: false
+      default: no_call
+      options:
+        - value: function_call
+          label:
+            en_US: Function Call
+            zh_Hans: Function Call
+        - value: tool_call
+          label:
+            en_US: Tool Call
+            zh_Hans: Tool Call
+        - value: no_call
+          label:
+            en_US: Not Support
+            zh_Hans: 不支持
+    - variable: vision_support
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        zh_Hans: Vision 支持
+        en_US: Vision Support
+      type: select
+      required: false
+      default: no_support
+      options:
+        - value: support
+          label:
+            en_US: Support
+            zh_Hans: 支持
+        - value: no_support
+          label:
+            en_US: Not Support
+            zh_Hans: 不支持

+ 0 - 0
api/core/model_runtime/model_providers/gpustack/llm/__init__.py


+ 45 - 0
api/core/model_runtime/model_providers/gpustack/llm/llm.py

@@ -0,0 +1,45 @@
+from collections.abc import Generator
+
+from yarl import URL
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import (
+    PromptMessage,
+    PromptMessageTool,
+)
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
+    OAIAPICompatLargeLanguageModel,
+)
+
+
+class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: list[PromptMessageTool] | None = None,
+        stop: list[str] | None = None,
+        stream: bool = True,
+        user: str | None = None,
+    ) -> LLMResult | Generator:
+        return super()._invoke(
+            model,
+            credentials,
+            prompt_messages,
+            model_parameters,
+            tools,
+            stop,
+            stream,
+            user,
+        )
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        self._add_custom_parameters(credentials)
+        super().validate_credentials(model, credentials)
+
+    @staticmethod
+    def _add_custom_parameters(credentials: dict) -> None:
+        credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
+        credentials["mode"] = "chat"

+ 0 - 0
api/core/model_runtime/model_providers/gpustack/rerank/__init__.py


+ 146 - 0
api/core/model_runtime/model_providers/gpustack/rerank/rerank.py

@@ -0,0 +1,146 @@
+from json import dumps
+from typing import Optional
+
+import httpx
+from requests import post
+from yarl import URL
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+)
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class GPUStackRerankModel(RerankModel):
+    """
+    Model class for GPUStack rerank model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        query: str,
+        docs: list[str],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n documents to return
+        :param user: unique user id
+        :return: rerank result
+        """
+        if len(docs) == 0:
+            return RerankResult(model=model, docs=[])
+
+        endpoint_url = credentials["endpoint_url"]
+        headers = {
+            "Authorization": f"Bearer {credentials.get('api_key')}",
+            "Content-Type": "application/json",
+        }
+
+        data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
+
+        try:
+            response = post(
+                str(URL(endpoint_url) / "v1" / "rerank"),
+                headers=headers,
+                data=dumps(data),
+                timeout=10,
+            )
+            response.raise_for_status()
+            results = response.json()
+
+            rerank_documents = []
+            for result in results["results"]:
+                index = result["index"]
+                if "document" in result:
+                    text = result["document"]["text"]
+                else:
+                    text = docs[index]
+
+                rerank_document = RerankDocument(
+                    index=index,
+                    text=text,
+                    score=result["relevance_score"],
+                )
+
+                if score_threshold is None or result["relevance_score"] >= score_threshold:
+                    rerank_documents.append(rerank_document)
+
+            return RerankResult(model=model, docs=rerank_documents)
+        except httpx.HTTPStatusError as e:
+            raise InvokeServerUnavailableError(str(e))
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._invoke(
+                model=model,
+                credentials=credentials,
+                query="What is the capital of the United States?",
+                docs=[
+                    "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                    "Census, Carson City had a population of 55,274.",
+                    "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                    "are a political division controlled by the United States. Its capital is Saipan.",
+                ],
+                score_threshold=0.8,
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        """
+        return {
+            InvokeConnectionError: [httpx.ConnectError],
+            InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [httpx.HTTPStatusError],
+            InvokeBadRequestError: [httpx.RequestError],
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+        generate custom model entities from credentials
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.RERANK,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
+        )
+
+        return entity

+ 0 - 0
api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py


+ 35 - 0
api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py

@@ -0,0 +1,35 @@
+from typing import Optional
+
+from yarl import URL
+
+from core.entities.embedding_type import EmbeddingInputType
+from core.model_runtime.entities.text_embedding_entities import (
+    TextEmbeddingResult,
+)
+from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
+    OAICompatEmbeddingModel,
+)
+
+
+class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
+    """
+    Model class for GPUStack text embedding model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
+    ) -> TextEmbeddingResult:
+        return super()._invoke(model, credentials, texts, user, input_type)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        self._add_custom_parameters(credentials)
+        super().validate_credentials(model, credentials)
+
+    @staticmethod
+    def _add_custom_parameters(credentials: dict) -> None:
+        credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")

+ 5 - 1
api/tests/integration_tests/.env.example

@@ -89,5 +89,9 @@ VESSL_AI_MODEL_NAME=
 VESSL_AI_API_KEY=
 VESSL_AI_ENDPOINT_URL=
 
+# GPUStack Credentials
+GPUSTACK_SERVER_URL=
+GPUSTACK_API_KEY=
+
 # Gitee AI Credentials
-GITEE_AI_API_KEY=
+GITEE_AI_API_KEY=

+ 0 - 0
api/tests/integration_tests/model_runtime/gpustack/__init__.py


+ 49 - 0
api/tests/integration_tests/model_runtime/gpustack/test_embedding.py

@@ -0,0 +1,49 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
+    GPUStackTextEmbeddingModel,
+)
+
+
+def test_validate_credentials():
+    model = GPUStackTextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="bge-m3",
+            credentials={
+                "endpoint_url": "invalid_url",
+                "api_key": "invalid_api_key",
+            },
+        )
+
+    model.validate_credentials(
+        model="bge-m3",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+        },
+    )
+
+
+def test_invoke_model():
+    model = GPUStackTextEmbeddingModel()
+
+    result = model.invoke(
+        model="bge-m3",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "context_size": 8192,
+        },
+        texts=["hello", "world"],
+        user="abc-123",
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+    assert result.usage.total_tokens == 7

+ 162 - 0
api/tests/integration_tests/model_runtime/gpustack/test_llm.py

@@ -0,0 +1,162 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import (
+    LLMResult,
+    LLMResultChunk,
+    LLMResultChunkDelta,
+)
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+    model = GPUStackLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="llama-3.2-1b-instruct",
+            credentials={
+                "endpoint_url": "invalid_url",
+                "api_key": "invalid_api_key",
+                "mode": "chat",
+            },
+        )
+
+    model.validate_credentials(
+        model="llama-3.2-1b-instruct",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "mode": "chat",
+        },
+    )
+
+
+def test_invoke_completion_model():
+    model = GPUStackLanguageModel()
+
+    response = model.invoke(
+        model="llama-3.2-1b-instruct",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "mode": "completion",
+        },
+        prompt_messages=[UserPromptMessage(content="ping")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+        stop=[],
+        user="abc-123",
+        stream=False,
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+    assert response.usage.total_tokens > 0
+
+
+def test_invoke_chat_model():
+    model = GPUStackLanguageModel()
+
+    response = model.invoke(
+        model="llama-3.2-1b-instruct",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "mode": "chat",
+        },
+        prompt_messages=[UserPromptMessage(content="ping")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+        stop=[],
+        user="abc-123",
+        stream=False,
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+    assert response.usage.total_tokens > 0
+
+
+def test_invoke_stream_chat_model():
+    model = GPUStackLanguageModel()
+
+    response = model.invoke(
+        model="llama-3.2-1b-instruct",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "mode": "chat",
+        },
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+        model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
+        stop=["you"],
+        stream=True,
+        user="abc-123",
+    )
+
+    assert isinstance(response, Generator)
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+    model = GPUStackLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model="????",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "mode": "chat",
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+        tools=[
+            PromptMessageTool(
+                name="get_current_weather",
+                description="Get the current weather in a given location",
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {
+                            "type": "string",
+                            "description": "The city and state e.g. San Francisco, CA",
+                        },
+                        "unit": {"type": "string", "enum": ["c", "f"]},
+                    },
+                    "required": ["location"],
+                },
+            )
+        ],
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 80
+
+    num_tokens = model.get_num_tokens(
+        model="????",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+            "mode": "chat",
+        },
+        prompt_messages=[UserPromptMessage(content="Hello World!")],
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 10

+ 107 - 0
api/tests/integration_tests/model_runtime/gpustack/test_rerank.py

@@ -0,0 +1,107 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.gpustack.rerank.rerank import (
+    GPUStackRerankModel,
+)
+
+
+def test_validate_credentials_for_rerank_model():
+    model = GPUStackRerankModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="bge-reranker-v2-m3",
+            credentials={
+                "endpoint_url": "invalid_url",
+                "api_key": "invalid_api_key",
+            },
+        )
+
+    model.validate_credentials(
+        model="bge-reranker-v2-m3",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+        },
+    )
+
+
+def test_invoke_rerank_model():
+    model = GPUStackRerankModel()
+
+    response = model.invoke(
+        model="bge-reranker-v2-m3",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+        },
+        query="Organic skincare products for sensitive skin",
+        docs=[
+            "Eco-friendly kitchenware for modern homes",
+            "Biodegradable cleaning supplies for eco-conscious consumers",
+            "Organic cotton baby clothes for sensitive skin",
+            "Natural organic skincare range for sensitive skin",
+            "Tech gadgets for smart homes: 2024 edition",
+            "Sustainable gardening tools and compost solutions",
+            "Sensitive skin-friendly facial cleansers and toners",
+            "Organic food wraps and storage solutions",
+            "Yoga mats made from recycled materials",
+        ],
+        top_n=3,
+        score_threshold=-0.75,
+        user="abc-123",
+    )
+
+    assert isinstance(response, RerankResult)
+    assert len(response.docs) == 3
+
+
+def test__invoke():
+    model = GPUStackRerankModel()
+
+    # Test case 1: Empty docs
+    result = model._invoke(
+        model="bge-reranker-v2-m3",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+        },
+        query="Organic skincare products for sensitive skin",
+        docs=[],
+        top_n=3,
+        score_threshold=0.75,
+        user="abc-123",
+    )
+    assert isinstance(result, RerankResult)
+    assert len(result.docs) == 0
+
+    # Test case 2: Expected docs
+    result = model._invoke(
+        model="bge-reranker-v2-m3",
+        credentials={
+            "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
+            "api_key": os.environ.get("GPUSTACK_API_KEY"),
+        },
+        query="Organic skincare products for sensitive skin",
+        docs=[
+            "Eco-friendly kitchenware for modern homes",
+            "Biodegradable cleaning supplies for eco-conscious consumers",
+            "Organic cotton baby clothes for sensitive skin",
+            "Natural organic skincare range for sensitive skin",
+            "Tech gadgets for smart homes: 2024 edition",
+            "Sustainable gardening tools and compost solutions",
+            "Sensitive skin-friendly facial cleansers and toners",
+            "Organic food wraps and storage solutions",
+            "Yoga mats made from recycled materials",
+        ],
+        top_n=3,
+        score_threshold=-0.75,
+        user="abc-123",
+    )
+    assert isinstance(result, RerankResult)
+    assert len(result.docs) == 3
+    assert all(isinstance(doc, RerankDocument) for doc in result.docs)

Деякі файли не було показано, через те що забагато файлів було змінено