Parcourir la source

Add TTS to OpenAI_API_Compatible (#11071)

Tao Wang il y a 4 mois
Parent
commit
aa135a3780

+ 1 - 1
api/core/model_runtime/model_providers/azure_openai/tts/tts.py

@@ -14,7 +14,7 @@ from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_M
 
 class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
     """
-    Model class for OpenAI Speech to text model.
+    Model class for OpenAI text2speech model.
     """
 
     def _invoke(

+ 1 - 1
api/core/model_runtime/model_providers/gitee_ai/tts/tts.py

@@ -10,7 +10,7 @@ from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
 
 class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
     """
-    Model class for OpenAI Speech to text model.
+    Model class for OpenAI text2speech model.
     """
 
     def _invoke(

+ 1 - 1
api/core/model_runtime/model_providers/openai/tts/tts.py

@@ -11,7 +11,7 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
 
 class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
     """
-    Model class for OpenAI Speech to text model.
+    Model class for OpenAI text2speech model.
     """
 
     def _invoke(

+ 21 - 4
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -9,6 +9,7 @@ supported_model_types:
   - text-embedding
   - speech2text
   - rerank
+  - tts
 configurate_methods:
   - customizable-model
 model_credential_schema:
@@ -67,7 +68,7 @@ model_credential_schema:
         - variable: __model_type
           value: llm
       type: text-input
-      default: '4096'
+      default: "4096"
       placeholder:
         zh_Hans: 在此输入您的模型上下文长度
         en_US: Enter your Model context size
@@ -80,7 +81,7 @@ model_credential_schema:
         - variable: __model_type
           value: text-embedding
       type: text-input
-      default: '4096'
+      default: "4096"
       placeholder:
         zh_Hans: 在此输入您的模型上下文长度
         en_US: Enter your Model context size
@@ -93,7 +94,7 @@ model_credential_schema:
         - variable: __model_type
           value: rerank
       type: text-input
-      default: '4096'
+      default: "4096"
       placeholder:
         zh_Hans: 在此输入您的模型上下文长度
         en_US: Enter your Model context size
@@ -104,7 +105,7 @@ model_credential_schema:
       show_on:
         - variable: __model_type
           value: llm
-      default: '4096'
+      default: "4096"
       type: text-input
     - variable: function_calling_type
       show_on:
@@ -174,3 +175,19 @@ model_credential_schema:
           value: llm
       default: '\n\n'
       type: text-input
+    - variable: voices
+      show_on:
+        - variable: __model_type
+          value: tts
+      label:
+        en_US: Available Voices (comma-separated)
+        zh_Hans: 可用声音(用英文逗号分隔)
+      type: text-input
+      required: false
+      default: "alloy"
+      placeholder:
+        en_US: "alloy,echo,fable,onyx,nova,shimmer"
+        zh_Hans: "alloy,echo,fable,onyx,nova,shimmer"
+      help:
+        en_US: "List voice names separated by commas. First voice will be used as default."
+        zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"

+ 0 - 0
api/core/model_runtime/model_providers/openai_api_compatible/tts/__init__.py


+ 145 - 0
api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py

@@ -0,0 +1,145 @@
+from collections.abc import Iterable
+from typing import Optional
+from urllib.parse import urljoin
+
+import requests
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.tts_model import TTSModel
+from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
+
+
+class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
+    """
+    Model class for OpenAI-compatible text2speech model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        tenant_id: str,
+        credentials: dict,
+        content_text: str,
+        voice: str,
+        user: Optional[str] = None,
+    ) -> Iterable[bytes]:
+        """
+        Invoke TTS model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model voice/speaker
+        :param user: unique user id
+        :return: audio data as bytes iterator
+        """
+        # Set up headers with authentication if provided
+        headers = {}
+        if api_key := credentials.get("api_key"):
+            headers["Authorization"] = f"Bearer {api_key}"
+
+        # Construct endpoint URL
+        endpoint_url = credentials.get("endpoint_url")
+        if not endpoint_url.endswith("/"):
+            endpoint_url += "/"
+        endpoint_url = urljoin(endpoint_url, "audio/speech")
+
+        # Get audio format from model properties
+        audio_format = self._get_model_audio_type(model, credentials)
+
+        # Split text into chunks if needed based on word limit
+        word_limit = self._get_model_word_limit(model, credentials)
+        sentences = self._split_text_into_sentences(content_text, word_limit)
+
+        for sentence in sentences:
+            # Prepare request payload
+            payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}
+
+            # Make POST request
+            response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)
+
+            if response.status_code != 200:
+                raise InvokeBadRequestError(response.text)
+
+            # Stream the audio data
+            for chunk in response.iter_content(chunk_size=4096):
+                if chunk:
+                    yield chunk
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            # Get default voice for validation
+            voice = self._get_model_default_voice(model, credentials)
+
+            # Test with a simple text
+            next(
+                self._invoke(
+                    model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
+                )
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+        """
+        Get customizable model schema
+        """
+        # Parse voices from comma-separated string
+        voice_names = credentials.get("voices", "alloy").strip().split(",")
+        voices = []
+
+        for voice in voice_names:
+            voice = voice.strip()
+            if not voice:
+                continue
+
+            # Use en-US for all voices
+            voices.append(
+                {
+                    "name": voice,
+                    "mode": voice,
+                    "language": "en-US",
+                }
+            )
+
+        # If no voices provided or all voices were empty strings, use 'alloy' as default
+        if not voices:
+            voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]
+
+        return AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TTS,
+            model_properties={
+                ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
+                ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
+                ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
+                ModelPropertyKey.VOICES: voices,
+            },
+        )
+
+    def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
+        """
+        Override base get_tts_model_voices to handle customizable voices
+        """
+        model_schema = self.get_customizable_model_schema(model, credentials)
+
+        if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
+            raise ValueError("this model does not support voice")
+
+        voices = model_schema.model_properties[ModelPropertyKey.VOICES]
+
+        # Always return all voices regardless of language
+        return [{"name": d["name"], "value": d["mode"]} for d in voices]