Explorar el Código

feat:support azure tts (#2751)

呆萌闷油瓶 hace 1 año
padre
commit
f49b1afd6c

+ 110 - 0
api/core/model_runtime/model_providers/azure_openai/_constant.py

@@ -583,3 +583,113 @@ SPEECH2TEXT_BASE_MODELS = [
         )
     )
 ]
+TTS_BASE_MODELS = [
+    AzureBaseModel(
+        base_model_name='tts-1',
+        entity=AIModelEntity(
+            model='fake-deployment-name',
+            label=I18nObject(
+                en_US='fake-deployment-name-label'
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TTS,
+            model_properties={
+                ModelPropertyKey.DEFAULT_VOICE: 'alloy',
+                ModelPropertyKey.VOICES: [
+                    {
+                        'mode': 'alloy',
+                        'name': 'Alloy',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'echo',
+                        'name': 'Echo',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'fable',
+                        'name': 'Fable',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'onyx',
+                        'name': 'Onyx',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'nova',
+                        'name': 'Nova',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'shimmer',
+                        'name': 'Shimmer',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                ],
+                ModelPropertyKey.WORD_LIMIT: 120,
+                ModelPropertyKey.AUDOI_TYPE: 'mp3',
+                ModelPropertyKey.MAX_WORKERS: 5
+            },
+            pricing=PriceConfig(
+                input=0.015,
+                unit=0.001,
+                currency='USD',
+            )
+        )
+    ),
+    AzureBaseModel(
+        base_model_name='tts-1-hd',
+        entity=AIModelEntity(
+            model='fake-deployment-name',
+            label=I18nObject(
+                en_US='fake-deployment-name-label'
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TTS,
+            model_properties={
+                ModelPropertyKey.DEFAULT_VOICE: 'alloy',
+                ModelPropertyKey.VOICES: [
+                    {
+                        'mode': 'alloy',
+                        'name': 'Alloy',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'echo',
+                        'name': 'Echo',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'fable',
+                        'name': 'Fable',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'onyx',
+                        'name': 'Onyx',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'nova',
+                        'name': 'Nova',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                    {
+                        'mode': 'shimmer',
+                        'name': 'Shimmer',
+                        'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
+                    },
+                ],
+                ModelPropertyKey.WORD_LIMIT: 120,
+                ModelPropertyKey.AUDOI_TYPE: 'mp3',
+                ModelPropertyKey.MAX_WORKERS: 5
+            },
+            pricing=PriceConfig(
+                input=0.03,
+                unit=0.001,
+                currency='USD',
+            )
+        )
+    )
+]

+ 13 - 0
api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml

@@ -16,6 +16,7 @@ supported_model_types:
   - llm
   - text-embedding
   - speech2text
+  - tts
 configurate_methods:
   - customizable-model
 model_credential_schema:
@@ -118,6 +119,18 @@ model_credential_schema:
           show_on:
             - variable: __model_type
               value: speech2text
+        - label:
+            en_US: tts-1
+          value: tts-1
+          show_on:
+            - variable: __model_type
+              value: tts
+        - label:
+            en_US: tts-1-hd
+          value: tts-1-hd
+          show_on:
+            - variable: __model_type
+              value: tts
       placeholder:
         zh_Hans: 在此输入您的模型版本
         en_US: Enter your model version

+ 0 - 0
api/core/model_runtime/model_providers/azure_openai/tts/__init__.py


+ 174 - 0
api/core/model_runtime/model_providers/azure_openai/tts/tts.py

@@ -0,0 +1,174 @@
+import concurrent.futures
+import copy
+from functools import reduce
+from io import BytesIO
+from typing import Optional
+
+from flask import Response, stream_with_context
+from openai import AzureOpenAI
+from pydub import AudioSegment
+
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.tts_model import TTSModel
+from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
+from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
+from extensions.ext_storage import storage
+
+
+class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
+    """
+    Model class for OpenAI Speech to text model.
+    """
+
+    def _invoke(self, model: str, tenant_id: str, credentials: dict,
+                content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
+        """
+        _invoke text2speech model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :param streaming: output is streaming
+        :param user: unique user id
+        :return: text translated to audio file
+        """
+        audio_type = self._get_model_audio_type(model, credentials)
+        if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
+            voice = self._get_model_default_voice(model, credentials)
+        if streaming:
+            return Response(stream_with_context(self._tts_invoke_streaming(model=model,
+                                                                           credentials=credentials,
+                                                                           content_text=content_text,
+                                                                           tenant_id=tenant_id,
+                                                                           voice=voice)),
+                            status=200, mimetype=f'audio/{audio_type}')
+        else:
+            return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
+
+    def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
+        """
+        validate credentials text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param user: unique user id
+        :return: text translated to audio file
+        """
+        try:
+            self._tts_invoke(
+                model=model,
+                credentials=credentials,
+                content_text='Hello Dify!',
+                voice=self._get_model_default_voice(model, credentials),
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
+        """
+        _tts_invoke text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :return: text translated to audio file
+        """
+        audio_type = self._get_model_audio_type(model, credentials)
+        word_limit = self._get_model_word_limit(model, credentials)
+        max_workers = self._get_model_workers_limit(model, credentials)
+        try:
+            sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
+            audio_bytes_list = list()
+
+            # Create a thread pool and map the function to the list of sentences
+            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+                futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
+                                           credentials=credentials) for sentence in sentences]
+                for future in futures:
+                    try:
+                        if future.result():
+                            audio_bytes_list.append(future.result())
+                    except Exception as ex:
+                        raise InvokeBadRequestError(str(ex))
+
+            if len(audio_bytes_list) > 0:
+                audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
+                                  audio_bytes_list if audio_bytes]
+                combined_segment = reduce(lambda x, y: x + y, audio_segments)
+                buffer: BytesIO = BytesIO()
+                combined_segment.export(buffer, format=audio_type)
+                buffer.seek(0)
+                return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))
+
+    # Todo: To improve the streaming function
+    def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
+                              voice: str) -> any:
+        """
+        _tts_invoke_streaming text2speech model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :return: text translated to audio file
+        """
+        # transform credentials to kwargs for model instance
+        credentials_kwargs = self._to_credential_kwargs(credentials)
+        if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
+            voice = self._get_model_default_voice(model, credentials)
+        word_limit = self._get_model_word_limit(model, credentials)
+        audio_type = self._get_model_audio_type(model, credentials)
+        tts_file_id = self._get_file_name(content_text)
+        file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
+        try:
+            client = AzureOpenAI(**credentials_kwargs)
+            sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
+            for sentence in sentences:
+                response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
+                # response.stream_to_file(file_path)
+                storage.save(file_path, response.read())
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))
+
+    def _process_sentence(self, sentence: str, model: str,
+                          voice, credentials: dict):
+        """
+        _tts_invoke openai text2speech model api
+
+        :param model: model name
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param sentence: text content to be translated
+        :return: text translated to audio file
+        """
+        # transform credentials to kwargs for model instance
+        credentials_kwargs = self._to_credential_kwargs(credentials)
+        client = AzureOpenAI(**credentials_kwargs)
+        response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
+        if isinstance(response.read(), bytes):
+            return response.read()
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+        ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
+        return ai_model_entity.entity
+
+
+    @staticmethod
+    def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
+        for ai_model_entity in TTS_BASE_MODELS:
+            if ai_model_entity.base_model_name == base_model_name:
+                ai_model_entity_copy = copy.deepcopy(ai_model_entity)
+                ai_model_entity_copy.entity.model = model
+                ai_model_entity_copy.entity.label.en_US = model
+                ai_model_entity_copy.entity.label.zh_Hans = model
+                return ai_model_entity_copy
+
+        return None

+ 1 - 1
api/requirements.txt

@@ -11,7 +11,7 @@ flask-cors~=4.0.0
 gunicorn~=21.2.0
 gevent~=23.9.1
 langchain==0.0.250
-openai~=1.3.6
+openai~=1.13.3
 tiktoken~=0.5.2
 psycopg2-binary~=2.9.6
 pycryptodome==3.19.1