Quellcode durchsuchen

feat: support xinference's auth system (#7369)

SoaringEthan vor 8 Monaten
Ursprung
Commit
acd72e3ab2

+ 7 - 3
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             tools=tools, stop=stop, stream=stream, user=user,
             extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
                 server_url=credentials['server_url'],
-                model_uid=credentials['model_uid']
+                model_uid=credentials['model_uid'],
+                api_key=credentials.get('api_key'),
             )
         )
 
@@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 
             extra_param = XinferenceHelper.get_xinference_extra_parameter(
                 server_url=credentials['server_url'],
-                model_uid=credentials['model_uid']
+                model_uid=credentials['model_uid'],
+                api_key=credentials.get('api_key')
             )
             if 'completion_type' not in credentials:
                 if 'chat' in extra_param.model_ability:
@@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         else:
             extra_args = XinferenceHelper.get_xinference_extra_parameter(
                 server_url=credentials['server_url'],
-                model_uid=credentials['model_uid']
+                model_uid=credentials['model_uid'],
+                api_key=credentials.get('api_key')
             )
 
             if 'chat' in extra_args.model_ability:
@@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
 
         xinference_client = Client(
             base_url=credentials['server_url'],
+            api_key=credentials.get('api_key'),
         )
 
         xinference_model = xinference_client.get_model(credentials['model_uid'])

+ 2 - 1
api/core/model_runtime/model_providers/xinference/rerank/rerank.py

@@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
 
             # initialize client
             client = Client(
-                base_url=credentials['server_url']
+                base_url=credentials['server_url'],
+                api_key=credentials.get('api_key'),
             )
 
             xinference_client = client.get_model(model_uid=credentials['model_uid'])

+ 2 - 1
api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py

@@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
 
             # initialize client
             client = Client(
-                base_url=credentials['server_url']
+                base_url=credentials['server_url'],
+                api_key=credentials.get('api_key'),
             )
 
             xinference_client = client.get_model(model_uid=credentials['model_uid'])

+ 10 - 2
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 
             server_url = credentials['server_url']
             model_uid = credentials['model_uid']
-            extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
+            api_key = credentials.get('api_key')
+            extra_args = XinferenceHelper.get_xinference_extra_parameter(
+                server_url=server_url,
+                model_uid=model_uid,
+                api_key=api_key,
+            )
 
             if extra_args.max_tokens:
                 credentials['max_tokens'] = extra_args.max_tokens
             if server_url.endswith('/'):
                 server_url = server_url[:-1]
 
-            client = Client(base_url=server_url)
+            client = Client(
+                base_url=server_url,
+                api_key=api_key,
+            )
 
             try:
                 handle = client.get_model(model_uid=model_uid)

+ 7 - 2
api/core/model_runtime/model_providers/xinference/tts/tts.py

@@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
 
             extra_param = XinferenceHelper.get_xinference_extra_parameter(
                 server_url=credentials['server_url'],
-                model_uid=credentials['model_uid']
+                model_uid=credentials['model_uid'],
+                api_key=credentials.get('api_key'),
             )
 
             if 'text-to-audio' not in extra_param.model_ability:
@@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
             credentials['server_url'] = credentials['server_url'][:-1]
 
         try:
-            handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
+            api_key = credentials.get('api_key')
+            auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
+            handle = RESTfulAudioModelHandle(
+                credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
+            )
 
             model_support_voice = [x.get("value") for x in
                                    self.get_tts_model_voices(model=model, credentials=credentials)]

+ 5 - 4
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -35,13 +35,13 @@ cache_lock = Lock()
 
 class XinferenceHelper:
     @staticmethod
-    def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
+    def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
         XinferenceHelper._clean_cache()
         with cache_lock:
             if model_uid not in cache:
                 cache[model_uid] = {
                     'expires': time() + 300,
-                    'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
+                    'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
                 }
             return cache[model_uid]['value']
 
@@ -56,7 +56,7 @@ class XinferenceHelper:
             pass
 
     @staticmethod
-    def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
+    def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
         """
             get xinference model extra parameter like model_format and model_handle_type
         """
@@ -70,9 +70,10 @@ class XinferenceHelper:
         session = Session()
         session.mount('http://', HTTPAdapter(max_retries=3))
         session.mount('https://', HTTPAdapter(max_retries=3))
+        headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
 
         try:
-            response = session.get(url, timeout=10)
+            response = session.get(url, headers=headers, timeout=10)
         except (MissingSchema, ConnectionError, Timeout) as e:
             raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
         if response.status_code != 200: