ソースを参照

feat: optimize hf inference endpoint (#975)

takatost 1 年間 前
コミット
a76fde3d23

+ 5 - 7
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -1,16 +1,14 @@
-import decimal
-from functools import wraps
 from typing import List, Optional, Any
 from typing import List, Optional, Any
 
 
 from langchain import HuggingFaceHub
 from langchain import HuggingFaceHub
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
-from langchain.llms import HuggingFaceEndpoint
 from langchain.schema import LLMResult
 from langchain.schema import LLMResult
 
 
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
-from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
 
 
 
 
 class HuggingfaceHubModel(BaseLLM):
 class HuggingfaceHubModel(BaseLLM):
@@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
     def _init_client(self) -> Any:
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
         if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
         if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
-            client = HuggingFaceEndpoint(
+            client = HuggingFaceEndpointLLM(
                 endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
                 endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
-                task='text2text-generation',
+                task=self.credentials['task_type'],
                 model_kwargs=provider_model_kwargs,
                 model_kwargs=provider_model_kwargs,
                 huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
                 huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
-                callbacks=self.callbacks,
+                callbacks=self.callbacks
             )
             )
         else:
         else:
             client = HuggingFaceHub(
             client = HuggingFaceHub(

+ 13 - 3
api/core/model_providers/providers/huggingface_hub_provider.py

@@ -2,7 +2,6 @@ import json
 from typing import Type
 from typing import Type
 
 
 from huggingface_hub import HfApi
 from huggingface_hub import HfApi
-from langchain.llms import HuggingFaceEndpoint
 
 
 from core.helper import encrypter
 from core.helper import encrypter
 from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
 from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
 
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.base import BaseProviderModel
+from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
 from models.provider import ProviderType
 from models.provider import ProviderType
 
 
 
 
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
             if 'huggingfacehub_endpoint_url' not in credentials:
             if 'huggingfacehub_endpoint_url' not in credentials:
                 raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
                 raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
 
 
+            if 'task_type' not in credentials:
+                raise CredentialsValidateFailedError('Task Type must be provided.')
+
+            if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
+                raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
+
             try:
             try:
-                llm = HuggingFaceEndpoint(
+                llm = HuggingFaceEndpointLLM(
                     endpoint_url=credentials['huggingfacehub_endpoint_url'],
                     endpoint_url=credentials['huggingfacehub_endpoint_url'],
-                    task="text2text-generation",
+                    task=credentials['task_type'],
                     model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
                     model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
                     huggingfacehub_api_token=credentials['huggingfacehub_api_token']
                     huggingfacehub_api_token=credentials['huggingfacehub_api_token']
                 )
                 )
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
             }
             }
 
 
         credentials = json.loads(provider_model.encrypted_config)
         credentials = json.loads(provider_model.encrypted_config)
+
+        if 'task_type' not in credentials:
+            credentials['task_type'] = 'text-generation'
+
         if credentials['huggingfacehub_api_token']:
         if credentials['huggingfacehub_api_token']:
             credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
             credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
                 self.provider.tenant_id,
                 self.provider.tenant_id,

+ 39 - 0
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py

@@ -0,0 +1,39 @@
+from typing import Dict
+
+from langchain.llms import HuggingFaceEndpoint
+from pydantic import Extra, root_validator
+
+from langchain.utils import get_from_dict_or_env
+
+
+class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
+    """HuggingFace Endpoint models.
+
+    To use, you should have the ``huggingface_hub`` python package installed, and the
+    environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
+    it as a named parameter to the constructor.
+
+    Only supports `text-generation` and `text2text-generation` for now.
+
+    Example:
+        .. code-block:: python
+
+            from langchain.llms import HuggingFaceEndpoint
+            endpoint_url = (
+                "https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
+            )
+            hf = HuggingFaceEndpoint(
+                endpoint_url=endpoint_url,
+                huggingfacehub_api_token="my-api-key"
+            )
+    """
+
+    @root_validator(allow_reuse=True)
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        huggingfacehub_api_token = get_from_dict_or_env(
+            values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
+        )
+
+        values["huggingfacehub_api_token"] = huggingfacehub_api_token
+        return values

+ 2 - 1
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py

@@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
 INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
 INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
     'huggingfacehub_api_type': 'inference_endpoints',
     'huggingfacehub_api_type': 'inference_endpoints',
     'huggingfacehub_api_token': 'valid_key',
     'huggingfacehub_api_token': 'valid_key',
-    'huggingfacehub_endpoint_url': 'valid_url'
+    'huggingfacehub_endpoint_url': 'valid_url',
+    'task_type': 'text-generation'
 }
 }
 
 
 def encrypt_side_effect(tenant_id, encrypt_key):
 def encrypt_side_effect(tenant_id, encrypt_key):