瀏覽代碼

feat (new llm): add support for openrouter (#3042)

Salem Korayem 1 年之前
父節點
當前提交
6b4c8e76e6

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -6,6 +6,7 @@
 - cohere
 - bedrock
 - togetherai
+- openrouter
 - ollama
 - mistralai
 - groq

+ 0 - 0
api/core/model_runtime/model_providers/openrouter/__init__.py


文件差異過大導致無法顯示
+ 1 - 0
api/core/model_runtime/model_providers/openrouter/_assets/openrouter.svg


+ 10 - 0
api/core/model_runtime/model_providers/openrouter/_assets/openrouter_square.svg

@@ -0,0 +1,10 @@
+<svg width="25" height="21" viewBox="0 0 25 21" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M1.05858 10.1738C1.76158 10.1738 4.47988 9.56715 5.88589 8.77041C7.2919 7.97367 7.2919 7.97367 10.1977 5.91152C13.8766 3.30069 16.4779 4.17486 20.7428 4.17486" fill="black"/>
+<path fill-rule="evenodd" clip-rule="evenodd" d="M11.4182 7.63145L11.3787 7.65951C8.50565 9.69845 8.42504 9.75566 6.92566 10.6053C5.98567 11.138 4.74704 11.5436 3.75151 11.8089C2.80313 12.0615 1.71203 12.2829 1.05858 12.2829V8.06483C1.05075 8.06483 1.05422 8.06445 1.06984 8.06276C1.11491 8.05788 1.26116 8.04203 1.52896 7.9926C1.84599 7.9341 2.24205 7.84582 2.6657 7.73296C3.55657 7.49564 4.3801 7.1996 4.84612 6.93552C4.88175 6.91533 4.91635 6.89573 4.95001 6.87666C6.15007 6.19693 6.15657 6.19325 8.97708 4.1916C12.5199 1.67735 15.5815 1.83587 18.5849 1.99138C19.3056 2.0287 20.0229 2.06584 20.7428 2.06584V6.28388C19.6102 6.28388 18.6583 6.24193 17.8263 6.20527C15.1245 6.08621 13.685 6.02278 11.4182 7.63145Z" fill="black"/>
+<path d="M24.8671 4.20087L17.6613 8.36117V0.0405881L24.8671 4.20087Z" fill="black"/>
+<path fill-rule="evenodd" clip-rule="evenodd" d="M17.6378 0L24.9139 4.20087L17.6378 8.40176V0ZM17.6847 0.0811762V8.32058L24.8202 4.20087L17.6847 0.0811762Z" fill="black"/>
+<path d="M0.917975 10.1764C1.62098 10.1764 4.33927 10.7831 5.74529 11.5799C7.1513 12.3766 7.1513 12.3766 10.0571 14.4388C13.736 17.0496 16.3373 16.1754 20.6022 16.1754" fill="black"/>
+<path fill-rule="evenodd" clip-rule="evenodd" d="M0.929234 12.2875C0.913615 12.2858 0.910145 12.2854 0.917975 12.2854V8.06741C1.57142 8.06741 2.66253 8.28878 3.61091 8.54142C4.60644 8.80663 5.84507 9.21231 6.78506 9.74497C8.28444 10.5946 8.36505 10.6518 11.2381 12.6908L11.2776 12.7188C13.5444 14.3275 14.9839 14.2641 17.6857 14.145C18.5177 14.1083 19.4696 14.0664 20.6022 14.0664V18.2844C19.8823 18.2844 19.165 18.3216 18.4443 18.3589C15.4409 18.5144 12.3793 18.6729 8.83648 16.1587C6.01597 14.157 6.00947 14.1533 4.80941 13.4736C4.77575 13.4545 4.74115 13.4349 4.70551 13.4148C4.2395 13.1507 3.41597 12.8546 2.5251 12.6173C2.10145 12.5045 1.70538 12.4162 1.38836 12.3577C1.12056 12.3083 0.974309 12.2924 0.929234 12.2875Z" fill="black"/>
+<path d="M24.7265 16.1494L17.5207 11.9892V20.3097L24.7265 16.1494Z" fill="black"/>
+<path fill-rule="evenodd" clip-rule="evenodd" d="M17.4972 11.9486L24.7733 16.1494L17.4972 20.3503V11.9486ZM17.5441 12.0297V20.2691L24.6796 16.1494L17.5441 12.0297Z" fill="black"/>
+</svg>

+ 0 - 0
api/core/model_runtime/model_providers/openrouter/llm/__init__.py


+ 46 - 0
api/core/model_runtime/model_providers/openrouter/llm/llm.py

@@ -0,0 +1,46 @@
+from collections.abc import Generator
+from typing import Optional, Union
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
+
+    def _update_endpoint_url(self, credentials: dict):
+        credentials['endpoint_url'] = "https://openrouter.ai/api/v1"
+        return credentials
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().validate_credentials(model, cred_with_endpoint)
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_customizable_model_schema(model, cred_with_endpoint)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)

+ 11 - 0
api/core/model_runtime/model_providers/openrouter/openrouter.py

@@ -0,0 +1,11 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class OpenRouterProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 75 - 0
api/core/model_runtime/model_providers/openrouter/openrouter.yaml

@@ -0,0 +1,75 @@
+provider: openrouter
+label:
+  en_US: openrouter.ai
+icon_small:
+  en_US: openrouter_square.svg
+icon_large:
+  en_US: openrouter.svg
+background: "#F1EFED"
+help:
+  title:
+    en_US: Get your API key from openrouter.ai
+    zh_Hans: 从 openrouter.ai 获取 API Key
+  url:
+    en_US: https://openrouter.ai/keys
+supported_model_types:
+  - llm
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter full model name
+      zh_Hans: 输入模型全称
+  credential_form_schemas:
+    - variable: api_key
+      required: true
+      label:
+        en_US: API Key
+      type: secret-input
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Completion mode
+      type: select
+      required: false
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select completion mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: "4096"
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens_to_sample
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: "4096"
+      type: text-input

+ 0 - 0
api/tests/integration_tests/model_runtime/openrouter/__init__.py


+ 118 - 0
api/tests/integration_tests/model_runtime/openrouter/test_llm.py

@@ -0,0 +1,118 @@
+import os
+from typing import Generator
+
+import pytest
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool,
+                                                          SystemPromptMessage, UserPromptMessage)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.openrouter.llm.llm import OpenRouterLargeLanguageModel
+
+
+def test_validate_credentials():
+    model = OpenRouterLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='mistralai/mixtral-8x7b-instruct',
+            credentials={
+                'api_key': 'invalid_key',
+                'mode': 'chat'
+            }
+        )
+
+    model.validate_credentials(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'chat'
+        }
+    )
+
+
+def test_invoke_model():
+    model = OpenRouterLargeLanguageModel()
+
+    response = model.invoke(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'completion'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = OpenRouterLargeLanguageModel()
+
+    response = model.invoke(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'chat'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_get_num_tokens():
+    model = OpenRouterLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 21

部分文件因文件數量過多而無法顯示