瀏覽代碼

enhance:speedup xinference audio transcription (#3636)

呆萌闷油瓶 1 年之前
父節點
當前提交
f76ac8bdee
共有 1 個文件被更改,包括 16 次插入11 次删除
  1. 16 11
      api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py

+ 16 - 11
api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py

@@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
 
+            if credentials['server_url'].endswith('/'):
+                credentials['server_url'] = credentials['server_url'][:-1]
+
+            # initialize client
+            client = Client(
+                base_url=credentials['server_url']
+            )
+
+            xinference_client = client.get_model(model_uid=credentials['model_uid'])
+
+            if not isinstance(xinference_client, RESTfulAudioModelHandle):
+                raise InvokeBadRequestError(
+                    'please check model type, the model you want to invoke is not a audio model')
+
             audio_file_path = self._get_demo_file_path()
 
             with open(audio_file_path, 'rb') as audio_file:
@@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
         if credentials['server_url'].endswith('/'):
             credentials['server_url'] = credentials['server_url'][:-1]
 
-        # initialize client
-        client = Client(
-            base_url=credentials['server_url']
-        )
-
-        xinference_client = client.get_model(model_uid=credentials['model_uid'])
-
-        if not isinstance(xinference_client, RESTfulAudioModelHandle):
-            raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
-        
-        response = xinference_client.transcriptions(
+        handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
+        response = handle.transcriptions(
             audio=file,
             language = language,
             prompt = prompt,