Browse Source

feat: add baichuan llm support (#1294)

Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
takatost 1 year ago
parent
commit
1d4f019de4

+ 3 - 0
api/core/model_providers/model_provider_factory.py

@@ -51,6 +51,9 @@ class ModelProviderFactory:
         elif provider_name == 'chatglm':
             from core.model_providers.providers.chatglm_provider import ChatGLMProvider
             return ChatGLMProvider
+        elif provider_name == 'baichuan':
+            from core.model_providers.providers.baichuan_provider import BaichuanProvider
+            return BaichuanProvider
         elif provider_name == 'azure_openai':
             from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
             return AzureOpenAIProvider

+ 61 - 0
api/core/model_providers/models/llm/baichuan_model.py

@@ -0,0 +1,61 @@
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
+
+
+class BaichuanModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.CHAT
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        return BaichuanChatLLM(
+            streaming=self.streaming,
+            callbacks=self.callbacks,
+            **self.credentials,
+            **provider_model_kwargs
+        )
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens_from_messages(prompts), 0)
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        for k, v in provider_model_kwargs.items():
+            if hasattr(self.client, k):
+                setattr(self.client, k, v)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"Baichuan: {str(ex)}")
+
+    @property
+    def support_streaming(self):
+        return True

+ 167 - 0
api/core/model_providers/providers/baichuan_provider.py

@@ -0,0 +1,167 @@
+import json
+from json import JSONDecodeError
+from typing import Type
+
+from langchain.schema import HumanMessage
+
+from core.helper import encrypter
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.llm.baichuan_model import BaichuanModel
+from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
+from models.provider import ProviderType
+
+
+class BaichuanProvider(BaseModelProvider):
+
+    @property
+    def provider_name(self):
+        """
+        Returns the name of a provider.
+        """
+        return 'baichuan'
+
+    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
+        if model_type == ModelType.TEXT_GENERATION:
+            return [
+                {
+                    'id': 'baichuan2-53b',
+                    'name': 'Baichuan2-53B',
+                }
+            ]
+        else:
+            return []
+
+    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
+        """
+        Returns the model class.
+
+        :param model_type:
+        :return:
+        """
+        if model_type == ModelType.TEXT_GENERATION:
+            model_class = BaichuanModel
+        else:
+            raise NotImplementedError
+
+        return model_class
+
+    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
+        """
+        get model parameter rules.
+
+        :param model_name:
+        :param model_type:
+        :return:
+        """
+        return ModelKwargsRules(
+            temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
+            top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
+            presence_penalty=KwargRule[float](enabled=False),
+            frequency_penalty=KwargRule[float](enabled=False),
+            max_tokens=KwargRule[int](enabled=False),
+        )
+
+    @classmethod
+    def is_provider_credentials_valid_or_raise(cls, credentials: dict):
+        """
+        Validates the given credentials.
+        """
+        if 'api_key' not in credentials:
+            raise CredentialsValidateFailedError('Baichuan api_key must be provided.')
+
+        if 'secret_key' not in credentials:
+            raise CredentialsValidateFailedError('Baichuan secret_key must be provided.')
+
+        try:
+            credential_kwargs = {
+                'api_key': credentials['api_key'],
+                'secret_key': credentials['secret_key'],
+            }
+
+            llm = BaichuanChatLLM(
+                temperature=0,
+                **credential_kwargs
+            )
+
+            llm([HumanMessage(content='ping')])
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @classmethod
+    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
+        credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
+        credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
+        return credentials
+
+    def get_provider_credentials(self, obfuscated: bool = False) -> dict:
+        if self.provider.provider_type == ProviderType.CUSTOM.value:
+            try:
+                credentials = json.loads(self.provider.encrypted_config)
+            except JSONDecodeError:
+                credentials = {
+                    'api_key': None,
+                    'secret_key': None,
+                }
+
+            if credentials['api_key']:
+                credentials['api_key'] = encrypter.decrypt_token(
+                    self.provider.tenant_id,
+                    credentials['api_key']
+                )
+
+                if obfuscated:
+                    credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
+
+            if credentials['secret_key']:
+                credentials['secret_key'] = encrypter.decrypt_token(
+                    self.provider.tenant_id,
+                    credentials['secret_key']
+                )
+
+                if obfuscated:
+                    credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
+
+            return credentials
+        else:
+            return {}
+
+    def should_deduct_quota(self):
+        return True
+
+    @classmethod
+    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
+        """
+        check model credentials valid.
+
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        """
+        return
+
+    @classmethod
+    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
+                                  credentials: dict) -> dict:
+        """
+        encrypt model credentials for save.
+
+        :param tenant_id:
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        :return:
+        """
+        return {}
+
+    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
+        """
+        get credentials for llm use.
+
+        :param model_name:
+        :param model_type:
+        :param obfuscated:
+        :return:
+        """
+        return self.get_provider_credentials(obfuscated)

+ 2 - 1
api/core/model_providers/rules/_providers.json

@@ -7,10 +7,11 @@
   "spark",
   "wenxin",
   "zhipuai",
+  "baichuan",
   "chatglm",
   "replicate",
   "huggingface_hub",
   "xinference",
   "openllm",
   "localai"
-]
+]

+ 15 - 0
api/core/model_providers/rules/baichuan.json

@@ -0,0 +1,15 @@
+{
+    "support_provider_types": [
+        "custom"
+    ],
+    "system_config": null,
+    "model_flexibility": "fixed",
+    "price_config": {
+        "baichuan2-53b": {
+            "prompt": "0.01",
+            "completion": "0.01",
+            "unit": "0.001",
+            "currency": "RMB"
+        }
+    }
+}

+ 315 - 0
api/core/third_party/langchain/llms/baichuan_llm.py

@@ -0,0 +1,315 @@
+"""Wrapper around Baichuan APIs."""
+from __future__ import annotations
+
+import hashlib
+import json
+import logging
+import time
+from typing import (
+    Any,
+    Dict,
+    List,
+    Optional, Iterator,
+)
+
+import requests
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
+from langchain.schema.messages import AIMessageChunk
+from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
+from pydantic import Extra, root_validator, BaseModel
+
+from langchain.callbacks.manager import (
+    CallbackManagerForLLMRun,
+)
+from langchain.utils import get_from_dict_or_env
+
+logger = logging.getLogger(__name__)
+
+
+class BaichuanModelAPI(BaseModel):
+    api_key: str
+    secret_key: str
+
+    base_url: str = "https://api.baichuan-ai.com/v1"
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+
+    def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any):
+        stream = 'stream' in kwargs and kwargs['stream']
+
+        url = self.base_url + ("/stream/chat" if stream else "/chat")
+
+        data = {
+            "model": model,
+            "messages": messages,
+            "parameters": parameters
+        }
+
+        json_data = json.dumps(data)
+        time_stamp = int(time.time())
+        signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp))
+
+        headers = {
+            "Content-Type": "application/json",
+            "Authorization": "Bearer " + self.api_key,
+            "X-BC-Request-Id": "your requestId",
+            "X-BC-Timestamp": str(time_stamp),
+            "X-BC-Signature": signature,
+            "X-BC-Sign-Algo": "MD5",
+        }
+
+        response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60))
+
+        if not response.ok:
+            raise ValueError(f"HTTP {response.status_code} error: {response.text}")
+
+        if not stream:
+            json_response = response.json()
+            if json_response['code'] != 0:
+                raise ValueError(
+                    f"API {json_response['code']}"
+                    f" error: {json_response['msg']}"
+                )
+            return json_response
+        else:
+            return response
+
+    def _calculate_md5(self, input_string):
+        md5 = hashlib.md5()
+        md5.update(input_string.encode('utf-8'))
+        encrypted = md5.hexdigest()
+        return encrypted
+
+
+class BaichuanChatLLM(BaseChatModel):
+    """Wrapper around Baichuan large language models.
+    To use, you should pass the api_key as a named parameter to the constructor.
+    Example:
+     .. code-block:: python
+         from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
+         model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key")
+    """
+
+    @property
+    def lc_secrets(self) -> Dict[str, str]:
+        return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
+
+    @property
+    def lc_serializable(self) -> bool:
+        return True
+
+    client: Any = None  #: :meta private:
+    model: str = "Baichuan2-53B"
+    """Model name to use."""
+    temperature: float = 0.3
+    """A non-negative float that tunes the degree of randomness in generation."""
+    top_p: float = 0.85
+    """Total probability mass of tokens to consider at each step."""
+    streaming: bool = False
+    """Whether to stream the response or return it all at once."""
+    api_key: Optional[str] = None
+    secret_key: Optional[str] = None
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        values["api_key"] = get_from_dict_or_env(
+            values, "api_key", "BAICHUAN_API_KEY"
+        )
+
+        values["secret_key"] = get_from_dict_or_env(
+            values, "secret_key", "BAICHUAN_SECRET_KEY"
+        )
+
+        values['client'] = BaichuanModelAPI(
+            api_key=values['api_key'],
+            secret_key=values['secret_key']
+        )
+        return values
+
+    @property
+    def _default_params(self) -> Dict[str, Any]:
+        """Get the default parameters for calling OpenAI API."""
+        return {
+            "model": self.model,
+            "parameters": {
+                "temperature": self.temperature,
+                "top_p": self.top_p
+            }
+        }
+
+    @property
+    def _identifying_params(self) -> Dict[str, Any]:
+        """Get the identifying parameters."""
+        return self._default_params
+
+    @property
+    def _llm_type(self) -> str:
+        """Return type of llm."""
+        return "baichuan"
+
+    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
+        if isinstance(message, ChatMessage):
+            message_dict = {"role": message.role, "content": message.content}
+        elif isinstance(message, HumanMessage):
+            message_dict = {"role": "user", "content": message.content}
+        elif isinstance(message, AIMessage):
+            message_dict = {"role": "assistant", "content": message.content}
+        elif isinstance(message, SystemMessage):
+            message_dict = {"role": "user", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        return message_dict
+
+    def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
+        role = _dict["role"]
+        if role == "user":
+            return HumanMessage(content=_dict["content"])
+        elif role == "assistant":
+            return AIMessage(content=_dict["content"])
+        elif role == "system":
+            return SystemMessage(content=_dict["content"])
+        else:
+            return ChatMessage(content=_dict["content"], role=role)
+
+    def _create_message_dicts(
+        self, messages: List[BaseMessage]
+    ) -> List[Dict[str, Any]]:
+        dict_messages = []
+        for m in messages:
+            message = self._convert_message_to_dict(m)
+            if dict_messages:
+                previous_message = dict_messages[-1]
+                if previous_message['role'] == message['role']:
+                    dict_messages[-1]['content'] += f"\n{message['content']}"
+                else:
+                    dict_messages.append(message)
+            else:
+                dict_messages.append(message)
+
+        return dict_messages
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        if self.streaming:
+            generation: Optional[ChatGenerationChunk] = None
+            llm_output: Optional[Dict] = None
+            for chunk in self._stream(
+                    messages=messages, stop=stop, run_manager=run_manager, **kwargs
+            ):
+                if generation is None:
+                    generation = chunk
+                else:
+                    generation += chunk
+
+                if chunk.generation_info is not None \
+                        and 'token_usage' in chunk.generation_info:
+                    llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
+
+            assert generation is not None
+            return ChatResult(generations=[generation], llm_output=llm_output)
+        else:
+            message_dicts = self._create_message_dicts(messages)
+            params = self._default_params
+            params["messages"] = message_dicts
+            params.update(kwargs)
+            response = self.client.do_request(**params)
+            return self._create_chat_result(response)
+
+    def _stream(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> Iterator[ChatGenerationChunk]:
+        message_dicts = self._create_message_dicts(messages)
+        params = self._default_params
+        params["messages"] = message_dicts
+        params.update(kwargs)
+
+        for event in self.client.do_request(stream=True, **params).iter_lines():
+            if event:
+                event = event.decode("utf-8")
+
+                meta = json.loads(event)
+
+                if meta['code'] != 0:
+                    raise ValueError(
+                        f"API {meta['code']}"
+                        f" error: {meta['msg']}"
+                    )
+
+                content = meta['data']['messages'][0]['content']
+
+                chunk_kwargs = {
+                    'message': AIMessageChunk(content=content),
+                }
+
+                if 'usage' in meta:
+                    token_usage = meta['usage']
+                    overall_token_usage = {
+                        'prompt_tokens': token_usage.get('prompt_tokens', 0),
+                        'completion_tokens': token_usage.get('answer_tokens', 0),
+                        'total_tokens': token_usage.get('total_tokens', 0)
+                    }
+                    chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
+
+                yield ChatGenerationChunk(**chunk_kwargs)
+                if run_manager:
+                    run_manager.on_llm_new_token(content)
+
+    def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
+        data = response["data"]
+        generations = []
+        for res in data["messages"]:
+            message = self._convert_dict_to_message(res)
+            gen = ChatGeneration(
+                message=message
+            )
+            generations.append(gen)
+        usage = response.get("usage")
+        token_usage = {
+            'prompt_tokens': usage.get('prompt_tokens', 0),
+            'completion_tokens': usage.get('answer_tokens', 0),
+            'total_tokens': usage.get('total_tokens', 0)
+        }
+        llm_output = {"token_usage": token_usage, "model_name": self.model}
+        return ChatResult(generations=generations, llm_output=llm_output)
+
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        """Get the number of tokens in the messages.
+
+        Useful for checking if an input will fit in a model's context window.
+
+        Args:
+            messages: The message inputs to tokenize.
+
+        Returns:
+            The sum of the number of tokens across the messages.
+        """
+        return sum([self.get_num_tokens(m.content) for m in messages])
+
+    def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+        token_usage: dict = {}
+        for output in llm_outputs:
+            if output is None:
+                # Happens in streaming
+                continue
+            token_usage = output["token_usage"]
+
+        return {"token_usage": token_usage, "model_name": self.model}

+ 4 - 0
api/tests/integration_tests/.env.example

@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
 # ZhipuAI Credentials
 ZHIPUAI_API_KEY=
 
+# Baichuan Credentials
+BAICHUAN_API_KEY=
+BAICHUAN_SECRET_KEY=
+
 # ChatGLM Credentials
 CHATGLM_API_BASE=
 

+ 81 - 0
api/tests/integration_tests/models/llm/test_baichuan_model.py

@@ -0,0 +1,81 @@
+import json
+import os
+from unittest.mock import patch
+
+
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelKwargs
+from core.model_providers.models.llm.baichuan_model import BaichuanModel
+from core.model_providers.providers.baichuan_provider import BaichuanProvider
+from models.provider import Provider, ProviderType
+
+
+def get_mock_provider(valid_api_key, valid_secret_key):
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='baichuan',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps({
+            'api_key': valid_api_key,
+            'secret_key': valid_secret_key,
+        }),
+        is_valid=True,
+    )
+
+
+def get_mock_model(model_name: str, streaming: bool = False):
+    model_kwargs = ModelKwargs(
+        temperature=0.01,
+    )
+    valid_api_key = os.environ['BAICHUAN_API_KEY']
+    valid_secret_key = os.environ['BAICHUAN_SECRET_KEY']
+    model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
+    return BaichuanModel(
+        model_provider=model_provider,
+        name=model_name,
+        model_kwargs=model_kwargs,
+        streaming=streaming
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_chat_get_num_tokens(mock_decrypt):
+    model = get_mock_model('baichuan2-53b')
+    rst = model.get_num_tokens([
+        PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
+        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+    ])
+    assert rst > 0
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_chat_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    model = get_mock_model('baichuan2-53b')
+    messages = [
+        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+    ]
+    rst = model.run(
+        messages,
+    )
+    assert len(rst.content) > 0
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_chat_stream_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    model = get_mock_model('baichuan2-53b', streaming=True)
+    messages = [
+        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+    ]
+    rst = model.run(
+        messages
+    )
+    assert len(rst.content) > 0

+ 97 - 0
api/tests/unit_tests/model_providers/test_baichuan_provider.py

@@ -0,0 +1,97 @@
+import pytest
+from unittest.mock import patch
+import json
+
+from langchain.schema import ChatResult, ChatGeneration, AIMessage
+
+from core.model_providers.providers.baichuan_provider import BaichuanProvider
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from models.provider import ProviderType, Provider
+
+
+PROVIDER_NAME = 'baichuan'
+MODEL_PROVIDER_CLASS = BaichuanProvider
+VALIDATE_CREDENTIAL = {
+    'api_key': 'valid_key',
+    'secret_key': 'valid_key',
+}
+
+
+def encrypt_side_effect(tenant_id, encrypt_key):
+    return f'encrypted_{encrypt_key}'
+
+
+def decrypt_side_effect(tenant_id, encrypted_key):
+    return encrypted_key.replace('encrypted_', '')
+
+
+def test_is_provider_credentials_valid_or_raise_valid(mocker):
+    mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate',
+                 return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
+
+    MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
+
+
+def test_is_provider_credentials_valid_or_raise_invalid():
+    # raise CredentialsValidateFailedError if api_key is not in credentials
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
+
+    credential = VALIDATE_CREDENTIAL.copy()
+    credential['api_key'] = 'invalid_key'
+    credential['secret_key'] = 'invalid_key'
+
+    # raise CredentialsValidateFailedError if api_key is invalid
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
+
+
+@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
+def test_encrypt_credentials(mock_encrypt):
+    result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
+    assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
+    assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_credentials_custom(mock_decrypt):
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
+    encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
+
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps(encrypted_credential),
+        is_valid=True,
+    )
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_provider_credentials()
+    assert result['api_key'] == 'valid_key'
+    assert result['secret_key'] == 'valid_key'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_credentials_obfuscated(mock_decrypt):
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
+    encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
+
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps(encrypted_credential),
+        is_valid=True,
+    )
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_provider_credentials(obfuscated=True)
+    middle_token = result['api_key'][6:-2]
+    secret_key_middle_token = result['secret_key'][6:-2]
+    assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
+    assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
+    assert all(char == '*' for char in middle_token)
+    assert all(char == '*' for char in secret_key_middle_token)