|
@@ -1,5 +1,6 @@
|
|
import json
|
|
import json
|
|
from typing import Type
|
|
from typing import Type
|
|
|
|
+import requests
|
|
|
|
|
|
from huggingface_hub import HfApi
|
|
from huggingface_hub import HfApi
|
|
|
|
|
|
@@ -10,8 +11,12 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
|
|
|
|
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
|
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
|
|
|
+from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
|
|
|
+from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
|
|
from models.provider import ProviderType
|
|
from models.provider import ProviderType
|
|
|
|
|
|
|
|
+HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
|
|
|
|
+
|
|
|
|
|
|
class HuggingfaceHubProvider(BaseModelProvider):
|
|
class HuggingfaceHubProvider(BaseModelProvider):
|
|
@property
|
|
@property
|
|
@@ -33,6 +38,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
|
"""
|
|
"""
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
model_class = HuggingfaceHubModel
|
|
model_class = HuggingfaceHubModel
|
|
|
|
+ elif model_type == ModelType.EMBEDDINGS:
|
|
|
|
+ model_class = HuggingfaceEmbedding
|
|
else:
|
|
else:
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
@@ -63,7 +70,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
|
:param model_type:
|
|
:param model_type:
|
|
:param credentials:
|
|
:param credentials:
|
|
"""
|
|
"""
|
|
- if model_type != ModelType.TEXT_GENERATION:
|
|
|
|
|
|
+ if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
if 'huggingfacehub_api_type' not in credentials \
|
|
if 'huggingfacehub_api_type' not in credentials \
|
|
@@ -88,19 +95,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
|
if 'task_type' not in credentials:
|
|
if 'task_type' not in credentials:
|
|
raise CredentialsValidateFailedError('Task Type must be provided.')
|
|
raise CredentialsValidateFailedError('Task Type must be provided.')
|
|
|
|
|
|
- if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
|
|
|
|
|
+ if credentials['task_type'] not in ("text2text-generation", "text-generation", 'feature-extraction'):
|
|
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
|
|
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
|
|
- 'text-generation, summarization.')
|
|
|
|
|
|
+ 'text-generation, feature-extraction.')
|
|
|
|
|
|
try:
|
|
try:
|
|
- llm = HuggingFaceEndpointLLM(
|
|
|
|
- endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
|
|
|
- task=credentials['task_type'],
|
|
|
|
- model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
|
|
|
- huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- llm("ping")
|
|
|
|
|
|
+ if credentials['task_type'] == 'feature-extraction':
|
|
|
|
+ cls.check_embedding_valid(credentials, model_name)
|
|
|
|
+ else:
|
|
|
|
+ cls.check_llm_valid(credentials)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
|
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
|
else:
|
|
else:
|
|
@@ -112,13 +115,64 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
|
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
|
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
|
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
|
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
|
|
|
|
|
- VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
|
|
|
|
|
+ VALID_TASKS = ("text2text-generation", "text-generation", "feature-extraction")
|
|
if model_info.pipeline_tag not in VALID_TASKS:
|
|
if model_info.pipeline_tag not in VALID_TASKS:
|
|
raise ValueError(f"Model {model_name} is not a valid task, "
|
|
raise ValueError(f"Model {model_name} is not a valid task, "
|
|
f"must be one of {VALID_TASKS}.")
|
|
f"must be one of {VALID_TASKS}.")
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
|
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
|
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def check_llm_valid(cls, credentials: dict):
|
|
|
|
+ llm = HuggingFaceEndpointLLM(
|
|
|
|
+ endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
|
|
|
+ task=credentials['task_type'],
|
|
|
|
+ model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
|
|
|
+ huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ llm("ping")
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def check_embedding_valid(cls, credentials: dict, model_name: str):
|
|
|
|
+
|
|
|
|
+ cls.check_endpoint_url_model_repository_name(credentials, model_name)
|
|
|
|
+
|
|
|
|
+ embedding_model = HuggingfaceHubEmbeddings(
|
|
|
|
+ model=model_name,
|
|
|
|
+ **credentials
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ embedding_model.embed_query("ping")
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str):
|
|
|
|
+ try:
|
|
|
|
+ url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
|
|
|
+ headers = {
|
|
|
|
+ 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
|
|
|
|
+ 'Content-Type': 'application/json'
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ response =requests.get(url=url, headers=headers)
|
|
|
|
+
|
|
|
|
+ if response.status_code != 200:
|
|
|
|
+ raise ValueError('User Name or Organization Name is invalid.')
|
|
|
|
+
|
|
|
|
+ model_repository_name = ''
|
|
|
|
+
|
|
|
|
+ for item in response.json().get("items", []):
|
|
|
|
+ if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
|
|
|
|
+ model_repository_name = item.get("model", {}).get("repository")
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ if model_repository_name != model_name:
|
|
|
|
+ raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise ValueError(str(e))
|
|
|
|
+
|
|
|
|
+
|
|
@classmethod
|
|
@classmethod
|
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
|
credentials: dict) -> dict:
|
|
credentials: dict) -> dict:
|