Sfoglia il codice sorgente

Feat/huggingface embedding support (#1211)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Garfield Dai 1 anno fa
parent
commit
e409895c02

+ 22 - 0
api/core/model_providers/models/embedding/huggingface_embedding.py

@@ -0,0 +1,22 @@
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
+from core.model_providers.models.embedding.base import BaseEmbedding
+
+
+class HuggingfaceEmbedding(BaseEmbedding):
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = HuggingfaceHubEmbeddings(
+            model=name,
+            **credentials
+        )
+
+        super().__init__(model_provider, client, name)
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")

+ 66 - 12
api/core/model_providers/providers/huggingface_hub_provider.py

@@ -1,5 +1,6 @@
 import json
 from typing import Type
+import requests
 
 from huggingface_hub import HfApi
 
@@ -10,8 +11,12 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
 
 from core.model_providers.models.base import BaseProviderModel
 from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
+from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
+from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
 from models.provider import ProviderType
 
+HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
+
 
 class HuggingfaceHubProvider(BaseModelProvider):
     @property
@@ -33,6 +38,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
         """
         if model_type == ModelType.TEXT_GENERATION:
             model_class = HuggingfaceHubModel
+        elif model_type == ModelType.EMBEDDINGS:
+            model_class = HuggingfaceEmbedding
         else:
             raise NotImplementedError
 
@@ -63,7 +70,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
         :param model_type:
         :param credentials:
         """
-        if model_type != ModelType.TEXT_GENERATION:
+        if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
             raise NotImplementedError
 
         if 'huggingfacehub_api_type' not in credentials \
@@ -88,19 +95,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
             if 'task_type' not in credentials:
                 raise CredentialsValidateFailedError('Task Type must be provided.')
 
-            if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
+            if credentials['task_type'] not in ("text2text-generation", "text-generation", 'feature-extraction'):
                 raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
-                                                     'text-generation, summarization.')
+                                                     'text-generation, feature-extraction.')
 
             try:
-                llm = HuggingFaceEndpointLLM(
-                    endpoint_url=credentials['huggingfacehub_endpoint_url'],
-                    task=credentials['task_type'],
-                    model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
-                    huggingfacehub_api_token=credentials['huggingfacehub_api_token']
-                )
-
-                llm("ping")
+                if credentials['task_type'] == 'feature-extraction':
+                    cls.check_embedding_valid(credentials, model_name)
+                else:
+                    cls.check_llm_valid(credentials)    
             except Exception as e:
                 raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
         else:
@@ -112,13 +115,64 @@ class HuggingfaceHubProvider(BaseModelProvider):
                 if 'inference' in model_info.cardData and not model_info.cardData['inference']:
                     raise ValueError(f'Inference API has been turned off for this model {model_name}.')
 
-                VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
+                VALID_TASKS = ("text2text-generation", "text-generation", "feature-extraction")
                 if model_info.pipeline_tag not in VALID_TASKS:
                     raise ValueError(f"Model {model_name} is not a valid task, "
                                      f"must be one of {VALID_TASKS}.")
             except Exception as e:
                 raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
 
+    @classmethod
+    def check_llm_valid(cls, credentials: dict):
+        llm = HuggingFaceEndpointLLM(
+            endpoint_url=credentials['huggingfacehub_endpoint_url'],
+            task=credentials['task_type'],
+            model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
+            huggingfacehub_api_token=credentials['huggingfacehub_api_token']
+        )
+
+        llm("ping")
+
+    @classmethod
+    def check_embedding_valid(cls, credentials: dict, model_name: str):
+
+        cls.check_endpoint_url_model_repository_name(credentials, model_name)
+        
+        embedding_model = HuggingfaceHubEmbeddings(
+            model=model_name,
+            **credentials
+        )
+
+        embedding_model.embed_query("ping")
+
+    @classmethod
+    def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str):
+        try:
+            url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
+            headers = {
+                'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
+                'Content-Type': 'application/json'
+            }
+
+            response =requests.get(url=url, headers=headers)
+
+            if response.status_code != 200:
+                raise ValueError('User Name or Organization Name is invalid.')
+
+            model_repository_name = ''
+
+            for item in response.json().get("items", []):
+                if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
+                    model_repository_name = item.get("model", {}).get("repository")
+                    break
+            
+            if model_repository_name != model_name:
+                raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
+
+        except Exception as e:
+            raise ValueError(str(e))
+        
+
     @classmethod
     def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
                                   credentials: dict) -> dict:

+ 74 - 0
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py

@@ -0,0 +1,74 @@
+from typing import Any, Dict, List, Optional
+import json
+import numpy as np
+
+from pydantic import BaseModel, Extra, root_validator
+
+from langchain.embeddings.base import Embeddings
+from langchain.utils import get_from_dict_or_env
+from huggingface_hub import InferenceClient
+
+HOSTED_INFERENCE_API = 'hosted_inference_api'
+INFERENCE_ENDPOINTS = 'inference_endpoints'
+
+
+class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
+    client: Any
+    model: str
+
+    huggingface_namespace: Optional[str] = None
+    task_type: Optional[str] = None
+    huggingfacehub_api_type: Optional[str] = None
+    huggingfacehub_api_token: Optional[str] = None
+    huggingfacehub_endpoint_url: Optional[str] = None
+
+    class Config:
+        extra = Extra.forbid
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        values['huggingfacehub_api_token'] = get_from_dict_or_env(
+            values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
+        )
+
+        values['client'] = InferenceClient(token=values['huggingfacehub_api_token'])
+
+        return values
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        model = ''
+
+        if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
+            model = self.model
+        else:
+            model = self.huggingfacehub_endpoint_url
+
+        output = self.client.post(
+            json={
+                "inputs": texts,
+                "options": {
+                    "wait_for_model": False,
+                    "use_cache": False
+                }
+            }, model=model)
+        
+        embeddings =  json.loads(output.decode())
+        return self.mean_pooling(embeddings)
+
+    def embed_query(self, text: str) -> List[float]:
+        return self.embed_documents([text])[0]
+    
+    # https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
+    # Returned values are a list of floats, or a list of list of floats 
+    # (depending on if you sent a string or a list of string, 
+    # and if the automatic reduction, usually mean_pooling for instance was applied for you or not. 
+    # This should be explained on the model's README.)
+    def mean_pooling(self, embeddings: List) -> List[float]:
+        # If automatic reduction by giving model, no need to mean_pooling.
+        # For example one: List[List[float]]
+        if not isinstance(embeddings[0][0], list):
+            return embeddings
+
+        # For example two: List[List[List[float]]], need to mean_pooling.
+        sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
+        return sentence_embeddings

+ 1 - 1
api/core/third_party/langchain/llms/huggingface_hub_llm.py

@@ -16,7 +16,7 @@ class HuggingFaceHubLLM(HuggingFaceHub):
     environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
     it as a named parameter to the constructor.
 
-    Only supports `text-generation`, `text2text-generation` and `summarization` for now.
+    Only supports `text-generation`, `text2text-generation` for now.
 
     Example:
         .. code-block:: python

+ 1 - 1
api/requirements.txt

@@ -51,4 +51,4 @@ stripe~=5.5.0
 pandas==1.5.3
 xinference==0.4.2
 safetensors==0.3.2
-zhipuai==1.0.7
+zhipuai==1.0.7

+ 1 - 0
api/tests/integration_tests/.env.example

@@ -14,6 +14,7 @@ REPLICATE_API_TOKEN=
 # Hugging Face API Key
 HUGGINGFACE_API_KEY=
 HUGGINGFACE_ENDPOINT_URL=
+HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=
 
 # Minimax Credentials
 MINIMAX_API_KEY=

+ 136 - 0
api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py

@@ -0,0 +1,136 @@
+import json
+import os
+from unittest.mock import patch, MagicMock
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
+from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
+from models.provider import Provider, ProviderType, ProviderModel
+
+DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2'
+
+def get_mock_provider():
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='huggingface_hub',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config='',
+        is_valid=True,
+    )
+
+
+def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker):
+    valid_api_key = os.environ['HUGGINGFACE_API_KEY']
+    endpoint_url = os.environ['HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL']
+    model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
+
+    credentials = {
+        'huggingfacehub_api_type': huggingfacehub_api_type,
+        'huggingfacehub_api_token': valid_api_key,
+        'task_type': 'feature-extraction'
+    }
+
+    if huggingfacehub_api_type == 'inference_endpoints':
+        credentials['huggingfacehub_endpoint_url'] = endpoint_url
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        provider_name='huggingface_hub',
+        model_name=model_name,
+        model_type=ModelType.EMBEDDINGS.value,
+        encrypted_config=json.dumps(credentials),
+        is_valid=True,
+    )
+    mocker.patch('extensions.ext_database.db.session.query',
+                 return_value=mock_query)
+
+    return HuggingfaceEmbedding(
+        model_provider=model_provider,
+        name=model_name
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_hosted_inference_api_embed_documents(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(
+        DEFAULT_MODEL_NAME,
+        'hosted_inference_api',
+        mocker)
+    rst = embedding_model.client.embed_documents(['test', 'test1'])
+    assert isinstance(rst, list)
+    assert len(rst) == 2
+    assert len(rst[0]) == 384
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(
+        '',
+        'inference_endpoints',
+        mocker)
+    mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
+                 , return_value=bytes(json.dumps([[1, 2, 3], [4, 5, 6]]), 'utf-8'))
+    
+    rst = embedding_model.client.embed_documents(['test', 'test1'])
+    assert isinstance(rst, list)
+    assert len(rst) == 2
+    assert len(rst[0]) == 3
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_endpoint_url_inference_api_embed_documents_two(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(
+        '',
+        'inference_endpoints',
+        mocker)
+    mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
+                 , return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]],[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8'))
+    
+    rst = embedding_model.client.embed_documents(['test', 'test1'])
+    assert isinstance(rst, list)
+    assert len(rst) == 2
+    assert len(rst[0]) == 3
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_hosted_inference_api_embed_query(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(
+        DEFAULT_MODEL_NAME,
+        'hosted_inference_api',
+        mocker)
+    rst = embedding_model.client.embed_query('test')
+    assert isinstance(rst, list)
+    assert len(rst) == 384
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(
+        '',
+        'inference_endpoints',
+        mocker)
+    
+    mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
+                 , return_value=bytes(json.dumps([[1, 2, 3]]), 'utf-8'))
+
+    rst = embedding_model.client.embed_query('test')
+    assert isinstance(rst, list)
+    assert len(rst) == 3
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_endpoint_url_inference_api_embed_query_two(mock_decrypt, mocker):
+    embedding_model = get_mock_embedding_model(
+        '',
+        'inference_endpoints',
+        mocker)
+    
+    mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
+                 , return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8'))
+
+    rst = embedding_model.client.embed_query('test')
+    assert isinstance(rst, list)
+    assert len(rst) == 3

+ 89 - 9
web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx

@@ -48,6 +48,15 @@ const config: ProviderConfig = {
         ]
       }
       if (v?.huggingfacehub_api_type === 'inference_endpoints') {
+        if (v.model_type === 'embeddings') {
+          return [
+            'huggingfacehub_api_token',
+            'huggingface_namespace',
+            'model_name',
+            'huggingfacehub_endpoint_url',
+            'task_type',
+          ]
+        }
         return [
           'huggingfacehub_api_token',
           'model_name',
@@ -68,14 +77,27 @@ const config: ProviderConfig = {
         ]
       }
       if (v?.huggingfacehub_api_type === 'inference_endpoints') {
-        filteredKeys = [
-          'huggingfacehub_api_type',
-          'huggingfacehub_api_token',
-          'model_name',
-          'huggingfacehub_endpoint_url',
-          'task_type',
-          'model_type',
-        ]
+        if (v.model_type === 'embeddings') {
+          filteredKeys = [
+            'huggingfacehub_api_type',
+            'huggingfacehub_api_token',
+            'huggingface_namespace',
+            'model_name',
+            'huggingfacehub_endpoint_url',
+            'task_type',
+            'model_type',
+          ]
+        }
+        else {
+          filteredKeys = [
+            'huggingfacehub_api_type',
+            'huggingfacehub_api_token',
+            'model_name',
+            'huggingfacehub_endpoint_url',
+            'task_type',
+            'model_type',
+          ]
+        }
       }
       return filteredKeys.reduce((prev: FormValue, next: string) => {
         prev[next] = v?.[next] || ''
@@ -83,6 +105,31 @@ const config: ProviderConfig = {
       }, {})
     },
     fields: [
+      {
+        type: 'radio',
+        key: 'model_type',
+        required: true,
+        label: {
+          'en': 'Model Type',
+          'zh-Hans': '模型类型',
+        },
+        options: [
+          {
+            key: 'text-generation',
+            label: {
+              'en': 'Text Generation',
+              'zh-Hans': '文本生成',
+            },
+          },
+          {
+            key: 'embeddings',
+            label: {
+              'en': 'Embeddings',
+              'zh-Hans': 'Embeddings',
+            },
+          },
+        ],
+      },
       {
         type: 'radio',
         key: 'huggingfacehub_api_type',
@@ -121,6 +168,20 @@ const config: ProviderConfig = {
           'zh-Hans': '在此输入您的 Hugging Face Hub API Token',
         },
       },
+      {
+        hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'),
+        type: 'text',
+        key: 'huggingface_namespace',
+        required: true,
+        label: {
+          'en': 'User Name / Organization Name',
+          'zh-Hans': '用户名 / 组织名称',
+        },
+        placeholder: {
+          'en': 'Enter your User Name / Organization Name here',
+          'zh-Hans': '在此输入您的用户名 / 组织名称',
+        },
+      },
       {
         type: 'text',
         key: 'model_name',
@@ -148,7 +209,7 @@ const config: ProviderConfig = {
         },
       },
       {
-        hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
+        hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings',
         type: 'radio',
         key: 'task_type',
         required: true,
@@ -173,6 +234,25 @@ const config: ProviderConfig = {
           },
         ],
       },
+      {
+        hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'),
+        type: 'radio',
+        key: 'task_type',
+        required: true,
+        label: {
+          'en': 'Task',
+          'zh-Hans': 'Task',
+        },
+        options: [
+          {
+            key: 'feature-extraction',
+            label: {
+              'en': 'Feature Extraction',
+              'zh-Hans': 'Feature Extraction',
+            },
+          },
+        ],
+      },
     ],
   },
 }

+ 25 - 4
web/app/components/header/account-setting/model-page/model-modal/Form.tsx

@@ -1,7 +1,7 @@
 import { useEffect, useState } from 'react'
 import type { Dispatch, FC, SetStateAction } from 'react'
 import { useContext } from 'use-context-selector'
-import type { Field, FormValue, ProviderConfigModal } from '../declarations'
+import { type Field, type FormValue, type ProviderConfigModal, ProviderEnum } from '../declarations'
 import { useValidate } from '../../key-validator/hooks'
 import { ValidatingTip } from '../../key-validator/ValidateStatus'
 import { validateModelProviderFn } from '../utils'
@@ -85,10 +85,31 @@ const Form: FC<FormProps> = ({
   }
 
   const handleFormChange = (k: string, v: string) => {
-    if (mode === 'edit' && !cleared)
+    if (mode === 'edit' && !cleared) {
       handleClear({ [k]: v })
-    else
-      handleMultiFormChange({ ...value, [k]: v }, k)
+    }
+    else {
+      const extraValue: Record<string, string> = {}
+      if (
+        (
+          (k === 'model_type' && v === 'embeddings' && value.huggingfacehub_api_type === 'inference_endpoints')
+          || (k === 'huggingfacehub_api_type' && v === 'inference_endpoints' && value.model_type === 'embeddings')
+        )
+        && modelModal?.key === ProviderEnum.huggingface_hub
+      )
+        extraValue.task_type = 'feature-extraction'
+
+      if (
+        (
+          (k === 'model_type' && v === 'text-generation' && value.huggingfacehub_api_type === 'inference_endpoints')
+          || (k === 'huggingfacehub_api_type' && v === 'inference_endpoints' && value.model_type === 'text-generation')
+        )
+        && modelModal?.key === ProviderEnum.huggingface_hub
+      )
+        extraValue.task_type = 'text-generation'
+
+      handleMultiFormChange({ ...value, [k]: v, ...extraValue }, k)
+    }
   }
 
   const handleFocus = () => {

+ 1 - 1
web/app/components/header/account-setting/model-page/model-modal/index.tsx

@@ -92,7 +92,7 @@ const ModelModal: FC<ModelModalProps> = ({
   return (
     <Portal>
       <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'>
-        <div className='w-[640px] max-h-screen bg-white shadow-xl rounded-2xl overflow-y-auto'>
+        <div className='w-[640px] max-h-[calc(100vh-120px)] bg-white shadow-xl rounded-2xl overflow-y-auto'>
           <div className='px-8 pt-8'>
             <div className='flex justify-between items-center mb-2'>
               <div className='text-xl font-semibold text-gray-900'>{renderTitlePrefix()}</div>