Browse Source

feat: Add support for TEI API key authentication (#11006)

Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
kenwoodjw 5 tháng trước cách đây
mục cha
commit
096c0ad564

+ 8 - 0
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml

@@ -34,3 +34,11 @@ model_credential_schema:
       placeholder:
         zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
         en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key

+ 11 - 2
api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py

@@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
 
         server_url = server_url.removesuffix("/")
 
+        headers = {"Content-Type": "application/json"}
+        api_key = credentials.get("api_key")
+        if api_key:
+            headers["Authorization"] = f"Bearer {api_key}"
+
         try:
-            results = TeiHelper.invoke_rerank(server_url, query, docs)
+            results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
 
             rerank_documents = []
             for result in results:
@@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
         """
         try:
             server_url = credentials["server_url"]
-            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
+            headers = {"Content-Type": "application/json"}
+            api_key = credentials.get("api_key")
+            if api_key:
+                headers["Authorization"] = f"Bearer {api_key}"
+            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
             if extra_args.model_type != "reranker":
                 raise CredentialsValidateFailedError("Current model is not a rerank model")
 

+ 18 - 20
api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py

@@ -26,13 +26,15 @@ cache_lock = Lock()
 
 class TeiHelper:
     @staticmethod
-    def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
+    def get_tei_extra_parameter(
+        server_url: str, model_name: str, headers: Optional[dict] = None
+    ) -> TeiModelExtraParameter:
         TeiHelper._clean_cache()
         with cache_lock:
             if model_name not in cache:
                 cache[model_name] = {
                     "expires": time() + 300,
-                    "value": TeiHelper._get_tei_extra_parameter(server_url),
+                    "value": TeiHelper._get_tei_extra_parameter(server_url, headers),
                 }
             return cache[model_name]["value"]
 
@@ -47,7 +49,7 @@ class TeiHelper:
             pass
 
     @staticmethod
-    def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
+    def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
         """
         get tei model extra parameter like model_type, max_input_length, max_batch_requests
         """
@@ -61,7 +63,7 @@ class TeiHelper:
         session.mount("https://", HTTPAdapter(max_retries=3))
 
         try:
-            response = session.get(url, timeout=10)
+            response = session.get(url, headers=headers, timeout=10)
         except (MissingSchema, ConnectionError, Timeout) as e:
             raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
         if response.status_code != 200:
@@ -86,7 +88,7 @@ class TeiHelper:
         )
 
     @staticmethod
-    def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
+    def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
         """
         Invoke tokenize endpoint
 
@@ -114,15 +116,15 @@ class TeiHelper:
         :param server_url: server url
         :param texts: texts to tokenize
         """
-        resp = httpx.post(
-            f"{server_url}/tokenize",
-            json={"inputs": texts},
-        )
+        url = f"{server_url}/tokenize"
+        json_data = {"inputs": texts}
+        resp = httpx.post(url, json=json_data, headers=headers)
+
         resp.raise_for_status()
         return resp.json()
 
     @staticmethod
-    def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
+    def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
         """
         Invoke embeddings endpoint
 
@@ -147,15 +149,14 @@ class TeiHelper:
         :param texts: texts to embed
         """
         # Use OpenAI compatible API here, which has usage tracking
-        resp = httpx.post(
-            f"{server_url}/v1/embeddings",
-            json={"input": texts},
-        )
+        url = f"{server_url}/v1/embeddings"
+        json_data = {"input": texts}
+        resp = httpx.post(url, json=json_data, headers=headers)
         resp.raise_for_status()
         return resp.json()
 
     @staticmethod
-    def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
+    def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
         """
         Invoke rerank endpoint
 
@@ -173,10 +174,7 @@ class TeiHelper:
         :param candidates: candidates to rerank
         """
         params = {"query": query, "texts": docs, "return_text": True}
-
-        response = httpx.post(
-            server_url + "/rerank",
-            json=params,
-        )
+        url = f"{server_url}/rerank"
+        response = httpx.post(url, json=params, headers=headers)
         response.raise_for_status()
         return response.json()

+ 19 - 4
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py

@@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
 
         server_url = server_url.removesuffix("/")
 
+        headers = {"Content-Type": "application/json"}
+        api_key = credentials["api_key"]
+        if api_key:
+            headers["Authorization"] = f"Bearer {api_key}"
         # get model properties
         context_size = self._get_context_size(model, credentials)
         max_chunks = self._get_max_chunks(model, credentials)
@@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
         used_tokens = 0
 
         # get tokenized results from TEI
-        batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
+        batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
 
         for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
             # Check if the number of tokens is larger than the context size
@@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
             used_tokens = 0
             for i in _iter:
                 iter_texts = inputs[i : i + max_chunks]
-                results = TeiHelper.invoke_embeddings(server_url, iter_texts)
+                results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
                 embeddings = results["data"]
                 embeddings = [embedding["embedding"] for embedding in embeddings]
                 batched_embeddings.extend(embeddings)
@@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
 
         server_url = server_url.removesuffix("/")
 
-        batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
+        headers = {
+            "Authorization": f"Bearer {credentials.get('api_key')}",
+        }
+
+        batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
         num_tokens = sum(len(tokens) for tokens in batch_tokens)
         return num_tokens
 
@@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
         """
         try:
             server_url = credentials["server_url"]
-            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
+            headers = {"Content-Type": "application/json"}
+
+            api_key = credentials.get("api_key")
+
+            if api_key:
+                headers["Authorization"] = f"Bearer {api_key}"
+
+            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
             print(extra_args)
             if extra_args.model_type != "embedding":
                 raise CredentialsValidateFailedError("Current model is not a embedding model")

+ 1 - 0
api/pytest.ini

@@ -20,6 +20,7 @@ env =
     OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
     TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451
     TEI_RERANK_SERVER_URL = http://a.abc.com:11451
+    TEI_API_KEY = ttttttttttttttt
     UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa
     VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa
     XINFERENCE_CHAT_MODEL_UID = chat

+ 3 - 0
api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py

@@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
                 model="reranker",
                 credentials={
                     "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
+                    "api_key": os.environ.get("TEI_API_KEY", ""),
                 },
             )
 
@@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
         model=model_name,
         credentials={
             "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
+            "api_key": os.environ.get("TEI_API_KEY", ""),
         },
     )
 
@@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock):
         model=model_name,
         credentials={
             "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
+            "api_key": os.environ.get("TEI_API_KEY", ""),
         },
         texts=["hello", "world"],
         user="abc-123",

+ 3 - 0
api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py

@@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock):
                 model="embedding",
                 credentials={
                     "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
+                    "api_key": os.environ.get("TEI_API_KEY", ""),
                 },
             )
 
@@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock):
         model=model_name,
         credentials={
             "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
+            "api_key": os.environ.get("TEI_API_KEY", ""),
         },
     )
 
@@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock):
         model=model_name,
         credentials={
             "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
+            "api_key": os.environ.get("TEI_API_KEY", ""),
         },
         query="Who is Kasumi?",
         docs=[