瀏覽代碼

Add model hunyuan-embedding (#6657)

Co-authored-by: sun <sun@centen.cn>
Giga Group 8 月之前
父節點
當前提交
c9ff0e3961

+ 1 - 0
api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml

@@ -18,6 +18,7 @@ help:
     en_US: https://console.cloud.tencent.com/cam/capi
 supported_model_types:
   - llm
+  - text-embedding
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 0 - 0
api/core/model_runtime/model_providers/hunyuan/text_embedding/__init__.py


+ 5 - 0
api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml

@@ -0,0 +1,5 @@
+model: hunyuan-embedding
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 1

+ 173 - 0
api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py

@@ -0,0 +1,173 @@
+import json
+import logging
+import time
+from typing import Optional
+
+from tencentcloud.common import credential
+from tencentcloud.common.exception import TencentCloudSDKException
+from tencentcloud.common.profile.client_profile import ClientProfile
+from tencentcloud.common.profile.http_profile import HttpProfile
+from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+    InvokeError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+logger = logging.getLogger(__name__)
+
+class HunyuanTextEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for Hunyuan text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+
+        if model != 'hunyuan-embedding':
+            raise ValueError('Invalid model name')
+        
+        client = self._setup_hunyuan_client(credentials)
+
+        embeddings = []
+        token_usage = 0
+
+        for input in texts:
+            request = models.GetEmbeddingRequest()
+            params = {
+                "Input": input
+            }
+            request.from_json_string(json.dumps(params))
+            response = client.GetEmbedding(request)
+            usage = response.Usage.TotalTokens
+
+            embeddings.extend([data.Embedding for data in response.Data])
+            token_usage += usage
+
+        result = TextEmbeddingResult(
+            model=model,
+            embeddings=embeddings,
+            usage=self._calc_response_usage(
+                model=model,
+                credentials=credentials,
+                tokens=token_usage
+            )
+        )
+
+        return result
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate credentials
+        """
+        try:
+            client = self._setup_hunyuan_client(credentials)
+
+            req = models.ChatCompletionsRequest()
+            params = {
+                "Model": model,
+                "Messages": [{
+                    "Role": "user",
+                    "Content": "hello"
+                }],
+                "TopP": 1,
+                "Temperature": 0,
+                "Stream": False
+            }
+            req.from_json_string(json.dumps(params))
+            client.ChatCompletions(req)
+        except Exception as e:
+            raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
+
+    def _setup_hunyuan_client(self, credentials):
+        secret_id = credentials['secret_id']
+        secret_key = credentials['secret_key']
+        cred = credential.Credential(secret_id, secret_key)
+        httpProfile = HttpProfile()
+        httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
+        clientProfile = ClientProfile()
+        clientProfile.httpProfile = httpProfile
+        client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
+        return client
+    
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+    
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeError: [TencentCloudSDKException],
+        }
+    
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        # client = self._setup_hunyuan_client(credentials)
+
+        num_tokens = 0
+        for text in texts:
+            num_tokens += self._get_num_tokens_by_gpt2(text)
+            # use client.GetTokenCount to get num tokens
+            # request = models.GetTokenCountRequest()
+            # params = {
+            #     "Prompt": text
+            # }
+            # request.from_json_string(json.dumps(params))
+            # response = client.GetTokenCount(request)
+            # num_tokens += response.TokenCount
+
+        return num_tokens

+ 104 - 0
api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py

@@ -0,0 +1,104 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel
+
+
+def test_validate_credentials():
+    model = HunyuanTextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='hunyuan-embedding',
+            credentials={
+                'secret_id': 'invalid_key',
+                'secret_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='hunyuan-embedding',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        }
+    )
+
+
+def test_invoke_model():
+    model = HunyuanTextEmbeddingModel()
+
+    result = model.invoke(
+        model='hunyuan-embedding',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        },
+        texts=[
+            "hello",
+            "world"
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+    assert result.usage.total_tokens == 6
+
+def test_get_num_tokens():
+    model = HunyuanTextEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model='hunyuan-embedding',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        },
+        texts=[
+            "hello",
+            "world"
+        ]
+    )
+
+    assert num_tokens == 2
+
+def test_max_chunks():
+    model = HunyuanTextEmbeddingModel()
+
+    result = model.invoke(
+        model='hunyuan-embedding',
+        credentials={
+            'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
+            'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
+        },
+        texts=[
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+            "hello",
+            "world",
+        ]
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 22