فهرست منبع

support xinference tts (#6746)

Weaxs 8 ماه پیش
والد
کامیت
f6e8e120a1

+ 0 - 0
api/core/model_runtime/model_providers/xinference/tts/__init__.py


+ 240 - 0
api/core/model_runtime/model_providers/xinference/tts/tts.py

@@ -0,0 +1,240 @@
+import concurrent.futures
+from functools import reduce
+from io import BytesIO
+from typing import Optional
+
+from flask import Response
+from pydub import AudioSegment
+from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.tts_model import TTSModel
+
+
+class XinferenceText2SpeechModel(TTSModel):
+
+    def __init__(self):
+        # preset voices, need support custom voice
+        self.model_voices = {
+            'chattts': {
+                'all': [
+                    {'name': 'Alloy', 'value': 'alloy'},
+                    {'name': 'Echo', 'value': 'echo'},
+                    {'name': 'Fable', 'value': 'fable'},
+                    {'name': 'Onyx', 'value': 'onyx'},
+                    {'name': 'Nova', 'value': 'nova'},
+                    {'name': 'Shimmer', 'value': 'shimmer'},
+                ]
+            },
+            'cosyvoice': {
+                'zh-Hans': [
+                    {'name': '中文男', 'value': '中文男'},
+                    {'name': '中文女', 'value': '中文女'},
+                    {'name': '粤语女', 'value': '粤语女'},
+                ],
+                'zh-Hant': [
+                    {'name': '中文男', 'value': '中文男'},
+                    {'name': '中文女', 'value': '中文女'},
+                    {'name': '粤语女', 'value': '粤语女'},
+                ],
+                'en-US': [
+                    {'name': '英文男', 'value': '英文男'},
+                    {'name': '英文女', 'value': '英文女'},
+                ],
+                'ja-JP': [
+                    {'name': '日语男', 'value': '日语男'},
+                ],
+                'ko-KR': [
+                    {'name': '韩语女', 'value': '韩语女'},
+                ]
+            }
+        }
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+                Validate model credentials
+
+                :param model: model name
+                :param credentials: model credentials
+                :return:
+                """
+        try:
+            if ("/" in credentials['model_uid'] or
+                    "?" in credentials['model_uid'] or
+                    "#" in credentials['model_uid']):
+                raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
+
+            if credentials['server_url'].endswith('/'):
+                credentials['server_url'] = credentials['server_url'][:-1]
+
+            # initialize client
+            client = Client(
+                base_url=credentials['server_url']
+            )
+
+            xinference_client = client.get_model(model_uid=credentials['model_uid'])
+
+            if not isinstance(xinference_client, RESTfulAudioModelHandle):
+                raise InvokeBadRequestError(
+                    'please check model type, the model you want to invoke is not a audio model')
+
+            self._tts_invoke(
+                model=model,
+                credentials=credentials,
+                content_text='Hello Dify!',
+                voice=self._get_model_default_voice(model, credentials),
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
+                user: Optional[str] = None):
+        """
+        _invoke text2speech model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param content_text: text content to be translated
+        :param user: unique user id
+        :return: text translated to audio file
+        """
+        return self._tts_invoke(model, credentials, content_text, voice)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TTS,
+            model_properties={},
+            parameter_rules=[]
+        )
+
+        return entity
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                InvokeBadRequestError,
+                KeyError,
+                ValueError
+            ]
+        }
+
+    def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
+        for key, voices in self.model_voices.items():
+            if key in model.lower():
+                if language in voices:
+                    return voices[language]
+                elif 'all' in voices:
+                    return voices['all']
+        return []
+
+    def _get_model_default_voice(self, model: str, credentials: dict) -> any:
+        return ""
+
+    def _get_model_word_limit(self, model: str, credentials: dict) -> int:
+        return 3500
+
+    def _get_model_audio_type(self, model: str, credentials: dict) -> str:
+        return "mp3"
+
+    def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
+        return 5
+
+    def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
+        """
+        _tts_invoke text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param content_text: text content to be translated
+        :return: text translated to audio file
+        """
+        if credentials['server_url'].endswith('/'):
+            credentials['server_url'] = credentials['server_url'][:-1]
+
+        word_limit = self._get_model_word_limit(model, credentials)
+        audio_type = self._get_model_audio_type(model, credentials)
+        handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
+
+        try:
+            sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
+            audio_bytes_list = []
+
+            with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor:
+                futures = [executor.submit(
+                    handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False)
+                    for sentence in sentences]
+            for future in futures:
+                try:
+                    if future.result():
+                        audio_bytes_list.append(future.result())
+                except Exception as ex:
+                    raise InvokeBadRequestError(str(ex))
+
+            if len(audio_bytes_list) > 0:
+                audio_segments = [AudioSegment.from_file(
+                    BytesIO(audio_bytes), format=audio_type) for audio_bytes in
+                    audio_bytes_list if audio_bytes]
+                combined_segment = reduce(lambda x, y: x + y, audio_segments)
+                buffer: BytesIO = BytesIO()
+                combined_segment.export(buffer, format=audio_type)
+                buffer.seek(0)
+                return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))
+
+    def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
+        """
+        _tts_invoke_streaming text2speech model
+
+        Attention:  stream api may return error [Parallel generation is not supported by ggml]
+
+        :param model: model name
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param content_text: text content to be translated
+        :return: text translated to audio file
+        """
+        pass

+ 1 - 0
api/core/model_runtime/model_providers/xinference/xinference.yaml

@@ -17,6 +17,7 @@ supported_model_types:
   - text-embedding
   - rerank
   - speech2text
+  - tts
 configurate_methods:
   - customizable-model
 model_credential_schema:

+ 4 - 4
api/poetry.lock

@@ -9098,13 +9098,13 @@ h11 = ">=0.9.0,<1"
 
 [[package]]
 name = "xinference-client"
-version = "0.9.4"
+version = "0.13.3"
 description = "Client for Xinference"
 optional = false
 python-versions = "*"
 files = [
-    {file = "xinference-client-0.9.4.tar.gz", hash = "sha256:21934bc9f3142ade66aaed33c2b6cf244c274d5b4b3163f9981bebdddacf205f"},
-    {file = "xinference_client-0.9.4-py3-none-any.whl", hash = "sha256:6d3f1df3537a011f0afee5f9c9ca4f3ff564ca32cc999cf7038b324c0b907d0c"},
+    {file = "xinference-client-0.13.3.tar.gz", hash = "sha256:822b722100affdff049c27760be7d62ac92de58c87a40d3361066df446ba648f"},
+    {file = "xinference_client-0.13.3-py3-none-any.whl", hash = "sha256:f0eff3858b1ebcef2129726f82b09259c177e11db466a7ca23def3d4849c419f"},
 ]
 
 [package.dependencies]
@@ -9502,4 +9502,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b"
+content-hash = "ca55e4a4bb354fe969cc73c823557525c7598b0375e8791fcd77febc59e03b96"

+ 1 - 1
api/pyproject.toml

@@ -173,7 +173,7 @@ transformers = "~4.35.0"
 unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
 websocket-client = "~1.7.0"
 werkzeug = "~3.0.1"
-xinference-client = "0.9.4"
+xinference-client = "0.13.3"
 yarl = "~1.9.4"
 zhipuai = "1.0.7"
 rank-bm25 = "~0.2.2"