Explorar o código

feat: hf inference endpoint stream support (#1028)

takatost hai 1 ano
pai
achega
0796791de5

+ 2 - 2
api/core/model_providers/models/llm/anthropic_model.py

@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM):
         else:
             return ex
 
-    @classmethod
-    def support_streaming(cls):
+    @property
+    def support_streaming(self):
         return True
 

+ 3 - 3
api/core/model_providers/models/llm/azure_openai_model.py

@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
         else:
             return ex
 
-    @classmethod
-    def support_streaming(cls):
-        return True
+    @property
+    def support_streaming(self):
+        return True

+ 4 - 4
api/core/model_providers/models/llm/base.py

@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
                 result = self._run(
                     messages=messages,
                     stop=stop,
-                    callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
+                    callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
                     **kwargs
                 )
             except Exception as ex:
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
         else:
             completion_content = result.generations[0][0].text
 
-        if self.streaming and not self.support_streaming():
+        if self.streaming and not self.support_streaming:
             # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
             prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
             fake_llm = FakeLLM(
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
         else:
             self.client.callbacks.extend(callbacks)
 
-    @classmethod
-    def support_streaming(cls):
+    @property
+    def support_streaming(self):
         return False
 
     def get_prompt(self, mode: str,

+ 0 - 4
api/core/model_providers/models/llm/chatglm_model.py

@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
             return LLMBadRequestError(f"ChatGLM: {str(ex)}")
         else:
             return ex
-
-    @classmethod
-    def support_streaming(cls):
-        return False

+ 13 - 4
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
         if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
+            streaming = self.streaming
+
+            if 'baichuan' in self.name.lower():
+                streaming = False
+
             client = HuggingFaceEndpointLLM(
                 endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
                 task=self.credentials['task_type'],
                 model_kwargs=provider_model_kwargs,
                 huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
-                callbacks=self.callbacks
+                callbacks=self.callbacks,
+                streaming=streaming
             )
         else:
             client = HuggingFaceHub(
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
     def handle_exceptions(self, ex: Exception) -> Exception:
         return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
 
-    @classmethod
-    def support_streaming(cls):
-        return False
+    @property
+    def support_streaming(self):
+        if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
+            if 'baichuan' in self.name.lower():
+                return False
 
+        return True

+ 2 - 2
api/core/model_providers/models/llm/openai_model.py

@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
         else:
             return ex
 
-    @classmethod
-    def support_streaming(cls):
+    @property
+    def support_streaming(self):
         return True
 
     # def is_model_valid_or_raise(self):

+ 0 - 4
api/core/model_providers/models/llm/openllm_model.py

@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
 
     def handle_exceptions(self, ex: Exception) -> Exception:
         return LLMBadRequestError(f"OpenLLM: {str(ex)}")
-
-    @classmethod
-    def support_streaming(cls):
-        return False

+ 3 - 3
api/core/model_providers/models/llm/replicate_model.py

@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
         else:
             return ex
 
-    @classmethod
-    def support_streaming(cls):
-        return True
+    @property
+    def support_streaming(self):
+        return True

+ 3 - 3
api/core/model_providers/models/llm/spark_model.py

@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
         else:
             return ex
 
-    @classmethod
-    def support_streaming(cls):
-        return True
+    @property
+    def support_streaming(self):
+        return True

+ 2 - 2
api/core/model_providers/models/llm/tongyi_model.py

@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
         else:
             return ex
 
-    @classmethod
-    def support_streaming(cls):
+    @property
+    def support_streaming(self):
         return True

+ 0 - 4
api/core/model_providers/models/llm/wenxin_model.py

@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
 
     def handle_exceptions(self, ex: Exception) -> Exception:
         return LLMBadRequestError(f"Wenxin: {str(ex)}")
-
-    @classmethod
-    def support_streaming(cls):
-        return False

+ 2 - 2
api/core/model_providers/models/llm/xinference_model.py

@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
     def handle_exceptions(self, ex: Exception) -> Exception:
         return LLMBadRequestError(f"Xinference: {str(ex)}")
 
-    @classmethod
-    def support_streaming(cls):
+    @property
+    def support_streaming(self):
         return True

+ 91 - 2
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py

@@ -1,7 +1,11 @@
-from typing import Dict
+from typing import Dict, Any, Optional, List, Iterable, Iterator
 
+from huggingface_hub import InferenceClient
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.embeddings.huggingface_hub import VALID_TASKS
 from langchain.llms import HuggingFaceEndpoint
-from pydantic import Extra, root_validator
+from langchain.llms.utils import enforce_stop_tokens
+from pydantic import root_validator
 
 from langchain.utils import get_from_dict_or_env
 
@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
                 huggingfacehub_api_token="my-api-key"
             )
     """
+    client: Any
+    streaming: bool = False
 
     @root_validator(allow_reuse=True)
     def validate_environment(cls, values: Dict) -> Dict:
@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
             values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
         )
 
+        values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token)
+
         values["huggingfacehub_api_token"] = huggingfacehub_api_token
         return values
+
+    def _call(
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        **kwargs: Any,
+    ) -> str:
+        """Call out to HuggingFace Hub's inference endpoint.
+
+        Args:
+            prompt: The prompt to pass into the model.
+            stop: Optional list of stop words to use when generating.
+
+        Returns:
+            The string generated by the model.
+
+        Example:
+            .. code-block:: python
+
+                response = hf("Tell me a joke.")
+        """
+        _model_kwargs = self.model_kwargs or {}
+
+        # payload samples
+        params = {**_model_kwargs, **kwargs}
+
+        # generation parameter
+        gen_kwargs = {
+            **params,
+            'stop_sequences': stop
+        }
+
+        response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs)
+
+        if self.streaming and isinstance(response, Iterable):
+            combined_text_output = ""
+            for token in self._stream_response(response, run_manager):
+                combined_text_output += token
+            completion = combined_text_output
+        else:
+            completion = response.generated_text
+
+        if self.task == "text-generation":
+            text = completion
+            # Remove prompt if included in generated text.
+            if text.startswith(prompt):
+                text = text[len(prompt) :]
+        elif self.task == "text2text-generation":
+            text = completion
+        else:
+            raise ValueError(
+                f"Got invalid task {self.task}, "
+                f"currently only {VALID_TASKS} are supported"
+            )
+
+        if stop is not None:
+            # This is a bit hacky, but I can't figure out a better way to enforce
+            # stop tokens when making calls to huggingface_hub.
+            text = enforce_stop_tokens(text, stop)
+
+        return text
+
+    def _stream_response(
+            self,
+            response: Iterable,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+    ) -> Iterator[str]:
+        for r in response:
+            # skip special tokens
+            if r.token.special:
+                continue
+
+            token = r.token.text
+            if run_manager:
+                run_manager.on_llm_new_token(
+                    token=token, verbose=self.verbose, log_probs=None
+                )
+
+            # yield the generated token
+            yield token

+ 3 - 1
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py

@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i
 
 def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
     mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
-    mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc")
+    mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
 
     MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
         model_name='test_model_name',
@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
         credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
     )
 
+
 def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
     mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
+    mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
 
     with pytest.raises(CredentialsValidateFailedError):
         MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(