|
@@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
|
|
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
|
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
|
|
|
|
|
+ if credentials['server_url'].endswith('/'):
|
|
|
+ credentials['server_url'] = credentials['server_url'][:-1]
|
|
|
+
|
|
|
+ # initialize client
|
|
|
+ client = Client(
|
|
|
+ base_url=credentials['server_url']
|
|
|
+ )
|
|
|
+
|
|
|
+ xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
|
|
+
|
|
|
+ if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
|
|
+ raise InvokeBadRequestError(
|
|
|
+ 'please check model type, the model you want to invoke is not a audio model')
|
|
|
+
|
|
|
audio_file_path = self._get_demo_file_path()
|
|
|
|
|
|
with open(audio_file_path, 'rb') as audio_file:
|
|
@@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
|
|
if credentials['server_url'].endswith('/'):
|
|
|
credentials['server_url'] = credentials['server_url'][:-1]
|
|
|
|
|
|
- # initialize client
|
|
|
- client = Client(
|
|
|
- base_url=credentials['server_url']
|
|
|
- )
|
|
|
-
|
|
|
- xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
|
|
-
|
|
|
- if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
|
|
- raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
|
|
|
-
|
|
|
- response = xinference_client.transcriptions(
|
|
|
+ handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
|
|
|
+ response = handle.transcriptions(
|
|
|
audio=file,
|
|
|
language = language,
|
|
|
prompt = prompt,
|