Browse Source

feat: support hunyuan llm models (#5013)

Co-authored-by: takatost <takatost@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
xielong 10 months ago
parent
commit
ea69dc2a7e

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -31,3 +31,4 @@
 - volcengine_maas
 - openai_api_compatible
 - deepseek
+- hunyuan

+ 0 - 0
api/core/model_runtime/model_providers/hunyuan/__init__.py


BIN
api/core/model_runtime/model_providers/hunyuan/_assets/icon_l_en.png


BIN
api/core/model_runtime/model_providers/hunyuan/_assets/icon_s_en.png


+ 30 - 0
api/core/model_runtime/model_providers/hunyuan/hunyuan.py

@@ -0,0 +1,30 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+class HunyuanProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            # Use `hunyuan-standard` model for validate,
+            model_instance.validate_credentials(
+                model='hunyuan-standard',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 40 - 0
api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml

@@ -0,0 +1,40 @@
+provider: hunyuan
+label:
+  zh_Hans: 腾讯混元
+  en_US: Hunyuan
+description:
+  en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro and hunyuan-lite.
+  zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro 和 hunyuan-lite。
+icon_small:
+  en_US: icon_s_en.png
+icon_large:
+  en_US: icon_l_en.png
+background: "#F6F7F7"
+help:
+  title:
+    en_US: Get your API Key from Tencent Hunyuan
+    zh_Hans: 从腾讯混元获取 API Key
+  url:
+    en_US: https://console.cloud.tencent.com/cam/capi
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: secret_id
+      label:
+        en_US: Secret ID
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 Secret ID
+        en_US: Enter your Secret ID
+    - variable: secret_key
+      label:
+        en_US: Secret Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 Secret Key
+        en_US: Enter your Secret Key

+ 0 - 0
api/core/model_runtime/model_providers/hunyuan/llm/__init__.py


+ 4 - 0
api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml

@@ -0,0 +1,4 @@
+- hunyuan-lite
+- hunyuan-standard
+- hunyuan-standard-256k
+- hunyuan-pro

+ 28 - 0
api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-lite.yaml

@@ -0,0 +1,28 @@
+model: hunyuan-lite
+label:
+  zh_Hans: hunyuan-lite
+  en_US: hunyuan-lite
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - multi-tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 256000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 256000
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.001'
+  currency: RMB

+ 28 - 0
api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-pro.yaml

@@ -0,0 +1,28 @@
+model: hunyuan-pro
+label:
+  zh_Hans: hunyuan-pro
+  en_US: hunyuan-pro
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - multi-tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 32000
+pricing:
+  input: '0.03'
+  output: '0.10'
+  unit: '0.001'
+  currency: RMB

+ 28 - 0
api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml

@@ -0,0 +1,28 @@
+model: hunyuan-standard-256k
+label:
+  zh_Hans: hunyuan-standard-256k
+  en_US: hunyuan-standard-256k
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - multi-tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 256000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 256000
+pricing:
+  input: '0.015'
+  output: '0.06'
+  unit: '0.001'
+  currency: RMB

+ 28 - 0
api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard.yaml

@@ -0,0 +1,28 @@
+model: hunyuan-standard
+label:
+  zh_Hans: hunyuan-standard
+  en_US: hunyuan-standard
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - multi-tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 32000
+pricing:
+  input: '0.0045'
+  output: '0.0005'
+  unit: '0.001'
+  currency: RMB

+ 205 - 0
api/core/model_runtime/model_providers/hunyuan/llm/llm.py

@@ -0,0 +1,205 @@
+import json
+import logging
+from collections.abc import Generator
+
+from tencentcloud.common import credential
+from tencentcloud.common.exception import TencentCloudSDKException
+from tencentcloud.common.profile.client_profile import ClientProfile
+from tencentcloud.common.profile.http_profile import HttpProfile
+from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.invoke import InvokeError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+class HunyuanLargeLanguageModel(LargeLanguageModel):
+
+    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+
+        client = self._setup_hunyuan_client(credentials)
+        request = models.ChatCompletionsRequest()
+        messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages)
+
+        custom_parameters = {
+            'Temperature': model_parameters.get('temperature', 0.0),
+            'TopP': model_parameters.get('top_p', 1.0)
+        }
+
+        params = {
+            "Model": model,
+            "Messages": messages_dict,
+            "Stream": stream,
+            **custom_parameters,
+        }
+
+        request.from_json_string(json.dumps(params))
+        response = client.ChatCompletions(request)
+
+        if stream:
+            return self._handle_stream_chat_response(model, credentials, prompt_messages, response)
+
+        return self._handle_chat_response(credentials, model, prompt_messages, response)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate credentials
+        """
+        try:
+            client = self._setup_hunyuan_client(credentials)
+
+            req = models.ChatCompletionsRequest()
+            params = {
+                "Model": model,
+                "Messages": [{
+                    "Role": "user",
+                    "Content": "hello"
+                }],
+                "TopP": 1,
+                "Temperature": 0,
+                "Stream": False
+            }
+            req.from_json_string(json.dumps(params))
+            client.ChatCompletions(req)
+        except Exception as e:
+            raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
+
+    def _setup_hunyuan_client(self, credentials):
+        secret_id = credentials['secret_id']
+        secret_key = credentials['secret_key']
+        cred = credential.Credential(secret_id, secret_key)
+        httpProfile = HttpProfile()
+        httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
+        clientProfile = ClientProfile()
+        clientProfile.httpProfile = httpProfile
+        client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
+        return client
+
+    def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage]) -> list[dict]:
+        """Convert a list of PromptMessage objects to a list of dictionaries with 'Role' and 'Content' keys."""
+        return [{"Role": message.role.value, "Content": message.content} for message in prompt_messages]
+
+    def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp):
+        for index, event in enumerate(resp):
+            logging.debug("_handle_stream_chat_response, event: %s", event)
+
+            data_str = event['data']
+            data = json.loads(data_str)
+
+            choices = data.get('Choices', [])
+            if not choices:
+                continue
+            choice = choices[0]
+            delta = choice.get('Delta', {})
+            message_content = delta.get('Content', '')
+            finish_reason = choice.get('FinishReason', '')
+
+            usage = data.get('Usage', {})
+            prompt_tokens = usage.get('PromptTokens', 0)
+            completion_tokens = usage.get('CompletionTokens', 0)
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            assistant_prompt_message = AssistantPromptMessage(
+                content=message_content,
+                tool_calls=[]
+            )
+
+            delta_chunk = LLMResultChunkDelta(
+                index=index,
+                role=delta.get('Role', 'assistant'),
+                message=assistant_prompt_message,
+                usage=usage,
+                finish_reason=finish_reason,
+            )
+
+            yield LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=delta_chunk,
+            )
+
+    def _handle_chat_response(self, credentials, model, prompt_messages, response):
+        usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens,
+                                          response.Usage.CompletionTokens)
+        assistant_prompt_message = PromptMessage(role="assistant")
+        assistant_prompt_message.content = response.Choices[0].Message.Content
+        result = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage,
+        )
+
+        return result
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: list[PromptMessageTool] | None = None) -> int:
+        if len(prompt_messages) == 0:
+            return 0
+        prompt = self._convert_messages_to_prompt(prompt_messages)
+        return self._get_num_tokens_by_gpt2(prompt)
+
+    def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+        """
+        Format a list of messages into a full prompt for the Anthropic model
+
+        :param messages: List of PromptMessage to combine.
+        :return: Combined string with necessary human_prompt and ai_prompt tags.
+        """
+        messages = messages.copy()  # don't mutate the original list
+
+        text = "".join(
+            self._convert_one_message_to_text(message)
+            for message in messages
+        )
+
+        # trim off the trailing ' ' that might come from the "Assistant: "
+        return text.rstrip()
+
+    def _convert_one_message_to_text(self, message: PromptMessage) -> str:
+        """
+        Convert a single message to a string.
+
+        :param message: PromptMessage to convert.
+        :return: String representation of the message.
+        """
+        human_prompt = "\n\nHuman:"
+        ai_prompt = "\n\nAssistant:"
+        content = message.content
+
+        if isinstance(message, UserPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        elif isinstance(message, AssistantPromptMessage):
+            message_text = f"{ai_prompt} {content}"
+        elif isinstance(message, SystemPromptMessage):
+            message_text = content
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        return message_text
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeError: [TencentCloudSDKException],
+        }

+ 1 - 0
api/requirements.txt

@@ -85,3 +85,4 @@ pymysql==1.1.1
 tidb-vector==0.0.9
 google-cloud-aiplatform==1.49.0
 vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
+tencentcloud-sdk-python-hunyuan~=3.0.1158

+ 0 - 0
api/tests/integration_tests/model_runtime/hunyuan/__init__.py


+ 111 - 0
api/tests/integration_tests/model_runtime/hunyuan/test_llm.py

@@ -0,0 +1,111 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.hunyuan.llm.llm import HunyuanLargeLanguageModel
+
+
+def test_validate_credentials():
+    model = HunyuanLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='hunyuan-standard',
+            credentials={
+                'secret_id': 'invalid_key',
+                'secret_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='hunyuan-standard',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        }
+    )
+
+
+def test_invoke_model():
+    model = HunyuanLargeLanguageModel()
+
+    response = model.invoke(
+        model='hunyuan-standard',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hi'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.5,
+            'max_tokens': 10
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = HunyuanLargeLanguageModel()
+
+    response = model.invoke(
+        model='hunyuan-standard',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hi'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.5,
+            'max_tokens': 100,
+            'seed': 1234
+        },
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+    model = HunyuanLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='hunyuan-standard',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert num_tokens == 14

+ 25 - 0
api/tests/integration_tests/model_runtime/hunyuan/test_provider.py

@@ -0,0 +1,25 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.hunyuan.hunyuan import HunyuanProvider
+
+
+def test_validate_provider_credentials():
+    provider = HunyuanProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(
+            credentials={
+                'secret_id': 'invalid_key',
+                'secret_key': 'invalid_key'
+            }
+        )
+
+    provider.validate_provider_credentials(
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        }
+    )