Bläddra i källkod

chore: optimize streaming tts of xinference (#6966)

takatost 8 månader sedan
förälder
incheckning
ea30174057

+ 64 - 62
api/core/model_runtime/model_providers/xinference/tts/tts.py

@@ -1,11 +1,7 @@
 import concurrent.futures
-from functools import reduce
-from io import BytesIO
 from typing import Optional
 
-from flask import Response
-from pydub import AudioSegment
-from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
+from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle
 
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -19,6 +15,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.tts_model import TTSModel
+from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
 
 
 class XinferenceText2SpeechModel(TTSModel):
@@ -26,7 +23,12 @@ class XinferenceText2SpeechModel(TTSModel):
     def __init__(self):
         # preset voices, need support custom voice
         self.model_voices = {
-            'chattts': {
+            '__default': {
+                'all': [
+                    {'name': 'Default', 'value': 'default'},
+                ]
+            },
+            'ChatTTS': {
                 'all': [
                     {'name': 'Alloy', 'value': 'alloy'},
                     {'name': 'Echo', 'value': 'echo'},
@@ -36,7 +38,7 @@ class XinferenceText2SpeechModel(TTSModel):
                     {'name': 'Shimmer', 'value': 'shimmer'},
                 ]
             },
-            'cosyvoice': {
+            'CosyVoice': {
                 'zh-Hans': [
                     {'name': '中文男', 'value': '中文男'},
                     {'name': '中文女', 'value': '中文女'},
@@ -77,18 +79,21 @@ class XinferenceText2SpeechModel(TTSModel):
             if credentials['server_url'].endswith('/'):
                 credentials['server_url'] = credentials['server_url'][:-1]
 
-            # initialize client
-            client = Client(
-                base_url=credentials['server_url']
+            extra_param = XinferenceHelper.get_xinference_extra_parameter(
+                server_url=credentials['server_url'],
+                model_uid=credentials['model_uid']
             )
 
-            xinference_client = client.get_model(model_uid=credentials['model_uid'])
-
-            if not isinstance(xinference_client, RESTfulAudioModelHandle):
+            if 'text-to-audio' not in extra_param.model_ability:
                 raise InvokeBadRequestError(
-                    'please check model type, the model you want to invoke is not a audio model')
+                    'please check model type, the model you want to invoke is not a text-to-audio model')
+
+            if extra_param.model_family and extra_param.model_family in self.model_voices:
+                credentials['audio_model_name'] = extra_param.model_family
+            else:
+                credentials['audio_model_name'] = '__default'
 
-            self._tts_invoke(
+            self._tts_invoke_streaming(
                 model=model,
                 credentials=credentials,
                 content_text='Hello Dify!',
@@ -110,7 +115,7 @@ class XinferenceText2SpeechModel(TTSModel):
         :param user: unique user id
         :return: text translated to audio file
         """
-        return self._tts_invoke(model, credentials, content_text, voice)
+        return self._tts_invoke_streaming(model, credentials, content_text, voice)
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
         """
@@ -161,13 +166,15 @@ class XinferenceText2SpeechModel(TTSModel):
         }
 
     def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
+        audio_model_name = credentials.get('audio_model_name', '__default')
         for key, voices in self.model_voices.items():
-            if key in model.lower():
-                if language in voices:
+            if key in audio_model_name:
+                if language and language in voices:
                     return voices[language]
                 elif 'all' in voices:
                     return voices['all']
-        return []
+
+        return self.model_voices['__default']['all']
 
     def _get_model_default_voice(self, model: str, credentials: dict) -> any:
         return ""
@@ -181,60 +188,55 @@ class XinferenceText2SpeechModel(TTSModel):
     def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
         return 5
 
-    def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
+    def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
+                              voice: str) -> any:
         """
-        _tts_invoke text2speech model
+        _tts_invoke_streaming text2speech model
 
         :param model: model name
         :param credentials: model credentials
-        :param voice: model timbre
         :param content_text: text content to be translated
+        :param voice: model timbre
         :return: text translated to audio file
         """
         if credentials['server_url'].endswith('/'):
             credentials['server_url'] = credentials['server_url'][:-1]
 
-        word_limit = self._get_model_word_limit(model, credentials)
-        audio_type = self._get_model_audio_type(model, credentials)
-        handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
-
         try:
-            sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
-            audio_bytes_list = []
-
-            with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor:
+            handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
+
+            model_support_voice = [x.get("value") for x in
+                                   self.get_tts_model_voices(model=model, credentials=credentials)]
+            if not voice or voice not in model_support_voice:
+                voice = self._get_model_default_voice(model, credentials)
+            word_limit = self._get_model_word_limit(model, credentials)
+            if len(content_text) > word_limit:
+                sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
+                executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
                 futures = [executor.submit(
-                    handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False)
-                    for sentence in sentences]
-            for future in futures:
-                try:
-                    if future.result():
-                        audio_bytes_list.append(future.result())
-                except Exception as ex:
-                    raise InvokeBadRequestError(str(ex))
-
-            if len(audio_bytes_list) > 0:
-                audio_segments = [AudioSegment.from_file(
-                    BytesIO(audio_bytes), format=audio_type) for audio_bytes in
-                    audio_bytes_list if audio_bytes]
-                combined_segment = reduce(lambda x, y: x + y, audio_segments)
-                buffer: BytesIO = BytesIO()
-                combined_segment.export(buffer, format=audio_type)
-                buffer.seek(0)
-                return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
+                    handle.speech,
+                    input=sentences[i],
+                    voice=voice,
+                    response_format="mp3",
+                    speed=1.0,
+                    stream=False
+                )
+                    for i in range(len(sentences))]
+
+                for index, future in enumerate(futures):
+                    response = future.result()
+                    for i in range(0, len(response), 1024):
+                        yield response[i:i + 1024]
+            else:
+                response = handle.speech(
+                    input=content_text.strip(),
+                    voice=voice,
+                    response_format="mp3",
+                    speed=1.0,
+                    stream=False
+                )
+
+                for i in range(0, len(response), 1024):
+                    yield response[i:i + 1024]
         except Exception as ex:
             raise InvokeBadRequestError(str(ex))
-
-    def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
-        """
-        _tts_invoke_streaming text2speech model
-
-        Attention:  stream api may return error [Parallel generation is not supported by ggml]
-
-        :param model: model name
-        :param credentials: model credentials
-        :param voice: model timbre
-        :param content_text: text content to be translated
-        :return: text translated to audio file
-        """
-        pass

+ 16 - 4
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -1,5 +1,6 @@
 from threading import Lock
 from time import time
+from typing import Optional
 
 from requests.adapters import HTTPAdapter
 from requests.exceptions import ConnectionError, MissingSchema, Timeout
@@ -15,9 +16,11 @@ class XinferenceModelExtraParameter:
     context_length: int = 2048
     support_function_call: bool = False
     support_vision: bool = False
+    model_family: Optional[str]
 
     def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str],
-                 support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int) -> None:
+                 support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int,
+                 model_family: Optional[str]) -> None:
         self.model_format = model_format
         self.model_handle_type = model_handle_type
         self.model_ability = model_ability
@@ -25,6 +28,7 @@ class XinferenceModelExtraParameter:
         self.support_vision = support_vision
         self.max_tokens = max_tokens
         self.context_length = context_length
+        self.model_family = model_family
 
 cache = {}
 cache_lock = Lock()
@@ -78,9 +82,16 @@ class XinferenceHelper:
 
         model_format = response_json.get('model_format', 'ggmlv3')
         model_ability = response_json.get('model_ability', [])
+        model_family = response_json.get('model_family', None)
 
         if response_json.get('model_type') == 'embedding':
             model_handle_type = 'embedding'
+        elif response_json.get('model_type') == 'audio':
+            model_handle_type = 'audio'
+            if model_family and model_family in ['ChatTTS', 'CosyVoice']:
+                model_ability.append('text-to-audio')
+            else:
+                model_ability.append('audio-to-text')
         elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
             model_handle_type = 'chatglm'
         elif 'generate' in model_ability:
@@ -88,7 +99,7 @@ class XinferenceHelper:
         elif 'chat' in model_ability:
             model_handle_type = 'chat'
         else:
-            raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
+            raise NotImplementedError('xinference model handle type is not supported')
 
         support_function_call = 'tools' in model_ability
         support_vision = 'vision' in model_ability
@@ -103,5 +114,6 @@ class XinferenceHelper:
             support_function_call=support_function_call,
             support_vision=support_vision,
             max_tokens=max_tokens,
-            context_length=context_length
-        )
+            context_length=context_length,
+            model_family=model_family
+        )