소스 검색

Feat: Add model provider Text Embedding Inference for embedding and rerank (#7132)

Yanyi Liu 8 달 전
부모
커밋
5b32f2e0dd

+ 0 - 0
api/core/model_runtime/model_providers/huggingface_tei/__init__.py


+ 11 - 0
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py

@@ -0,0 +1,11 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class HuggingfaceTeiProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 36 - 0
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml

@@ -0,0 +1,36 @@
+provider: huggingface_tei
+label:
+  en_US: Text Embedding Inference
+description:
+  en_US: A blazing fast inference solution for text embeddings models.
+  zh_Hans: 用于文本嵌入模型的超快速推理解决方案。
+background: "#FFF8DC"
+help:
+  title:
+    en_US: How to deploy Text Embedding Inference
+    zh_Hans: 如何部署 Text Embedding Inference
+  url:
+    en_US: https://github.com/huggingface/text-embeddings-inference
+supported_model_types:
+  - text-embedding
+  - rerank
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: server_url
+      label:
+        zh_Hans: 服务器URL
+        en_US: Server url
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
+        en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080

+ 0 - 0
api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py


+ 137 - 0
api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py

@@ -0,0 +1,137 @@
+from typing import Optional
+
+import httpx
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
+
+
+class HuggingfaceTeiRerankModel(RerankModel):
+    """
+    Model class for Text Embedding Inference rerank model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        query: str,
+        docs: list[str],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id
+        :return: rerank result
+        """
+        if len(docs) == 0:
+            return RerankResult(model=model, docs=[])
+        server_url = credentials['server_url']
+
+        if server_url.endswith('/'):
+            server_url = server_url[:-1]
+
+        try:
+            results = TeiHelper.invoke_rerank(server_url, query, docs)
+
+            rerank_documents = []
+            for result in results:  
+                rerank_document = RerankDocument(
+                    index=result['index'],
+                    text=result['text'],
+                    score=result['score'],
+                )
+                if score_threshold is None or result['score'] >= score_threshold:
+                    rerank_documents.append(rerank_document)
+                if top_n is not None and len(rerank_documents) >= top_n:
+                    break
+
+            return RerankResult(model=model, docs=rerank_documents)
+        except httpx.HTTPStatusError as e:
+            raise InvokeServerUnavailableError(str(e))  
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            server_url = credentials['server_url']
+            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
+            if extra_args.model_type != 'reranker':
+                raise CredentialsValidateFailedError('Current model is not a rerank model')
+
+            credentials['context_size'] = extra_args.max_input_length
+
+            self.invoke(
+                model=model,
+                credentials=credentials,
+                query='Whose kasumi',
+                docs=[
+                    'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
+                    'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
+                    'and she leads a team named PopiParty.',
+                ],
+                score_threshold=0.8,
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [InvokeConnectionError],
+            InvokeServerUnavailableError: [InvokeServerUnavailableError],
+            InvokeRateLimitError: [InvokeRateLimitError],
+            InvokeAuthorizationError: [InvokeAuthorizationError],
+            InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+        used to define customizable model schema
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.RERANK,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
+            },
+            parameter_rules=[],
+        )
+
+        return entity

+ 183 - 0
api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py

@@ -0,0 +1,183 @@
+from threading import Lock
+from time import time
+from typing import Optional
+
+import httpx
+from requests.adapters import HTTPAdapter
+from requests.exceptions import ConnectionError, MissingSchema, Timeout
+from requests.sessions import Session
+from yarl import URL
+
+
+class TeiModelExtraParameter:
+    model_type: str
+    max_input_length: int
+    max_client_batch_size: int
+
+    def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None:
+        self.model_type = model_type
+        self.max_input_length = max_input_length
+        self.max_client_batch_size = max_client_batch_size
+
+
+cache = {}
+cache_lock = Lock()
+
+
+class TeiHelper:
+    @staticmethod
+    def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
+        TeiHelper._clean_cache()
+        with cache_lock:
+            if model_name not in cache:
+                cache[model_name] = {
+                    'expires': time() + 300,
+                    'value': TeiHelper._get_tei_extra_parameter(server_url),
+                }
+            return cache[model_name]['value']
+
+    @staticmethod
+    def _clean_cache() -> None:
+        try:
+            with cache_lock:
+                expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
+                for model_uid in expired_keys:
+                    del cache[model_uid]
+        except RuntimeError as e:
+            pass
+
+    @staticmethod
+    def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
+        """
+        get tei model extra parameter like model_type, max_input_length, max_batch_requests
+        """
+
+        url = str(URL(server_url) / 'info')
+
+        # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
+        session = Session()
+        session.mount('http://', HTTPAdapter(max_retries=3))
+        session.mount('https://', HTTPAdapter(max_retries=3))
+
+        try:
+            response = session.get(url, timeout=10)
+        except (MissingSchema, ConnectionError, Timeout) as e:
+            raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
+        if response.status_code != 200:
+            raise RuntimeError(
+                f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
+            )
+
+        response_json = response.json()
+
+        model_type = response_json.get('model_type', {})
+        if len(model_type.keys()) < 1:
+            raise RuntimeError('model_type is empty')
+        model_type = list(model_type.keys())[0]
+        if model_type not in ['embedding', 'reranker']:
+            raise RuntimeError(f'invalid model_type: {model_type}')
+        
+        max_input_length = response_json.get('max_input_length', 512)
+        max_client_batch_size = response_json.get('max_client_batch_size', 1)
+
+        return TeiModelExtraParameter(
+            model_type=model_type,
+            max_input_length=max_input_length,
+            max_client_batch_size=max_client_batch_size
+        )
+    
+    @staticmethod
+    def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
+        """
+        Invoke tokenize endpoint
+
+        Example response:
+        [
+            [
+                {
+                    "id": 0,
+                    "text": "<s>",
+                    "special": true,
+                    "start": null,
+                    "stop": null
+                },
+                {
+                    "id": 7704,
+                    "text": "str",
+                    "special": false,
+                    "start": 0,
+                    "stop": 3
+                },
+                < MORE TOKENS >
+            ]
+        ]
+
+        :param server_url: server url
+        :param texts: texts to tokenize
+        """
+        resp = httpx.post(
+            f'{server_url}/tokenize',
+            json={'inputs': texts},
+        )
+        resp.raise_for_status()
+        return resp.json()
+    
+    @staticmethod
+    def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
+        """
+        Invoke embeddings endpoint
+
+        Example response:
+        {
+            "object": "list",
+            "data": [
+                {
+                    "object": "embedding",
+                    "embedding": [...],
+                    "index": 0
+                }
+            ],
+            "model": "MODEL_NAME",
+            "usage": {
+                "prompt_tokens": 3,
+                "total_tokens": 3
+            }
+        }
+
+        :param server_url: server url
+        :param texts: texts to embed
+        """
+        # Use OpenAI compatible API here, which has usage tracking
+        resp = httpx.post(
+            f'{server_url}/v1/embeddings',
+            json={'input': texts},
+        )
+        resp.raise_for_status()
+        return resp.json()
+
+    @staticmethod
+    def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
+        """
+        Invoke rerank endpoint
+
+        Example response:
+        [
+            {
+                "index": 0,
+                "text": "Deep Learning is ...",
+                "score": 0.9950755
+            }
+        ]
+
+        :param server_url: server url
+        :param texts: texts to rerank
+        :param candidates: candidates to rerank
+        """
+        params = {'query': query, 'texts': docs, 'return_text': True}
+
+        response = httpx.post(
+            server_url + '/rerank',
+            json=params,
+        )
+        response.raise_for_status() 
+        return response.json()

+ 0 - 0
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py


+ 204 - 0
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py

@@ -0,0 +1,204 @@
+import time
+from typing import Optional
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
+
+
+class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for Text Embedding Inference text embedding model.
+    """
+
+    def _invoke(
+        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+    ) -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        credentials should be like:
+        {
+            'server_url': 'server url',
+            'model_uid': 'model uid',
+        }
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        server_url = credentials['server_url']
+
+        if server_url.endswith('/'):
+            server_url = server_url[:-1]
+
+
+        # get model properties
+        context_size = self._get_context_size(model, credentials)
+        max_chunks = self._get_max_chunks(model, credentials)
+
+        inputs = []
+        indices = []
+        used_tokens = 0
+
+        # get tokenized results from TEI
+        batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
+
+        for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
+
+            # Check if the number of tokens is larger than the context size
+            num_tokens = len(tokenize_result)
+
+            if num_tokens >= context_size:
+                # Find the best cutoff point
+                pre_special_token_count = 0
+                for token in tokenize_result:
+                    if token['special']:
+                        pre_special_token_count += 1
+                    else:
+                        break
+                rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
+
+                # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
+                token_cutoff = context_size - rest_special_token_count - 20
+
+                # Find the cutoff index
+                cutpoint_token = tokenize_result[token_cutoff]
+                cutoff = cutpoint_token['start']
+
+                inputs.append(text[0: cutoff])
+            else:
+                inputs.append(text)
+            indices += [i]
+
+        batched_embeddings = []
+        _iter = range(0, len(inputs), max_chunks)
+
+        try:
+            used_tokens = 0
+            for i in _iter:
+                iter_texts = inputs[i : i + max_chunks]
+                results = TeiHelper.invoke_embeddings(server_url, iter_texts)
+                embeddings = results['data']
+                embeddings = [embedding['embedding'] for embedding in embeddings]
+                batched_embeddings.extend(embeddings)
+
+                usage = results['usage']
+                used_tokens += usage['total_tokens']
+        except RuntimeError as e:
+            raise InvokeServerUnavailableError(str(e))
+
+        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
+
+        result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
+
+        return result
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        num_tokens = 0
+        server_url = credentials['server_url']
+
+        if server_url.endswith('/'):
+            server_url = server_url[:-1]
+
+        batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
+        num_tokens = sum(len(tokens) for tokens in batch_tokens)
+        return num_tokens
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            server_url = credentials['server_url']
+            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
+            print(extra_args)
+            if extra_args.model_type != 'embedding':
+                raise CredentialsValidateFailedError('Current model is not a embedding model')
+
+            credentials['context_size'] = extra_args.max_input_length
+            credentials['max_chunks'] = extra_args.max_client_batch_size
+            self._invoke(model=model, credentials=credentials, texts=['ping'])
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        return {
+            InvokeConnectionError: [InvokeConnectionError],
+            InvokeServerUnavailableError: [InvokeServerUnavailableError],
+            InvokeRateLimitError: [InvokeRateLimitError],
+            InvokeAuthorizationError: [InvokeAuthorizationError],
+            InvokeBadRequestError: [KeyError],
+        }
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at,
+        )
+
+        return usage
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+        used to define customizable model schema
+        """
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TEXT_EMBEDDING,
+            model_properties={
+                ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
+            },
+            parameter_rules=[],
+        )
+
+        return entity

+ 2 - 0
api/pyproject.toml

@@ -93,6 +93,8 @@ CODE_MAX_STRING_LENGTH = "80000"
 CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
 CODE_EXECUTION_API_KEY = "dify-sandbox"
 FIRECRAWL_API_KEY = "fc-"
+TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
+TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
 
 [tool.poetry]
 name = "dify-api"

+ 94 - 0
api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py

@@ -0,0 +1,94 @@
+
+from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
+
+
+class MockTEIClass:
+    @staticmethod
+    def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
+        # During mock, we don't have a real server to query, so we just return a dummy value
+        if 'rerank' in model_name:
+            model_type = 'reranker'
+        else:
+            model_type = 'embedding'
+
+        return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
+    
+    @staticmethod
+    def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
+        # Use space as token separator, and split the text into tokens
+        tokenized_texts = []
+        for text in texts:
+            tokens = text.split(' ')
+            current_index = 0
+            tokenized_text = []
+            for idx, token in enumerate(tokens):
+                s_token = {
+                    'id': idx,
+                    'text': token,
+                    'special': False,
+                    'start': current_index,
+                    'stop': current_index + len(token),
+                }
+                current_index += len(token) + 1
+                tokenized_text.append(s_token)
+            tokenized_texts.append(tokenized_text)
+        return tokenized_texts
+
+    @staticmethod
+    def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
+        # {
+        #     "object": "list",
+        #     "data": [
+        #         {
+        #             "object": "embedding",
+        #             "embedding": [...],
+        #             "index": 0
+        #         }
+        #     ],
+        #     "model": "MODEL_NAME",
+        #     "usage": {
+        #         "prompt_tokens": 3,
+        #         "total_tokens": 3
+        #     }
+        # }
+        embeddings = []
+        for idx, text in enumerate(texts):
+            embedding = [0.1] * 768
+            embeddings.append(
+                {
+                    'object': 'embedding',
+                    'embedding': embedding,
+                    'index': idx,
+                }
+            )
+        return {
+            'object': 'list',
+            'data': embeddings,
+            'model': 'MODEL_NAME',
+            'usage': {
+                'prompt_tokens': sum(len(text.split(' ')) for text in texts),
+                'total_tokens': sum(len(text.split(' ')) for text in texts),
+            },
+        }
+
+    def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
+        #         Example response:
+        # [
+        #     {
+        #         "index": 0,
+        #         "text": "Deep Learning is ...",
+        #         "score": 0.9950755
+        #     }
+        # ]
+        reranked_docs = []
+        for idx, text in enumerate(texts):
+            reranked_docs.append(
+                {
+                    'index': idx,
+                    'text': text,
+                    'score': 0.9,
+                }
+            )
+            # For mock, only return the first document
+            break
+        return reranked_docs

+ 0 - 0
api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py


+ 72 - 0
api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py

@@ -0,0 +1,72 @@
+import os
+
+import pytest
+from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
+    HuggingfaceTeiTextEmbeddingModel,
+)
+from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
+
+MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+
+
+@pytest.fixture
+def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
+    if MOCK:
+        monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
+        monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
+        monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
+        monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
+    yield
+
+    if MOCK:
+        monkeypatch.undo()
+
+
+@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+def test_validate_credentials(setup_tei_mock):
+    model = HuggingfaceTeiTextEmbeddingModel()
+    # model name is only used in mock
+    model_name = 'embedding'
+
+    if MOCK:
+        # TEI Provider will check model type by API endpoint, at real server, the model type is correct.
+        # So we dont need to check model type here. Only check in mock
+        with pytest.raises(CredentialsValidateFailedError):
+            model.validate_credentials(
+                model='reranker',
+                credentials={
+                    'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
+                }
+            )
+
+    model.validate_credentials(
+        model=model_name,
+        credentials={
+            'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
+        }
+    )
+
+@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+def test_invoke_model(setup_tei_mock):
+    model = HuggingfaceTeiTextEmbeddingModel()
+    model_name = 'embedding'
+
+    result = model.invoke(
+        model=model_name,
+        credentials={
+            'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
+        },
+        texts=[
+            "hello",
+            "world"
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+    assert result.usage.total_tokens > 0

+ 76 - 0
api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py

@@ -0,0 +1,76 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
+    HuggingfaceTeiRerankModel,
+)
+from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
+from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
+
+MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+
+
+@pytest.fixture
+def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
+    if MOCK:
+        monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
+        monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
+        monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
+        monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
+    yield
+
+    if MOCK:
+        monkeypatch.undo()
+
+@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+def test_validate_credentials(setup_tei_mock):
+    model = HuggingfaceTeiRerankModel()
+    # model name is only used in mock
+    model_name = 'reranker'
+
+    if MOCK:
+        # TEI Provider will check model type by API endpoint, at real server, the model type is correct.
+        # So we dont need to check model type here. Only check in mock
+        with pytest.raises(CredentialsValidateFailedError):
+            model.validate_credentials(
+                model='embedding',
+                credentials={
+                    'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
+                }
+            )
+
+    model.validate_credentials(
+        model=model_name,
+        credentials={
+            'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
+        }
+    )
+
+@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
+def test_invoke_model(setup_tei_mock):
+    model = HuggingfaceTeiRerankModel()
+    # model name is only used in mock
+    model_name = 'reranker'
+
+    result = model.invoke(
+        model=model_name,
+        credentials={
+            'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
+        },
+        query="Who is Kasumi?",
+        docs=[
+            "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
+            "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
+            "and she leads a team named PopiParty."
+        ],
+        score_threshold=0.8
+    )
+
+    assert isinstance(result, RerankResult)
+    assert len(result.docs) == 1
+    assert result.docs[0].index == 0
+    assert result.docs[0].score >= 0.8