Bläddra i källkod

feat: add LocalAI local embedding model support (#1021)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
takatost 1 år sedan
förälder
incheckning
417c19577a
24 ändrade filer med 1118 tillägg och 7 borttagningar
  1. 3 0
      api/core/model_providers/model_provider_factory.py
  2. 29 0
      api/core/model_providers/models/embedding/localai_embedding.py
  3. 131 0
      api/core/model_providers/models/llm/localai_model.py
  4. 164 0
      api/core/model_providers/providers/localai_provider.py
  5. 2 1
      api/core/model_providers/rules/_providers.json
  6. 7 0
      api/core/model_providers/rules/localai.json
  7. 2 1
      api/core/third_party/langchain/llms/chat_open_ai.py
  8. 32 3
      api/core/third_party/langchain/llms/open_ai.py
  9. 4 1
      api/tests/integration_tests/.env.example
  10. 61 0
      api/tests/integration_tests/models/embedding/test_localai_embedding.py
  11. 68 0
      api/tests/integration_tests/models/llm/test_localai_model.py
  12. 116 0
      api/tests/unit_tests/model_providers/test_localai_provider.py
  13. 19 0
      web/app/components/base/icons/assets/public/llm/localai-text.svg
  14. 12 0
      web/app/components/base/icons/assets/public/llm/localai.svg
  15. 97 0
      web/app/components/base/icons/src/public/llm/Localai.json
  16. 14 0
      web/app/components/base/icons/src/public/llm/Localai.tsx
  17. 160 0
      web/app/components/base/icons/src/public/llm/LocalaiText.json
  18. 14 0
      web/app/components/base/icons/src/public/llm/LocalaiText.tsx
  19. 2 0
      web/app/components/base/icons/src/public/llm/index.ts
  20. 2 0
      web/app/components/header/account-setting/model-page/configs/index.ts
  21. 176 0
      web/app/components/header/account-setting/model-page/configs/localai.tsx
  22. 1 0
      web/app/components/header/account-setting/model-page/declarations.ts
  23. 1 0
      web/app/components/header/account-setting/model-page/index.tsx
  24. 1 1
      web/app/components/header/account-setting/model-page/utils.ts

+ 3 - 0
api/core/model_providers/model_provider_factory.py

@@ -63,6 +63,9 @@ class ModelProviderFactory:
         elif provider_name == 'openllm':
             from core.model_providers.providers.openllm_provider import OpenLLMProvider
             return OpenLLMProvider
+        elif provider_name == 'localai':
+            from core.model_providers.providers.localai_provider import LocalAIProvider
+            return LocalAIProvider
         else:
             raise NotImplementedError
 

+ 29 - 0
api/core/model_providers/models/embedding/localai_embedding.py

@@ -0,0 +1,29 @@
+from langchain.embeddings import LocalAIEmbeddings
+
+from replicate.exceptions import ModelError, ReplicateError
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from core.model_providers.models.embedding.base import BaseEmbedding
+
+
+class LocalAIEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = LocalAIEmbeddings(
+            model=name,
+            openai_api_key="1",
+            openai_api_base=credentials['server_url'],
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, (ModelError, ReplicateError)):
+            return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
+        else:
+            return ex

+ 131 - 0
api/core/model_providers/models/llm/localai_model.py

@@ -0,0 +1,131 @@
+import logging
+from typing import List, Optional, Any
+
+import openai
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import LLMResult, get_buffer_string
+
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
+from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class LocalAIModel(BaseLLM):
+    def __init__(self, model_provider: BaseModelProvider,
+                 name: str,
+                 model_kwargs: ModelKwargs,
+                 streaming: bool = False,
+                 callbacks: Callbacks = None):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        if credentials['completion_type'] == 'chat_completion':
+            self.model_mode = ModelMode.CHAT
+        else:
+            self.model_mode = ModelMode.COMPLETION
+
+        super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
+
+    def _init_client(self) -> Any:
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        if self.model_mode == ModelMode.COMPLETION:
+            client = EnhanceOpenAI(
+                model_name=self.name,
+                streaming=self.streaming,
+                callbacks=self.callbacks,
+                request_timeout=60,
+                openai_api_key="1",
+                openai_api_base=self.credentials['server_url'] + '/v1',
+                **provider_model_kwargs
+            )
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p')
+            }
+
+            client = EnhanceChatOpenAI(
+                model_name=self.name,
+                temperature=provider_model_kwargs.get('temperature'),
+                max_tokens=provider_model_kwargs.get('max_tokens'),
+                model_kwargs=extra_model_kwargs,
+                streaming=self.streaming,
+                callbacks=self.callbacks,
+                request_timeout=60,
+                openai_api_key="1",
+                openai_api_base=self.credentials['server_url'] + '/v1'
+            )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        if isinstance(prompts, str):
+            return self._client.get_num_tokens(prompts)
+        else:
+            return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
+        if self.model_mode == ModelMode.COMPLETION:
+            for k, v in provider_model_kwargs.items():
+                if hasattr(self.client, k):
+                    setattr(self.client, k, v)
+        else:
+            extra_model_kwargs = {
+                'top_p': provider_model_kwargs.get('top_p')
+            }
+
+            self.client.temperature = provider_model_kwargs.get('temperature')
+            self.client.max_tokens = provider_model_kwargs.get('max_tokens')
+            self.client.model_kwargs = extra_model_kwargs
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to LocalAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to LocalAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("LocalAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            return LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 164 - 0
api/core/model_providers/providers/localai_provider.py

@@ -0,0 +1,164 @@
+import json
+from typing import Type
+
+from langchain.embeddings import LocalAIEmbeddings
+from langchain.schema import HumanMessage
+
+from core.helper import encrypter
+from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
+from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
+from core.model_providers.models.llm.localai_model import LocalAIModel
+from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+
+from core.model_providers.models.base import BaseProviderModel
+from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
+from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
+from models.provider import ProviderType
+
+
+class LocalAIProvider(BaseModelProvider):
+    @property
+    def provider_name(self):
+        """
+        Returns the name of a provider.
+        """
+        return 'localai'
+
+    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
+        return []
+
+    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
+        """
+        Returns the model class.
+
+        :param model_type:
+        :return:
+        """
+        if model_type == ModelType.TEXT_GENERATION:
+            model_class = LocalAIModel
+        elif model_type == ModelType.EMBEDDINGS:
+            model_class = LocalAIEmbedding
+        else:
+            raise NotImplementedError
+
+        return model_class
+
+    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
+        """
+        get model parameter rules.
+
+        :param model_name:
+        :param model_type:
+        :return:
+        """
+        return ModelKwargsRules(
+            temperature=KwargRule[float](min=0, max=2, default=0.7),
+            top_p=KwargRule[float](min=0, max=1, default=1),
+            max_tokens=KwargRule[int](min=10, max=4097, default=16),
+        )
+
+    @classmethod
+    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
+        """
+        check model credentials valid.
+
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        """
+        if 'server_url' not in credentials:
+            raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
+
+        try:
+            if model_type == ModelType.EMBEDDINGS:
+                model = LocalAIEmbeddings(
+                    model=model_name,
+                    openai_api_key='1',
+                    openai_api_base=credentials['server_url']
+                )
+
+                model.embed_query("ping")
+            else:
+                if ('completion_type' not in credentials
+                        or credentials['completion_type'] not in ['completion', 'chat_completion']):
+                    raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
+
+                if credentials['completion_type'] == 'chat_completion':
+                    model = EnhanceChatOpenAI(
+                        model_name=model_name,
+                        openai_api_key='1',
+                        openai_api_base=credentials['server_url'] + '/v1',
+                        max_tokens=10,
+                        request_timeout=60,
+                    )
+
+                    model([HumanMessage(content='ping')])
+                else:
+                    model = EnhanceOpenAI(
+                        model_name=model_name,
+                        openai_api_key='1',
+                        openai_api_base=credentials['server_url'] + '/v1',
+                        max_tokens=10,
+                        request_timeout=60,
+                    )
+
+                    model('ping')
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @classmethod
+    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
+                                  credentials: dict) -> dict:
+        """
+        encrypt model credentials for save.
+
+        :param tenant_id:
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        :return:
+        """
+        credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
+        return credentials
+
+    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
+        """
+        get credentials for llm use.
+
+        :param model_name:
+        :param model_type:
+        :param obfuscated:
+        :return:
+        """
+        if self.provider.provider_type != ProviderType.CUSTOM.value:
+            raise NotImplementedError
+
+        provider_model = self._get_provider_model(model_name, model_type)
+
+        if not provider_model.encrypted_config:
+            return {
+                'server_url': None,
+            }
+
+        credentials = json.loads(provider_model.encrypted_config)
+        if credentials['server_url']:
+            credentials['server_url'] = encrypter.decrypt_token(
+                self.provider.tenant_id,
+                credentials['server_url']
+            )
+
+            if obfuscated:
+                credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
+
+        return credentials
+
+    @classmethod
+    def is_provider_credentials_valid_or_raise(cls, credentials: dict):
+        return
+
+    @classmethod
+    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
+        return {}
+
+    def get_provider_credentials(self, obfuscated: bool = False) -> dict:
+        return {}

+ 2 - 1
api/core/model_providers/rules/_providers.json

@@ -10,5 +10,6 @@
   "replicate",
   "huggingface_hub",
   "xinference",
-  "openllm"
+  "openllm",
+  "localai"
 ]

+ 7 - 0
api/core/model_providers/rules/localai.json

@@ -0,0 +1,7 @@
+{
+    "support_provider_types": [
+        "custom"
+    ],
+    "system_config": null,
+    "model_flexibility": "configurable"
+}

+ 2 - 1
api/core/third_party/langchain/llms/chat_open_ai.py

@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
         return {
             **super()._default_params,
             "api_type": 'openai',
-            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
+            "api_base": self.openai_api_base if self.openai_api_base
+            else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
             "api_version": None,
             "api_key": self.openai_api_key,
             "organization": self.openai_organization if self.openai_organization else None,

+ 32 - 3
api/core/third_party/langchain/llms/open_ai.py

@@ -1,7 +1,10 @@
 import os
 
-from typing import Dict, Any, Mapping, Optional, Union, Tuple
+from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
 from langchain import OpenAI
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
+from langchain.schema.output import GenerationChunk
 from pydantic import root_validator
 
 
@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
     def _invocation_params(self) -> Dict[str, Any]:
         return {**super()._invocation_params, **{
             "api_type": 'openai',
-            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
+            "api_base": self.openai_api_base if self.openai_api_base
+            else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
             "api_version": None,
             "api_key": self.openai_api_key,
             "organization": self.openai_organization if self.openai_organization else None,
@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
     def _identifying_params(self) -> Mapping[str, Any]:
         return {**super()._identifying_params, **{
             "api_type": 'openai',
-            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
+            "api_base": self.openai_api_base if self.openai_api_base
+            else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
             "api_version": None,
             "api_key": self.openai_api_key,
             "organization": self.openai_organization if self.openai_organization else None,
         }}
+
+    def _stream(
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        **kwargs: Any,
+    ) -> Iterator[GenerationChunk]:
+        params = {**self._invocation_params, **kwargs, "stream": True}
+        self.get_sub_prompts(params, [prompt], stop)  # this mutates params
+        for stream_resp in completion_with_retry(
+            self, prompt=prompt, run_manager=run_manager, **params
+        ):
+            if 'text' in stream_resp["choices"][0]:
+                chunk = _stream_response_to_generation_chunk(stream_resp)
+                yield chunk
+                if run_manager:
+                    run_manager.on_llm_new_token(
+                        chunk.text,
+                        verbose=self.verbose,
+                        logprobs=chunk.generation_info["logprobs"]
+                        if chunk.generation_info
+                        else None,
+                    )

+ 4 - 1
api/tests/integration_tests/.env.example

@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
 XINFERENCE_MODEL_UID=
 
 # OpenLLM Credentials
-OPENLLM_SERVER_URL=
+OPENLLM_SERVER_URL=
+
+# LocalAI Credentials
+LOCALAI_SERVER_URL=

+ 61 - 0
api/tests/integration_tests/models/embedding/test_localai_embedding.py

@@ -0,0 +1,61 @@
+import json
+import os
+from unittest.mock import patch, MagicMock
+
+from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.localai_provider import LocalAIProvider
+from models.provider import Provider, ProviderType, ProviderModel
+
+
+def get_mock_provider():
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='localai',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config='',
+        is_valid=True,
+    )
+
+
+def get_mock_embedding_model(mocker):
+    model_name = 'text-embedding-ada-002'
+    server_url = os.environ['LOCALAI_SERVER_URL']
+    model_provider = LocalAIProvider(provider=get_mock_provider())
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        provider_name='localai',
+        model_name=model_name,
+        model_type=ModelType.EMBEDDINGS.value,
+        encrypted_config=json.dumps({
+            'server_url': server_url,
+        }),
+        is_valid=True,
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    return LocalAIEmbedding(
+        model_provider=model_provider,
+        name=model_name
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_embed_documents(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(mocker)
+    rst = embedding_model.client.embed_documents(['test', 'test1'])
+    assert isinstance(rst, list)
+    assert len(rst) == 2
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_embed_query(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(mocker)
+    rst = embedding_model.client.embed_query('test')
+    assert isinstance(rst, list)

+ 68 - 0
api/tests/integration_tests/models/llm/test_localai_model.py

@@ -0,0 +1,68 @@
+import json
+import os
+from unittest.mock import patch, MagicMock
+
+from core.model_providers.models.llm.localai_model import LocalAIModel
+from core.model_providers.providers.localai_provider import LocalAIProvider
+from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
+from models.provider import Provider, ProviderType, ProviderModel
+
+
+def get_mock_provider(server_url):
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='localai',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps({}),
+        is_valid=True,
+    )
+
+
+def get_mock_model(model_name, mocker):
+    model_kwargs = ModelKwargs(
+        max_tokens=10,
+        temperature=0
+    )
+    server_url = os.environ['LOCALAI_SERVER_URL']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        provider_name='localai',
+        model_name=model_name,
+        model_type=ModelType.TEXT_GENERATION.value,
+        encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
+        is_valid=True,
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
+    return LocalAIModel(
+        model_provider=openai_provider,
+        name=model_name,
+        model_kwargs=model_kwargs
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
+    return encrypted_openai_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_num_tokens(mock_decrypt, mocker):
+    openai_model = get_mock_model('ggml-gpt4all-j', mocker)
+    rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
+    assert rst > 0
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    openai_model = get_mock_model('ggml-gpt4all-j', mocker)
+    rst = openai_model.run(
+        [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
+        stop=['\nHuman:'],
+    )
+    assert len(rst.content) > 0

+ 116 - 0
api/tests/unit_tests/model_providers/test_localai_provider.py

@@ -0,0 +1,116 @@
+import pytest
+from unittest.mock import patch, MagicMock
+import json
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from core.model_providers.providers.localai_provider import LocalAIProvider
+from models.provider import ProviderType, Provider, ProviderModel
+
+PROVIDER_NAME = 'localai'
+MODEL_PROVIDER_CLASS = LocalAIProvider
+VALIDATE_CREDENTIAL = {
+    'server_url': 'http://127.0.0.1:8080/'
+}
+
+
+def encrypt_side_effect(tenant_id, encrypt_key):
+    return f'encrypted_{encrypt_key}'
+
+
+def decrypt_side_effect(tenant_id, encrypted_key):
+    return encrypted_key.replace('encrypted_', '')
+
+
+def test_is_credentials_valid_or_raise_valid(mocker):
+    mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
+                 return_value="abc")
+
+    MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+        model_name='username/test_model_name',
+        model_type=ModelType.EMBEDDINGS,
+        credentials=VALIDATE_CREDENTIAL.copy()
+    )
+
+
+def test_is_credentials_valid_or_raise_invalid():
+    # raise CredentialsValidateFailedError if server_url is not in credentials
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+            model_name='test_model_name',
+            model_type=ModelType.EMBEDDINGS,
+            credentials={}
+        )
+
+
+@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
+def test_encrypt_model_credentials(mock_encrypt, mocker):
+    server_url = 'http://127.0.0.1:8080/'
+
+    result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
+        tenant_id='tenant_id',
+        model_name='test_model_name',
+        model_type=ModelType.EMBEDDINGS,
+        credentials=VALIDATE_CREDENTIAL.copy()
+    )
+    mock_encrypt.assert_called_with('tenant_id', server_url)
+    assert result['server_url'] == f'encrypted_{server_url}'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_model_credentials_custom(mock_decrypt, mocker):
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=None,
+        is_valid=True,
+    )
+
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        encrypted_config=json.dumps(encrypted_credential)
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_model_credentials(
+        model_name='test_model_name',
+        model_type=ModelType.EMBEDDINGS
+    )
+    assert result['server_url'] == 'http://127.0.0.1:8080/'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=None,
+        is_valid=True,
+    )
+
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        encrypted_config=json.dumps(encrypted_credential)
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_model_credentials(
+        model_name='test_model_name',
+        model_type=ModelType.EMBEDDINGS,
+        obfuscated=True
+    )
+    middle_token = result['server_url'][6:-2]
+    assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
+    assert all(char == '*' for char in middle_token)

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 19 - 0
web/app/components/base/icons/assets/public/llm/localai-text.svg


Filskillnaden har hållts tillbaka eftersom den är för stor
+ 12 - 0
web/app/components/base/icons/assets/public/llm/localai.svg


Filskillnaden har hållts tillbaka eftersom den är för stor
+ 97 - 0
web/app/components/base/icons/src/public/llm/Localai.json


+ 14 - 0
web/app/components/base/icons/src/public/llm/Localai.tsx

@@ -0,0 +1,14 @@
+// GENERATE BY script
+// DON NOT EDIT IT MANUALLY
+
+import * as React from 'react'
+import data from './Localai.json'
+import IconBase from '@/app/components/base/icons/IconBase'
+import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
+
+const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
+  props,
+  ref,
+) => <IconBase {...props} ref={ref} data={data as IconData} />)
+
+export default Icon

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 160 - 0
web/app/components/base/icons/src/public/llm/LocalaiText.json


+ 14 - 0
web/app/components/base/icons/src/public/llm/LocalaiText.tsx

@@ -0,0 +1,14 @@
+// GENERATE BY script
+// DON NOT EDIT IT MANUALLY
+
+import * as React from 'react'
+import data from './LocalaiText.json'
+import IconBase from '@/app/components/base/icons/IconBase'
+import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
+
+const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
+  props,
+  ref,
+) => <IconBase {...props} ref={ref} data={data as IconData} />)
+
+export default Icon

+ 2 - 0
web/app/components/base/icons/src/public/llm/index.ts

@@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface'
 export { default as IflytekSparkTextCn } from './IflytekSparkTextCn'
 export { default as IflytekSparkText } from './IflytekSparkText'
 export { default as IflytekSpark } from './IflytekSpark'
+export { default as LocalaiText } from './LocalaiText'
+export { default as Localai } from './Localai'
 export { default as Microsoft } from './Microsoft'
 export { default as OpenaiBlack } from './OpenaiBlack'
 export { default as OpenaiBlue } from './OpenaiBlue'

+ 2 - 0
web/app/components/header/account-setting/model-page/configs/index.ts

@@ -10,6 +10,7 @@ import minimax from './minimax'
 import chatglm from './chatglm'
 import xinference from './xinference'
 import openllm from './openllm'
+import localai from './localai'
 
 export default {
   openai,
@@ -24,4 +25,5 @@ export default {
   chatglm,
   xinference,
   openllm,
+  localai,
 }

+ 176 - 0
web/app/components/header/account-setting/model-page/configs/localai.tsx

@@ -0,0 +1,176 @@
+import { ProviderEnum } from '../declarations'
+import type { FormValue, ProviderConfig } from '../declarations'
+import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm'
+
+const config: ProviderConfig = {
+  selector: {
+    name: {
+      'en': 'LocalAI',
+      'zh-Hans': 'LocalAI',
+    },
+    icon: <Localai className='w-full h-full' />,
+  },
+  item: {
+    key: ProviderEnum.localai,
+    titleIcon: {
+      'en': <LocalaiText className='h-6' />,
+      'zh-Hans': <LocalaiText className='h-6' />,
+    },
+    disable: {
+      tip: {
+        'en': 'Only supports the ',
+        'zh-Hans': '仅支持',
+      },
+      link: {
+        href: {
+          'en': 'https://docs.dify.ai/getting-started/install-self-hosted',
+          'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted',
+        },
+        label: {
+          'en': 'community open-source version',
+          'zh-Hans': '社区开源版本',
+        },
+      },
+    },
+  },
+  modal: {
+    key: ProviderEnum.localai,
+    title: {
+      'en': 'LocalAI',
+      'zh-Hans': 'LocalAI',
+    },
+    icon: <Localai className='h-6' />,
+    link: {
+      href: 'https://github.com/go-skynet/LocalAI',
+      label: {
+        'en': 'How to deploy LocalAI',
+        'zh-Hans': '如何部署 LocalAI',
+      },
+    },
+    defaultValue: {
+      model_type: 'text-generation',
+      completion_type: 'completion',
+    },
+    validateKeys: (v?: FormValue) => {
+      if (v?.model_type === 'text-generation') {
+        return [
+          'model_type',
+          'model_name',
+          'server_url',
+          'completion_type',
+        ]
+      }
+      if (v?.model_type === 'embeddings') {
+        return [
+          'model_type',
+          'model_name',
+          'server_url',
+        ]
+      }
+      return []
+    },
+    filterValue: (v?: FormValue) => {
+      let filteredKeys: string[] = []
+      if (v?.model_type === 'text-generation') {
+        filteredKeys = [
+          'model_type',
+          'model_name',
+          'server_url',
+          'completion_type',
+        ]
+      }
+      if (v?.model_type === 'embeddings') {
+        filteredKeys = [
+          'model_type',
+          'model_name',
+          'server_url',
+        ]
+      }
+      return filteredKeys.reduce((prev: FormValue, next: string) => {
+        prev[next] = v?.[next] || ''
+        return prev
+      }, {})
+    },
+    fields: [
+      {
+        type: 'radio',
+        key: 'model_type',
+        required: true,
+        label: {
+          'en': 'Model Type',
+          'zh-Hans': '模型类型',
+        },
+        options: [
+          {
+            key: 'text-generation',
+            label: {
+              'en': 'Text Generation',
+              'zh-Hans': '文本生成',
+            },
+          },
+          {
+            key: 'embeddings',
+            label: {
+              'en': 'Embeddings',
+              'zh-Hans': 'Embeddings',
+            },
+          },
+        ],
+      },
+      {
+        type: 'text',
+        key: 'model_name',
+        required: true,
+        label: {
+          'en': 'Model Name',
+          'zh-Hans': '模型名称',
+        },
+        placeholder: {
+          'en': 'Enter your Model Name here',
+          'zh-Hans': '在此输入您的模型名称',
+        },
+      },
+      {
+        hidden: (value?: FormValue) => value?.model_type === 'embeddings',
+        type: 'radio',
+        key: 'completion_type',
+        required: true,
+        label: {
+          'en': 'Completion Type',
+          'zh-Hans': 'Completion Type',
+        },
+        options: [
+          {
+            key: 'completion',
+            label: {
+              'en': 'Completion',
+              'zh-Hans': 'Completion',
+            },
+          },
+          {
+            key: 'chat_completion',
+            label: {
+              'en': 'Chat Completion',
+              'zh-Hans': 'Chat Completion',
+            },
+          },
+        ],
+      },
+      {
+        type: 'text',
+        key: 'server_url',
+        required: true,
+        label: {
+          'en': 'Server url',
+          'zh-Hans': 'Server url',
+        },
+        placeholder: {
+          'en': 'Enter your Server Url, eg: https://example.com/xxx',
+          'zh-Hans': '在此输入您的 Server Url,如:https://example.com/xxx',
+        },
+      },
+    ],
+  },
+}
+
+export default config

+ 1 - 0
web/app/components/header/account-setting/model-page/declarations.ts

@@ -41,6 +41,7 @@ export enum ProviderEnum {
   'chatglm' = 'chatglm',
   'xinference' = 'xinference',
   'openllm' = 'openllm',
+  'localai' = 'localai',
 }
 
 export type ProviderConfigItem = {

+ 1 - 0
web/app/components/header/account-setting/model-page/index.tsx

@@ -99,6 +99,7 @@ const ModelPage = () => {
       config.chatglm,
       config.xinference,
       config.openllm,
+      config.localai,
     ]
   }
 

+ 1 - 1
web/app/components/header/account-setting/model-page/utils.ts

@@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations'
 import { ProviderEnum } from './declarations'
 import { validateModelProvider } from '@/service/common'
 
-export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm]
+export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai]
 
 export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => {
   let body, url

Vissa filer visades inte eftersom för många filer har ändrats