Explorar el Código

feat(api): support wenxin text embedding (#7377)

Chengyu Yan hace 8 meses
padre
commit
bfd905602f

+ 195 - 0
api/core/model_runtime/model_providers/wenxin/_common.py

@@ -0,0 +1,195 @@
+from datetime import datetime, timedelta
+from threading import Lock
+
+from requests import post
+
+from core.model_runtime.model_providers.wenxin.wenxin_errors import (
+    BadRequestError,
+    InternalServerError,
+    InvalidAPIKeyError,
+    InvalidAuthenticationError,
+    RateLimitReachedError,
+)
+
+baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
+baidu_access_tokens_lock = Lock()
+
+
+class BaiduAccessToken:
+    api_key: str
+    access_token: str
+    expires: datetime
+
+    def __init__(self, api_key: str) -> None:
+        self.api_key = api_key
+        self.access_token = ''
+        self.expires = datetime.now() + timedelta(days=3)
+
+    @staticmethod
+    def _get_access_token(api_key: str, secret_key: str) -> str:
+        """
+            request access token from Baidu
+        """
+        try:
+            response = post(
+                url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
+                headers={
+                    'Content-Type': 'application/json',
+                    'Accept': 'application/json'
+                },
+            )
+        except Exception as e:
+            raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
+
+        resp = response.json()
+        if 'error' in resp:
+            if resp['error'] == 'invalid_client':
+                raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
+            elif resp['error'] == 'unknown_error':
+                raise InternalServerError(f'Internal server error: {resp["error_description"]}')
+            elif resp['error'] == 'invalid_request':
+                raise BadRequestError(f'Bad request: {resp["error_description"]}')
+            elif resp['error'] == 'rate_limit_exceeded':
+                raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
+            else:
+                raise Exception(f'Unknown error: {resp["error_description"]}')
+
+        return resp['access_token']
+
+    @staticmethod
+    def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
+        """
+            LLM from Baidu requires access token to invoke the API.
+            however, we have api_key and secret_key, and access token is valid for 30 days.
+            so we can cache the access token for 3 days. (avoid memory leak)
+
+            it may be more efficient to use a ticker to refresh access token, but it will cause
+            more complexity, so we just refresh access tokens when get_access_token is called.
+        """
+
+        # loop up cache, remove expired access token
+        baidu_access_tokens_lock.acquire()
+        now = datetime.now()
+        for key in list(baidu_access_tokens.keys()):
+            token = baidu_access_tokens[key]
+            if token.expires < now:
+                baidu_access_tokens.pop(key)
+
+        if api_key not in baidu_access_tokens:
+            # if access token not in cache, request it
+            token = BaiduAccessToken(api_key)
+            baidu_access_tokens[api_key] = token
+            # release it to enhance performance
+            # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
+            baidu_access_tokens_lock.release()
+            # try to get access token
+            token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
+            token.access_token = token_str
+            token.expires = now + timedelta(days=3)
+            return token
+        else:
+            # if access token in cache, return it
+            token = baidu_access_tokens[api_key]
+            baidu_access_tokens_lock.release()
+            return token
+
+
+class _CommonWenxin:
+    api_bases = {
+        'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
+        'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
+        'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
+        'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
+        'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
+        'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
+        'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
+        'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
+        'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
+        'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
+        'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
+        'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
+        'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
+        'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
+        'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
+        'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
+        'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
+        'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
+        'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
+        'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
+        'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
+        'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
+    }
+
+    function_calling_supports = [
+        'ernie-bot',
+        'ernie-bot-8k',
+        'ernie-3.5-8k',
+        'ernie-3.5-8k-0205',
+        'ernie-3.5-8k-1222',
+        'ernie-3.5-4k-0205',
+        'ernie-3.5-128k',
+        'ernie-4.0-8k',
+        'ernie-4.0-turbo-8k',
+        'ernie-4.0-turbo-8k-preview',
+        'yi_34b_chat'
+    ]
+
+    api_key: str = ''
+    secret_key: str = ''
+
+    def __init__(self, api_key: str, secret_key: str):
+        self.api_key = api_key
+        self.secret_key = secret_key
+
+    @staticmethod
+    def _to_credential_kwargs(credentials: dict) -> dict:
+        credentials_kwargs = {
+            "api_key": credentials['api_key'],
+            "secret_key": credentials['secret_key']
+        }
+        return credentials_kwargs
+
+    def _handle_error(self, code: int, msg: str):
+        error_map = {
+            1: InternalServerError,
+            2: InternalServerError,
+            3: BadRequestError,
+            4: RateLimitReachedError,
+            6: InvalidAuthenticationError,
+            13: InvalidAPIKeyError,
+            14: InvalidAPIKeyError,
+            15: InvalidAPIKeyError,
+            17: RateLimitReachedError,
+            18: RateLimitReachedError,
+            19: RateLimitReachedError,
+            100: InvalidAPIKeyError,
+            111: InvalidAPIKeyError,
+            200: InternalServerError,
+            336000: InternalServerError,
+            336001: BadRequestError,
+            336002: BadRequestError,
+            336003: BadRequestError,
+            336004: InvalidAuthenticationError,
+            336005: InvalidAPIKeyError,
+            336006: BadRequestError,
+            336007: BadRequestError,
+            336008: BadRequestError,
+            336100: InternalServerError,
+            336101: BadRequestError,
+            336102: BadRequestError,
+            336103: BadRequestError,
+            336104: BadRequestError,
+            336105: BadRequestError,
+            336200: InternalServerError,
+            336303: BadRequestError,
+            337006: BadRequestError
+        }
+
+        if code in error_map:
+            raise error_map[code](msg)
+        else:
+            raise InternalServerError(f'Unknown error: {msg}')
+
+    def _get_access_token(self) -> str:
+        token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
+        return token.access_token

+ 3 - 177
api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py

@@ -1,102 +1,17 @@
 from collections.abc import Generator
-from datetime import datetime, timedelta
 from enum import Enum
 from json import dumps, loads
-from threading import Lock
 from typing import Any, Union
 
 from requests import Response, post
 
 from core.model_runtime.entities.message_entities import PromptMessageTool
-from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
+from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
+from core.model_runtime.model_providers.wenxin.wenxin_errors import (
     BadRequestError,
     InternalServerError,
-    InvalidAPIKeyError,
-    InvalidAuthenticationError,
-    RateLimitReachedError,
 )
 
-# map api_key to access_token
-baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
-baidu_access_tokens_lock = Lock()
-
-class BaiduAccessToken:
-    api_key: str
-    access_token: str
-    expires: datetime
-
-    def __init__(self, api_key: str) -> None:
-        self.api_key = api_key
-        self.access_token = ''
-        self.expires = datetime.now() + timedelta(days=3)
-
-    def _get_access_token(api_key: str, secret_key: str) -> str:
-        """
-            request access token from Baidu
-        """
-        try:
-            response = post(
-                url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
-                headers={
-                    'Content-Type': 'application/json',
-                    'Accept': 'application/json'
-                },
-            )
-        except Exception as e:
-            raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
-
-        resp = response.json()
-        if 'error' in resp:
-            if resp['error'] == 'invalid_client':
-                raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
-            elif resp['error'] == 'unknown_error':
-                raise InternalServerError(f'Internal server error: {resp["error_description"]}')
-            elif resp['error'] == 'invalid_request':
-                raise BadRequestError(f'Bad request: {resp["error_description"]}')
-            elif resp['error'] == 'rate_limit_exceeded':
-                raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
-            else:
-                raise Exception(f'Unknown error: {resp["error_description"]}')
-
-        return resp['access_token']
-
-    @staticmethod
-    def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
-        """
-            LLM from Baidu requires access token to invoke the API.
-            however, we have api_key and secret_key, and access token is valid for 30 days.
-            so we can cache the access token for 3 days. (avoid memory leak)
-
-            it may be more efficient to use a ticker to refresh access token, but it will cause
-            more complexity, so we just refresh access tokens when get_access_token is called.
-        """
-
-        # loop up cache, remove expired access token
-        baidu_access_tokens_lock.acquire()
-        now = datetime.now()
-        for key in list(baidu_access_tokens.keys()):
-            token = baidu_access_tokens[key]
-            if token.expires < now:
-                baidu_access_tokens.pop(key)
-
-        if api_key not in baidu_access_tokens:
-            # if access token not in cache, request it
-            token = BaiduAccessToken(api_key)
-            baidu_access_tokens[api_key] = token
-            # release it to enhance performance
-            # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
-            baidu_access_tokens_lock.release()
-            # try to get access token
-            token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
-            token.access_token = token_str
-            token.expires = now + timedelta(days=3)
-            return token
-        else:
-            # if access token in cache, return it
-            token = baidu_access_tokens[api_key]
-            baidu_access_tokens_lock.release()
-            return token
-
 
 class ErnieMessage:
     class Role(Enum):
@@ -120,51 +35,7 @@ class ErnieMessage:
         self.content = content
         self.role = role
 
-class ErnieBotModel:
-    api_bases = {
-        'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
-        'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
-        'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
-        'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
-        'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
-        'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
-        'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
-        'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
-        'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
-        'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
-        'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
-        'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
-        'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
-        'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
-        'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
-        'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
-        'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
-        'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
-        'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
-        'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
-        'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
-    }
-
-    function_calling_supports = [
-        'ernie-bot',
-        'ernie-bot-8k',
-        'ernie-3.5-8k',
-        'ernie-3.5-8k-0205',
-        'ernie-3.5-8k-1222',
-        'ernie-3.5-4k-0205',
-        'ernie-3.5-128k',
-        'ernie-4.0-8k',
-        'ernie-4.0-turbo-8k',
-        'ernie-4.0-turbo-8k-preview',
-        'yi_34b_chat'
-    ]
-
-    api_key: str = ''
-    secret_key: str = ''
-
-    def __init__(self, api_key: str, secret_key: str):
-        self.api_key = api_key
-        self.secret_key = secret_key
+class ErnieBotModel(_CommonWenxin):
 
     def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
                  parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
@@ -199,51 +70,6 @@ class ErnieBotModel:
             return self._handle_chat_stream_generate_response(resp)
         return self._handle_chat_generate_response(resp)
 
-    def _handle_error(self, code: int, msg: str):
-        error_map = {
-            1: InternalServerError,
-            2: InternalServerError,
-            3: BadRequestError,
-            4: RateLimitReachedError,
-            6: InvalidAuthenticationError,
-            13: InvalidAPIKeyError,
-            14: InvalidAPIKeyError,
-            15: InvalidAPIKeyError,
-            17: RateLimitReachedError,
-            18: RateLimitReachedError,
-            19: RateLimitReachedError,
-            100: InvalidAPIKeyError,
-            111: InvalidAPIKeyError,
-            200: InternalServerError,
-            336000: InternalServerError,
-            336001: BadRequestError,
-            336002: BadRequestError,
-            336003: BadRequestError,
-            336004: InvalidAuthenticationError,
-            336005: InvalidAPIKeyError,
-            336006: BadRequestError,
-            336007: BadRequestError,
-            336008: BadRequestError,
-            336100: InternalServerError,
-            336101: BadRequestError,
-            336102: BadRequestError,
-            336103: BadRequestError,
-            336104: BadRequestError,
-            336105: BadRequestError,
-            336200: InternalServerError,
-            336303: BadRequestError,
-            337006: BadRequestError
-        }
-
-        if code in error_map:
-            raise error_map[code](msg)
-        else:
-            raise InternalServerError(f'Unknown error: {msg}')
-
-    def _get_access_token(self) -> str:
-        token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
-        return token.access_token
-
     def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
         return [ErnieMessage(message.content, message.role) for message in messages]
 

+ 0 - 17
api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py

@@ -1,17 +0,0 @@
-class InvalidAuthenticationError(Exception):
-    pass
-
-class InvalidAPIKeyError(Exception):
-    pass
-
-class RateLimitReachedError(Exception):
-    pass
-
-class InsufficientAccountBalance(Exception):
-    pass
-
-class InternalServerError(Exception):
-    pass
-
-class BadRequestError(Exception):
-    pass

+ 5 - 34
api/core/model_runtime/model_providers/wenxin/llm/llm.py

@@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
     UserPromptMessage,
 )
 from core.model_runtime.errors.invoke import (
-    InvokeAuthorizationError,
-    InvokeBadRequestError,
-    InvokeConnectionError,
     InvokeError,
-    InvokeRateLimitError,
-    InvokeServerUnavailableError,
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
-from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
-    BadRequestError,
-    InsufficientAccountBalance,
-    InternalServerError,
-    InvalidAPIKeyError,
-    InvalidAuthenticationError,
-    RateLimitReachedError,
-)
+from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
+from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
+from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
 
 ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
 The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
         api_key = credentials['api_key']
         secret_key = credentials['secret_key']
         try:
-            BaiduAccessToken._get_access_token(api_key, secret_key)
+            BaiduAccessToken.get_access_token(api_key, secret_key)
         except Exception as e:
             raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
 
@@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
 
         :return: Invoke error mapping
         """
-        return {
-            InvokeConnectionError: [
-            ],
-            InvokeServerUnavailableError: [
-                InternalServerError
-            ],
-            InvokeRateLimitError: [
-                RateLimitReachedError
-            ],
-            InvokeAuthorizationError: [
-                InvalidAuthenticationError,
-                InsufficientAccountBalance,
-                InvalidAPIKeyError,
-            ],
-            InvokeBadRequestError: [
-                BadRequestError,
-                KeyError
-            ]
-        }
+        return invoke_error_mapping()

+ 0 - 0
api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py


+ 9 - 0
api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml

@@ -0,0 +1,9 @@
+model: embedding-v1
+model_type: text-embedding
+model_properties:
+  context_size: 384
+  max_chunks: 16
+pricing:
+  input: '0.0005'
+  unit: '0.001'
+  currency: RMB

+ 184 - 0
api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py

@@ -0,0 +1,184 @@
+import time
+from abc import abstractmethod
+from collections.abc import Mapping
+from json import dumps
+from typing import Any, Optional
+
+import numpy as np
+from requests import Response, post
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import InvokeError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
+from core.model_runtime.model_providers.wenxin.wenxin_errors import (
+    BadRequestError,
+    InternalServerError,
+    invoke_error_mapping,
+)
+
+
+class TextEmbedding:
+    @abstractmethod
+    def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
+        raise NotImplementedError
+
+
+class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
+    def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
+        access_token = self._get_access_token()
+        url = f'{self.api_bases[model]}?access_token={access_token}'
+        body = self._build_embed_request_body(model, texts, user)
+        headers = {
+            'Content-Type': 'application/json',
+        }
+
+        resp = post(url, data=dumps(body), headers=headers)
+        if resp.status_code != 200:
+            raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
+        return self._handle_embed_response(model, resp)
+
+    def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
+        if len(texts) == 0:
+            raise BadRequestError('The number of texts should not be zero.')
+        body = {
+            'input': texts,
+            'user_id': user,
+        }
+        return body
+
+    def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
+        data = response.json()
+        if 'error_code' in data:
+            code = data['error_code']
+            msg = data['error_msg']
+            # raise error
+            self._handle_error(code, msg)
+
+        embeddings = [v['embedding'] for v in data['data']]
+        _usage = data['usage']
+        tokens = _usage['prompt_tokens']
+        total_tokens = _usage['total_tokens']
+
+        return embeddings, tokens, total_tokens
+
+
+class WenxinTextEmbeddingModel(TextEmbeddingModel):
+    def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
+        return WenxinTextEmbedding(api_key, secret_key)
+
+    def _invoke(self, model: str, credentials: dict, texts: list[str],
+                user: Optional[str] = None) -> TextEmbeddingResult:
+        """
+                Invoke text embedding model
+
+                :param model: model name
+                :param credentials: model credentials
+                :param texts: texts to embed
+                :param user: unique user id
+                :return: embeddings result
+                """
+
+        api_key = credentials['api_key']
+        secret_key = credentials['secret_key']
+        embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
+        user = user if user else 'ErnieBotDefault'
+
+        context_size = self._get_context_size(model, credentials)
+        max_chunks = self._get_max_chunks(model, credentials)
+        inputs = []
+        indices = []
+        used_tokens = 0
+        used_total_tokens = 0
+
+        for i, text in enumerate(texts):
+
+            # Here token count is only an approximation based on the GPT2 tokenizer
+            num_tokens = self._get_num_tokens_by_gpt2(text)
+
+            if num_tokens >= context_size:
+                cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
+                # if num tokens is larger than context length, only use the start
+                inputs.append(text[0:cutoff])
+            else:
+                inputs.append(text)
+            indices += [i]
+
+        batched_embeddings = []
+        _iter = range(0, len(inputs), max_chunks)
+        for i in _iter:
+            embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
+                model,
+                inputs[i: i + max_chunks],
+                user)
+            used_tokens += _used_tokens
+            used_total_tokens += _total_used_tokens
+            batched_embeddings += embeddings_batch
+
+        usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
+        return TextEmbeddingResult(
+            model=model,
+            embeddings=batched_embeddings,
+            usage=usage,
+        )
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        if len(texts) == 0:
+            return 0
+        total_num_tokens = 0
+        for text in texts:
+            total_num_tokens += self._get_num_tokens_by_gpt2(text)
+
+        return total_num_tokens
+
+    def validate_credentials(self, model: str, credentials: Mapping) -> None:
+        api_key = credentials['api_key']
+        secret_key = credentials['secret_key']
+        try:
+            BaiduAccessToken.get_access_token(api_key, secret_key)
+        except Exception as e:
+            raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        return invoke_error_mapping()
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=total_tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage

+ 1 - 0
api/core/model_runtime/model_providers/wenxin/wenxin.yaml

@@ -17,6 +17,7 @@ help:
     en_US: https://cloud.baidu.com/wenxin.html
 supported_model_types:
   - llm
+  - text-embedding
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 57 - 0
api/core/model_runtime/model_providers/wenxin/wenxin_errors.py

@@ -0,0 +1,57 @@
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+
+
+def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
+    """
+    Map model invoke error to unified error
+    The key is the error type thrown to the caller
+    The value is the error type thrown by the model,
+    which needs to be converted into a unified error type for the caller.
+
+    :return: Invoke error mapping
+    """
+    return {
+        InvokeConnectionError: [
+        ],
+        InvokeServerUnavailableError: [
+            InternalServerError
+        ],
+        InvokeRateLimitError: [
+            RateLimitReachedError
+        ],
+        InvokeAuthorizationError: [
+            InvalidAuthenticationError,
+            InsufficientAccountBalance,
+            InvalidAPIKeyError,
+        ],
+        InvokeBadRequestError: [
+            BadRequestError,
+            KeyError
+        ]
+    }
+
+
+class InvalidAuthenticationError(Exception):
+    pass
+
+class InvalidAPIKeyError(Exception):
+    pass
+
+class RateLimitReachedError(Exception):
+    pass
+
+class InsufficientAccountBalance(Exception):
+    pass
+
+class InternalServerError(Exception):
+    pass
+
+class BadRequestError(Exception):
+    pass

+ 24 - 0
api/tests/integration_tests/model_runtime/wenxin/test_embedding.py

@@ -0,0 +1,24 @@
+import os
+from time import sleep
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel
+
+
+def test_invoke_embedding_model():
+    sleep(3)
+    model = WenxinTextEmbeddingModel()
+
+    response = model.invoke(
+        model='embedding-v1',
+        credentials={
+            'api_key': os.environ.get('WENXIN_API_KEY'),
+            'secret_key': os.environ.get('WENXIN_SECRET_KEY')
+        },
+        texts=['hello', '你好', 'xxxxx'],
+        user="abc-123"
+    )
+
+    assert isinstance(response, TextEmbeddingResult)
+    assert len(response.embeddings) == 3
+    assert isinstance(response.embeddings[0], list)

+ 0 - 0
api/tests/unit_tests/core/model_runtime/__init__.py


+ 0 - 0
api/tests/unit_tests/core/model_runtime/model_providers/__init__.py


+ 0 - 0
api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py


+ 75 - 0
api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py

@@ -0,0 +1,75 @@
+import numpy as np
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
+from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import (
+    TextEmbedding,
+    WenxinTextEmbeddingModel,
+)
+
+
+def test_max_chunks():
+    class _MockTextEmbedding(TextEmbedding):
+        def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
+            embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
+            tokens = 0
+            for text in texts:
+                tokens += len(text)
+
+            return embeddings, tokens, tokens
+
+    def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
+        return _MockTextEmbedding()
+
+    model = 'embedding-v1'
+    credentials = {
+        'api_key': 'xxxx',
+        'secret_key': 'yyyy',
+    }
+    embedding_model = WenxinTextEmbeddingModel()
+    context_size = embedding_model._get_context_size(model, credentials)
+    max_chunks = embedding_model._get_max_chunks(model, credentials)
+    embedding_model._create_text_embedding = _create_text_embedding
+
+    texts = ['0123456789' for i in range(0, max_chunks * 2)]
+    result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
+    assert len(result.embeddings) == max_chunks * 2
+
+
+def test_context_size():
+    def get_num_tokens_by_gpt2(text: str) -> int:
+        return GPT2Tokenizer.get_num_tokens(text)
+
+    def mock_text(token_size: int) -> str:
+        _text = "".join(['0' for i in range(token_size)])
+        num_tokens = get_num_tokens_by_gpt2(_text)
+        ratio = int(np.floor(len(_text) / num_tokens))
+        m_text = "".join([_text for i in range(ratio)])
+        return m_text
+
+    model = 'embedding-v1'
+    credentials = {
+        'api_key': 'xxxx',
+        'secret_key': 'yyyy',
+    }
+    embedding_model = WenxinTextEmbeddingModel()
+    context_size = embedding_model._get_context_size(model, credentials)
+
+    class _MockTextEmbedding(TextEmbedding):
+        def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
+            embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
+            tokens = 0
+            for text in texts:
+                tokens += get_num_tokens_by_gpt2(text)
+            return embeddings, tokens, tokens
+
+    def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
+        return _MockTextEmbedding()
+
+    embedding_model._create_text_embedding = _create_text_embedding
+    text = mock_text(context_size * 2)
+    assert get_num_tokens_by_gpt2(text) == context_size * 2
+
+    texts = [text]
+    result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
+    assert result.usage.tokens == context_size