|
@@ -0,0 +1,145 @@
|
|
|
+from collections.abc import Iterable
|
|
|
+from typing import Optional
|
|
|
+from urllib.parse import urljoin
|
|
|
+
|
|
|
+import requests
|
|
|
+
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
|
|
+from core.model_runtime.errors.invoke import InvokeBadRequestError
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
|
|
+from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
|
|
+
|
|
|
+
|
|
|
+class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
|
|
|
+ """
|
|
|
+ Model class for OpenAI-compatible text2speech model.
|
|
|
+ """
|
|
|
+
|
|
|
+ def _invoke(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ tenant_id: str,
|
|
|
+ credentials: dict,
|
|
|
+ content_text: str,
|
|
|
+ voice: str,
|
|
|
+ user: Optional[str] = None,
|
|
|
+ ) -> Iterable[bytes]:
|
|
|
+ """
|
|
|
+ Invoke TTS model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param tenant_id: user tenant id
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param content_text: text content to be translated
|
|
|
+ :param voice: model voice/speaker
|
|
|
+ :param user: unique user id
|
|
|
+ :return: audio data as bytes iterator
|
|
|
+ """
|
|
|
+ # Set up headers with authentication if provided
|
|
|
+ headers = {}
|
|
|
+ if api_key := credentials.get("api_key"):
|
|
|
+ headers["Authorization"] = f"Bearer {api_key}"
|
|
|
+
|
|
|
+ # Construct endpoint URL
|
|
|
+ endpoint_url = credentials.get("endpoint_url")
|
|
|
+ if not endpoint_url.endswith("/"):
|
|
|
+ endpoint_url += "/"
|
|
|
+ endpoint_url = urljoin(endpoint_url, "audio/speech")
|
|
|
+
|
|
|
+ # Get audio format from model properties
|
|
|
+ audio_format = self._get_model_audio_type(model, credentials)
|
|
|
+
|
|
|
+ # Split text into chunks if needed based on word limit
|
|
|
+ word_limit = self._get_model_word_limit(model, credentials)
|
|
|
+ sentences = self._split_text_into_sentences(content_text, word_limit)
|
|
|
+
|
|
|
+ for sentence in sentences:
|
|
|
+ # Prepare request payload
|
|
|
+ payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}
|
|
|
+
|
|
|
+ # Make POST request
|
|
|
+ response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise InvokeBadRequestError(response.text)
|
|
|
+
|
|
|
+ # Stream the audio data
|
|
|
+ for chunk in response.iter_content(chunk_size=4096):
|
|
|
+ if chunk:
|
|
|
+ yield chunk
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get default voice for validation
|
|
|
+ voice = self._get_model_default_voice(model, credentials)
|
|
|
+
|
|
|
+ # Test with a simple text
|
|
|
+ next(
|
|
|
+ self._invoke(
|
|
|
+ model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(str(ex))
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
|
|
+ """
|
|
|
+ Get customizable model schema
|
|
|
+ """
|
|
|
+ # Parse voices from comma-separated string
|
|
|
+ voice_names = credentials.get("voices", "alloy").strip().split(",")
|
|
|
+ voices = []
|
|
|
+
|
|
|
+ for voice in voice_names:
|
|
|
+ voice = voice.strip()
|
|
|
+ if not voice:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Use en-US for all voices
|
|
|
+ voices.append(
|
|
|
+ {
|
|
|
+ "name": voice,
|
|
|
+ "mode": voice,
|
|
|
+ "language": "en-US",
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # If no voices provided or all voices were empty strings, use 'alloy' as default
|
|
|
+ if not voices:
|
|
|
+ voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]
|
|
|
+
|
|
|
+ return AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(en_US=model),
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_type=ModelType.TTS,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
|
|
|
+ ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
|
|
|
+ ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
|
|
|
+ ModelPropertyKey.VOICES: voices,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
|
|
+ """
|
|
|
+ Override base get_tts_model_voices to handle customizable voices
|
|
|
+ """
|
|
|
+ model_schema = self.get_customizable_model_schema(model, credentials)
|
|
|
+
|
|
|
+ if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
|
|
+ raise ValueError("this model does not support voice")
|
|
|
+
|
|
|
+ voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
|
|
+
|
|
|
+ # Always return all voices regardless of language
|
|
|
+ return [{"name": d["name"], "value": d["mode"]} for d in voices]
|