Переглянути джерело

Feat/add 360-zhinao provider (#7069)

小羽 8 місяців тому
батько
коміт
34a9dbe826

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -36,3 +36,4 @@
 - hunyuan
 - siliconflow
 - perfxcloud
+- zhinao

+ 0 - 0
api/core/model_runtime/model_providers/zhinao/__init__.py


Різницю між файлами не показано, бо вона завелика
+ 5 - 0
api/core/model_runtime/model_providers/zhinao/_assets/icon_l_en.svg


Різницю між файлами не показано, бо вона завелика
+ 5 - 0
api/core/model_runtime/model_providers/zhinao/_assets/icon_s_en.svg


+ 36 - 0
api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo-responsibility-8k.yaml

@@ -0,0 +1,36 @@
+model: 360gpt-turbo-responsibility-8k
+label:
+  zh_Hans: 360gpt-turbo-responsibility-8k
+  en_US: 360gpt-turbo-responsibility-8k
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 1
+    default: 0.5
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 8192
+    default: 1024
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/zhinao/llm/360gpt-turbo.yaml

@@ -0,0 +1,36 @@
+model: 360gpt-turbo
+label:
+  zh_Hans: 360gpt-turbo
+  en_US: 360gpt-turbo
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 2048
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 1
+    default: 0.5
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 1024
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/zhinao/llm/360gpt2-pro.yaml

@@ -0,0 +1,36 @@
+model: 360gpt2-pro
+label:
+  zh_Hans: 360gpt2-pro
+  en_US: 360gpt2-pro
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 2048
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 1
+    default: 0.5
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 1024
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 0 - 0
api/core/model_runtime/model_providers/zhinao/llm/__init__.py


+ 3 - 0
api/core/model_runtime/model_providers/zhinao/llm/_position.yaml

@@ -0,0 +1,3 @@
+- 360gpt2-pro
+- 360gpt-turbo
+- 360gpt-turbo-responsibility-8k

+ 25 - 0
api/core/model_runtime/model_providers/zhinao/llm/llm.py

@@ -0,0 +1,25 @@
+from collections.abc import Generator
+from typing import Optional, Union
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel):
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        self._add_custom_parameters(credentials)
+        return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        self._add_custom_parameters(credentials)
+        super().validate_credentials(model, credentials)
+
+    @classmethod
+    def _add_custom_parameters(cls, credentials: dict) -> None:
+        credentials['mode'] = 'chat'
+        credentials['endpoint_url'] = 'https://api.360.cn/v1'

+ 32 - 0
api/core/model_runtime/model_providers/zhinao/zhinao.py

@@ -0,0 +1,32 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class ZhinaoProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            # Use `360gpt-turbo` model for validate,
+            # no matter what model you pass in, text completion model or chat model
+            model_instance.validate_credentials(
+                model='360gpt-turbo',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 32 - 0
api/core/model_runtime/model_providers/zhinao/zhinao.yaml

@@ -0,0 +1,32 @@
+provider: zhinao
+label:
+  en_US: 360 AI
+  zh_Hans: 360 智脑
+description:
+  en_US: Models provided by 360 AI.
+  zh_Hans: 360 智脑提供的模型。
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+background: "#e3f0ff"
+help:
+  title:
+    en_US: Get your API Key from 360 AI.
+    zh_Hans: 从360 智脑获取 API Key
+  url:
+    en_US: https://ai.360.com/platform/keys
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key

+ 4 - 1
api/tests/integration_tests/.env.example

@@ -79,4 +79,7 @@ CODE_EXECUTION_API_KEY=
 VOLC_API_KEY=
 VOLC_SECRET_KEY=
 VOLC_MODEL_ENDPOINT_ID=
-VOLC_EMBEDDING_ENDPOINT_ID=
+VOLC_EMBEDDING_ENDPOINT_ID=
+
+# 360 AI Credentials
+ZHINAO_API_KEY=

+ 0 - 0
api/tests/integration_tests/model_runtime/zhinao/__init__.py


+ 106 - 0
api/tests/integration_tests/model_runtime/zhinao/test_llm.py

@@ -0,0 +1,106 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.zhinao.llm.llm import ZhinaoLargeLanguageModel
+
+
+def test_validate_credentials():
+    model = ZhinaoLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='360gpt2-pro',
+            credentials={
+                'api_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='360gpt2-pro',
+        credentials={
+            'api_key': os.environ.get('ZHINAO_API_KEY')
+        }
+    )
+
+
+def test_invoke_model():
+    model = ZhinaoLargeLanguageModel()
+
+    response = model.invoke(
+        model='360gpt2-pro',
+        credentials={
+            'api_key': os.environ.get('ZHINAO_API_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.5,
+            'max_tokens': 10
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = ZhinaoLargeLanguageModel()
+
+    response = model.invoke(
+        model='360gpt2-pro',
+        credentials={
+            'api_key': os.environ.get('ZHINAO_API_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.5,
+            'max_tokens': 100,
+            'seed': 1234
+        },
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+    model = ZhinaoLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='360gpt2-pro',
+        credentials={
+            'api_key': os.environ.get('ZHINAO_API_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert num_tokens == 21

+ 21 - 0
api/tests/integration_tests/model_runtime/zhinao/test_provider.py

@@ -0,0 +1,21 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.zhinao.zhinao import ZhinaoProvider
+
+
+def test_validate_provider_credentials():
+    provider = ZhinaoProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(
+            credentials={}
+        )
+
+    provider.validate_provider_credentials(
+        credentials={
+            'api_key': os.environ.get('ZHINAO_API_KEY')
+        }
+    )

Деякі файли не було показано, через те що забагато файлів було змінено