Przeglądaj źródła

feat: add openllm support (#928)

takatost 1 rok temu
rodzic
commit
3ea8d7a019

+ 3 - 0
api/core/model_providers/model_provider_factory.py

@@ -60,6 +60,9 @@ class ModelProviderFactory:
         elif provider_name == 'xinference':
             from core.model_providers.providers.xinference_provider import XinferenceProvider
             return XinferenceProvider
+        elif provider_name == 'openllm':
+            from core.model_providers.providers.openllm_provider import OpenLLMProvider
+            return OpenLLMProvider
         else:
             raise NotImplementedError
 

+ 60 - 0
api/core/model_providers/models/llm/openllm_model.py

@@ -0,0 +1,60 @@
+from typing import List, Optional, Any
+
+from langchain.callbacks.manager import Callbacks
+from langchain.llms import OpenLLM
+from langchain.schema import LLMResult
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+
+
+class OpenLLMModel(BaseLLM):
+    model_mode: ModelMode = ModelMode.COMPLETION
+
+    def _init_client(self) -> Any:
+        self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+
+        client = OpenLLM(
+            server_url=self.credentials.get('server_url'),
+            callbacks=self.callbacks,
+            **self.provider_model_kwargs
+        )
+
+        return client
+
+    def _run(self, messages: List[PromptMessage],
+             stop: Optional[List[str]] = None,
+             callbacks: Callbacks = None,
+             **kwargs) -> LLMResult:
+        """
+        run predict by prompt messages and stop words.
+
+        :param messages:
+        :param stop:
+        :param callbacks:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return self._client.generate([prompts], stop, callbacks)
+
+    def get_num_tokens(self, messages: List[PromptMessage]) -> int:
+        """
+        get num tokens of prompt messages.
+
+        :param messages:
+        :return:
+        """
+        prompts = self._get_prompt_from_messages(messages)
+        return max(self._client.get_num_tokens(prompts), 0)
+
+    def _set_model_kwargs(self, model_kwargs: ModelKwargs):
+        pass
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"OpenLLM: {str(ex)}")
+
+    @classmethod
+    def support_streaming(cls):
+        return False

+ 137 - 0
api/core/model_providers/providers/openllm_provider.py

@@ -0,0 +1,137 @@
+import json
+from typing import Type
+
+from langchain.llms import OpenLLM
+
+from core.helper import encrypter
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.llm.openllm_model import OpenLLMModel
+from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+
+from core.model_providers.models.base import BaseProviderModel
+from models.provider import ProviderType
+
+
+class OpenLLMProvider(BaseModelProvider):
+    @property
+    def provider_name(self):
+        """
+        Returns the name of a provider.
+        """
+        return 'openllm'
+
+    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
+        return []
+
+    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
+        """
+        Returns the model class.
+
+        :param model_type:
+        :return:
+        """
+        if model_type == ModelType.TEXT_GENERATION:
+            model_class = OpenLLMModel
+        else:
+            raise NotImplementedError
+
+        return model_class
+
+    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
+        """
+        get model parameter rules.
+
+        :param model_name:
+        :param model_type:
+        :return:
+        """
+        return ModelKwargsRules(
+            temperature=KwargRule[float](min=0, max=2, default=1),
+            top_p=KwargRule[float](min=0, max=1, default=0.7),
+            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
+            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
+            max_tokens=KwargRule[int](min=10, max=4000, default=128),
+        )
+
+    @classmethod
+    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
+        """
+        check model credentials valid.
+
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        """
+        if 'server_url' not in credentials:
+            raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
+
+        try:
+            credential_kwargs = {
+                'server_url': credentials['server_url']
+            }
+
+            llm = OpenLLM(
+                max_tokens=10,
+                **credential_kwargs
+            )
+
+            llm("ping")
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @classmethod
+    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
+                                  credentials: dict) -> dict:
+        """
+        encrypt model credentials for save.
+
+        :param tenant_id:
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        :return:
+        """
+        credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
+        return credentials
+
+    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
+        """
+        get credentials for llm use.
+
+        :param model_name:
+        :param model_type:
+        :param obfuscated:
+        :return:
+        """
+        if self.provider.provider_type != ProviderType.CUSTOM.value:
+            raise NotImplementedError
+
+        provider_model = self._get_provider_model(model_name, model_type)
+
+        if not provider_model.encrypted_config:
+            return {
+                'server_url': None
+            }
+
+        credentials = json.loads(provider_model.encrypted_config)
+        if credentials['server_url']:
+            credentials['server_url'] = encrypter.decrypt_token(
+                self.provider.tenant_id,
+                credentials['server_url']
+            )
+
+            if obfuscated:
+                credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
+
+        return credentials
+
+    @classmethod
+    def is_provider_credentials_valid_or_raise(cls, credentials: dict):
+        return
+
+    @classmethod
+    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
+        return {}
+
+    def get_provider_credentials(self, obfuscated: bool = False) -> dict:
+        return {}

+ 2 - 1
api/core/model_providers/rules/_providers.json

@@ -9,5 +9,6 @@
   "chatglm",
   "replicate",
   "huggingface_hub",
-  "xinference"
+  "xinference",
+  "openllm"
 ]

+ 7 - 0
api/core/model_providers/rules/openllm.json

@@ -0,0 +1,7 @@
+{
+    "support_provider_types": [
+        "custom"
+    ],
+    "system_config": null,
+    "model_flexibility": "configurable"
+}

+ 2 - 1
api/requirements.txt

@@ -49,4 +49,5 @@ huggingface_hub~=0.16.4
 transformers~=4.31.0
 stripe~=5.5.0
 pandas==1.5.3
-xinference==0.2.0
+xinference==0.2.0
+openllm~=0.2.26

+ 4 - 1
api/tests/integration_tests/.env.example

@@ -36,4 +36,7 @@ CHATGLM_API_BASE=
 
 # Xinference Credentials
 XINFERENCE_SERVER_URL=
-XINFERENCE_MODEL_UID=
+XINFERENCE_MODEL_UID=
+
+# OpenLLM Credentials
+OPENLLM_SERVER_URL=

+ 72 - 0
api/tests/integration_tests/models/llm/test_openllm_model.py

@@ -0,0 +1,72 @@
+import json
+import os
+from unittest.mock import patch, MagicMock
+
+from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
+from core.model_providers.models.llm.openllm_model import OpenLLMModel
+from core.model_providers.providers.openllm_provider import OpenLLMProvider
+from models.provider import Provider, ProviderType, ProviderModel
+
+
+def get_mock_provider():
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='openllm',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config='',
+        is_valid=True,
+    )
+
+
+def get_mock_model(model_name, mocker):
+    model_kwargs = ModelKwargs(
+        max_tokens=10,
+        temperature=0.01
+    )
+    server_url = os.environ['OPENLLM_SERVER_URL']
+    model_provider = OpenLLMProvider(provider=get_mock_provider())
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        provider_name='openllm',
+        model_name=model_name,
+        model_type=ModelType.TEXT_GENERATION.value,
+        encrypted_config=json.dumps({
+            'server_url': server_url
+        }),
+        is_valid=True,
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    return OpenLLMModel(
+        model_provider=model_provider,
+        name=model_name,
+        model_kwargs=model_kwargs
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_num_tokens(mock_decrypt, mocker):
+    model = get_mock_model('facebook/opt-125m', mocker)
+    rst = model.get_num_tokens([
+        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+    ])
+    assert rst == 5
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    model = get_mock_model('facebook/opt-125m', mocker)
+    messages = [PromptMessage(content='Human: who are you? \nAnswer: ')]
+    rst = model.run(
+        messages
+    )
+    assert len(rst.content) > 0

+ 125 - 0
api/tests/unit_tests/model_providers/test_openllm_provider.py

@@ -0,0 +1,125 @@
+import pytest
+from unittest.mock import patch, MagicMock
+import json
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import CredentialsValidateFailedError
+from core.model_providers.providers.openllm_provider import OpenLLMProvider
+from models.provider import ProviderType, Provider, ProviderModel
+
+PROVIDER_NAME = 'openllm'
+MODEL_PROVIDER_CLASS = OpenLLMProvider
+VALIDATE_CREDENTIAL = {
+    'server_url': 'http://127.0.0.1:3333/'
+}
+
+
+def encrypt_side_effect(tenant_id, encrypt_key):
+    return f'encrypted_{encrypt_key}'
+
+
+def decrypt_side_effect(tenant_id, encrypted_key):
+    return encrypted_key.replace('encrypted_', '')
+
+
+def test_is_credentials_valid_or_raise_valid(mocker):
+    mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
+    mocker.patch('langchain.llms.openllm.OpenLLM._call',
+                 return_value="abc")
+
+    MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+        model_name='username/test_model_name',
+        model_type=ModelType.TEXT_GENERATION,
+        credentials=VALIDATE_CREDENTIAL.copy()
+    )
+
+
+def test_is_credentials_valid_or_raise_invalid(mocker):
+    mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
+
+    # raise CredentialsValidateFailedError if credential is not in credentials
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+            model_name='test_model_name',
+            model_type=ModelType.TEXT_GENERATION,
+            credentials={}
+        )
+
+    # raise CredentialsValidateFailedError if credential is invalid
+    with pytest.raises(CredentialsValidateFailedError):
+        MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
+            model_name='test_model_name',
+            model_type=ModelType.TEXT_GENERATION,
+            credentials={'server_url': 'invalid'})
+
+
+@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
+def test_encrypt_model_credentials(mock_encrypt):
+    api_key = 'http://127.0.0.1:3333/'
+    result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
+        tenant_id='tenant_id',
+        model_name='test_model_name',
+        model_type=ModelType.TEXT_GENERATION,
+        credentials=VALIDATE_CREDENTIAL.copy()
+    )
+    mock_encrypt.assert_called_with('tenant_id', api_key)
+    assert result['server_url'] == f'encrypted_{api_key}'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_model_credentials_custom(mock_decrypt, mocker):
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=None,
+        is_valid=True,
+    )
+
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        encrypted_config=json.dumps(encrypted_credential)
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_model_credentials(
+        model_name='test_model_name',
+        model_type=ModelType.TEXT_GENERATION
+    )
+    assert result['server_url'] == 'http://127.0.0.1:3333/'
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
+    provider = Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name=PROVIDER_NAME,
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=None,
+        is_valid=True,
+    )
+
+    encrypted_credential = VALIDATE_CREDENTIAL.copy()
+    encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        encrypted_config=json.dumps(encrypted_credential)
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    model_provider = MODEL_PROVIDER_CLASS(provider=provider)
+    result = model_provider.get_model_credentials(
+        model_name='test_model_name',
+        model_type=ModelType.TEXT_GENERATION,
+        obfuscated=True
+    )
+    middle_token = result['server_url'][6:-2]
+    assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
+    assert all(char == '*' for char in middle_token)