Explorar o código

fix: remove openllm pypi package because of this package too large (#931)

takatost hai 1 ano
pai
achega
6c832ee328

+ 2 - 2
api/core/model_providers/models/llm/openllm_model.py

@@ -1,13 +1,13 @@
 from typing import List, Optional, Any
 
 from langchain.callbacks.manager import Callbacks
-from langchain.llms import OpenLLM
 from langchain.schema import LLMResult
 
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.openllm import OpenLLM
 
 
 class OpenLLMModel(BaseLLM):
@@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
         client = OpenLLM(
             server_url=self.credentials.get('server_url'),
             callbacks=self.callbacks,
-            **self.provider_model_kwargs
+            llm_kwargs=self.provider_model_kwargs
         )
 
         return client

+ 6 - 5
api/core/model_providers/providers/openllm_provider.py

@@ -1,14 +1,13 @@
 import json
 from typing import Type
 
-from langchain.llms import OpenLLM
-
 from core.helper import encrypter
 from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
 from core.model_providers.models.llm.openllm_model import OpenLLMModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
 from core.model_providers.models.base import BaseProviderModel
+from core.third_party.langchain.llms.openllm import OpenLLM
 from models.provider import ProviderType
 
 
@@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
         :return:
         """
         return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=1),
+            temperature=KwargRule[float](min=0.01, max=2, default=1),
             top_p=KwargRule[float](min=0, max=1, default=0.7),
             presence_penalty=KwargRule[float](min=-2, max=2, default=0),
             frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
-            max_tokens=KwargRule[int](min=10, max=4000, default=128),
+            max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
         )
 
     @classmethod
@@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
             }
 
             llm = OpenLLM(
-                max_tokens=10,
+                llm_kwargs={
+                    'max_new_tokens': 10
+                },
                 **credential_kwargs
             )
 

+ 87 - 0
api/core/third_party/langchain/llms/openllm.py

@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+import logging
+from typing import (
+    Any,
+    Dict,
+    List,
+    Optional,
+)
+
+import requests
+from langchain.llms.utils import enforce_stop_tokens
+from pydantic import Field
+
+from langchain.callbacks.manager import (
+    AsyncCallbackManagerForLLMRun,
+    CallbackManagerForLLMRun,
+)
+from langchain.llms.base import LLM
+
+logger = logging.getLogger(__name__)
+
+
+class OpenLLM(LLM):
+    """OpenLLM, supporting both in-process model
+    instance and remote OpenLLM servers.
+
+    If you have a OpenLLM server running, you can also use it remotely:
+        .. code-block:: python
+
+            from langchain.llms import OpenLLM
+            llm = OpenLLM(server_url='http://localhost:3000')
+            llm("What is the difference between a duck and a goose?")
+    """
+
+    server_url: Optional[str] = None
+    """Optional server URL that currently runs a LLMServer with 'openllm start'."""
+    llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
+    """Key word arguments to be passed to openllm.LLM"""
+
+    @property
+    def _llm_type(self) -> str:
+        return "openllm"
+
+    def _call(
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: CallbackManagerForLLMRun | None = None,
+        **kwargs: Any,
+    ) -> str:
+        params = {
+            "prompt": prompt,
+            "llm_config": self.llm_kwargs
+        }
+
+        headers = {"Content-Type": "application/json"}
+        response = requests.post(
+            f'{self.server_url}/v1/generate',
+            headers=headers,
+            json=params
+        )
+
+        if not response.ok:
+            raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
+
+        json_response = response.json()
+        completion = json_response["responses"][0]
+
+        if completion:
+            completion = completion[len(prompt):]
+
+        if stop is not None:
+            completion = enforce_stop_tokens(completion, stop)
+
+        return completion
+
+    async def _acall(
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
+        **kwargs: Any,
+    ) -> str:
+        raise NotImplementedError(
+            "Async call is not supported for OpenLLM at the moment."
+        )

+ 1 - 2
api/requirements.txt

@@ -49,5 +49,4 @@ huggingface_hub~=0.16.4
 transformers~=4.31.0
 stripe~=5.5.0
 pandas==1.5.3
-xinference==0.2.0
-openllm~=0.2.26
+xinference==0.2.0

+ 1 - 4
api/tests/unit_tests/model_providers/test_openllm_provider.py

@@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
 
 
 def test_is_credentials_valid_or_raise_valid(mocker):
-    mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
-    mocker.patch('langchain.llms.openllm.OpenLLM._call',
+    mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
                  return_value="abc")
 
     MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
@@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
 
 
 def test_is_credentials_valid_or_raise_invalid(mocker):
-    mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
-
     # raise CredentialsValidateFailedError if credential is not in credentials
     with pytest.raises(CredentialsValidateFailedError):
         MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(