Przeglądaj źródła

feat: support fish audio TTS (#7982)

Leng Yue 7 miesięcy temu
rodzic
commit
bd0992275c

+ 1 - 0
api/core/model_runtime/model_providers/fishaudio/__init__.py

@@ -0,0 +1 @@
+

Plik diff jest za duży
+ 0 - 0
api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg


Plik diff jest za duży
+ 0 - 0
api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg


+ 28 - 0
api/core/model_runtime/model_providers/fishaudio/fishaudio.py

@@ -0,0 +1,28 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class FishAudioProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+
+        For debugging purposes, this method now always passes validation.
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.TTS)
+            model_instance.validate_credentials(
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 76 - 0
api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml

@@ -0,0 +1,76 @@
+provider: fishaudio
+label:
+  en_US: Fish Audio
+description:
+  en_US: Models provided by Fish Audio, currently only support TTS.
+  zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
+icon_small:
+  en_US: fishaudio_s_en.svg
+icon_large:
+  en_US: fishaudio_l_en.svg
+background: "#E5E7EB"
+help:
+  title:
+    en_US: Get your API key from Fish Audio
+    zh_Hans: 从 Fish Audio 获取你的 API Key
+  url:
+    en_US: https://fish.audio/go-api/
+supported_model_types:
+  - tts
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: api_base
+      label:
+        en_US: API URL
+      type: text-input
+      required: false
+      default: https://api.fish.audio
+      placeholder:
+        en_US: Enter your API URL
+        zh_Hans: 在此输入您的 API URL
+    - variable: use_public_models
+      label:
+        en_US: Use Public Models
+      type: select
+      required: false
+      default: "false"
+      placeholder:
+        en_US: Toggle to use public models
+        zh_Hans: 切换以使用公共模型
+      options:
+        - value: "true"
+          label:
+            en_US: Allow Public Models
+            zh_Hans: 使用公共模型
+        - value: "false"
+          label:
+            en_US: Private Models Only
+            zh_Hans: 仅使用私有模型
+    - variable: latency
+      label:
+        en_US: Latency
+      type: select
+      required: false
+      default: "normal"
+      placeholder:
+        en_US: Toggle to choice latency
+        zh_Hans: 切换以调整延迟
+      options:
+        - value: "balanced"
+          label:
+            en_US: Low (may affect quality)
+            zh_Hans: 低延迟 (可能降低质量)
+        - value: "normal"
+          label:
+            en_US: Normal
+            zh_Hans: 标准

+ 0 - 0
api/core/model_runtime/model_providers/fishaudio/tts/__init__.py


+ 174 - 0
api/core/model_runtime/model_providers/fishaudio/tts/tts.py

@@ -0,0 +1,174 @@
+from typing import Optional
+
+import httpx
+
+from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.tts_model import TTSModel
+
+
+class FishAudioText2SpeechModel(TTSModel):
+    """
+    Model class for Fish.audio Text to Speech model.
+    """
+
+    def get_tts_model_voices(
+        self, model: str, credentials: dict, language: Optional[str] = None
+    ) -> list:
+        api_base = credentials.get("api_base", "https://api.fish.audio")
+        api_key = credentials.get("api_key")
+        use_public_models = credentials.get("use_public_models", "false") == "true"
+
+        params = {
+            "self": str(not use_public_models).lower(),
+            "page_size": "100",
+        }
+
+        if language is not None:
+            if "-" in language:
+                language = language.split("-")[0]
+            params["language"] = language
+
+        results = httpx.get(
+            f"{api_base}/model",
+            headers={"Authorization": f"Bearer {api_key}"},
+            params=params,
+        )
+
+        results.raise_for_status()
+        data = results.json()
+
+        return [{"name": i["title"], "value": i["_id"]} for i in data["items"]]
+
+    def _invoke(
+        self,
+        model: str,
+        tenant_id: str,
+        credentials: dict,
+        content_text: str,
+        voice: str,
+        user: Optional[str] = None,
+    ) -> any:
+        """
+        Invoke text2speech model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param content_text: text content to be translated
+        :param user: unique user id
+        :return: generator yielding audio chunks
+        """
+
+        return self._tts_invoke_streaming(
+            model=model,
+            credentials=credentials,
+            content_text=content_text,
+            voice=voice,
+        )
+
+    def validate_credentials(
+        self, credentials: dict, user: Optional[str] = None
+    ) -> None:
+        """
+        Validate credentials for text2speech model
+
+        :param credentials: model credentials
+        :param user: unique user id
+        """
+
+        try:
+            self.get_tts_model_voices(
+                None,
+                credentials={
+                    "api_key": credentials["api_key"],
+                    "api_base": credentials["api_base"],
+                    # Disable public models will trigger a 403 error if user is not logged in
+                    "use_public_models": "false",
+                },
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _tts_invoke_streaming(
+        self, model: str, credentials: dict, content_text: str, voice: str
+    ) -> any:
+        """
+        Invoke streaming text2speech model
+        :param model: model name
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: ID of the reference audio (if any)
+        :return: generator yielding audio chunks
+        """
+
+        try:
+            word_limit = self._get_model_word_limit(model, credentials)
+            if len(content_text) > word_limit:
+                sentences = self._split_text_into_sentences(
+                    content_text, max_length=word_limit
+                )
+            else:
+                sentences = [content_text.strip()]
+            
+            for i in range(len(sentences)):
+                yield from self._tts_invoke_streaming_sentence(
+                    credentials=credentials, content_text=sentences[i], voice=voice
+                )
+
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))
+
+    def _tts_invoke_streaming_sentence(
+        self, credentials: dict, content_text: str, voice: Optional[str] = None
+    ) -> any:
+        """
+        Invoke streaming text2speech model
+
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: ID of the reference audio (if any)
+        :return: generator yielding audio chunks
+        """
+        api_key = credentials.get("api_key")
+        api_url = credentials.get("api_base", "https://api.fish.audio")
+        latency = credentials.get("latency")
+
+        if not api_key:
+            raise InvokeBadRequestError("API key is required")
+
+        with httpx.stream(
+            "POST",
+            api_url + "/v1/tts",
+            json={
+                "text": content_text,
+                "reference_id": voice,
+                "latency": latency
+            },
+            headers={
+                "Authorization": f"Bearer {api_key}",
+            },
+            timeout=None,
+        ) as response:
+            if response.status_code != 200:
+                raise InvokeBadRequestError(
+                    f"Error: {response.status_code} - {response.text}"
+                )
+            yield from response.iter_bytes()
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeBadRequestError: [
+                httpx.HTTPStatusError,
+            ],
+        }

+ 5 - 0
api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml

@@ -0,0 +1,5 @@
+model: tts-default
+model_type: tts
+model_properties:
+  word_limit: 1000
+  audio_type: 'mp3'

+ 82 - 0
api/tests/integration_tests/model_runtime/__mock/fishaudio.py

@@ -0,0 +1,82 @@
+import os
+from collections.abc import Callable
+from typing import Literal
+
+import httpx
+import pytest
+from _pytest.monkeypatch import MonkeyPatch
+
+
+def mock_get(*args, **kwargs):
+    if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
+        raise httpx.HTTPStatusError(
+            "Invalid API key",
+            request=httpx.Request("GET", ""),
+            response=httpx.Response(401),
+        )
+
+    return httpx.Response(
+        200,
+        json={
+            "items": [
+                {"title": "Model 1", "_id": "model1"},
+                {"title": "Model 2", "_id": "model2"},
+            ]
+        },
+        request=httpx.Request("GET", ""),
+    )
+
+
+def mock_stream(*args, **kwargs):
+    class MockStreamResponse:
+        def __init__(self):
+            self.status_code = 200
+
+        def __enter__(self):
+            return self
+
+        def __exit__(self, exc_type, exc_val, exc_tb):
+            pass
+
+        def iter_bytes(self):
+            yield b"Mocked audio data"
+
+    return MockStreamResponse()
+
+
+def mock_fishaudio(
+    monkeypatch: MonkeyPatch,
+    methods: list[Literal["list-models", "tts"]],
+) -> Callable[[], None]:
+    """
+    mock fishaudio module
+
+    :param monkeypatch: pytest monkeypatch fixture
+    :return: unpatch function
+    """
+
+    def unpatch() -> None:
+        monkeypatch.undo()
+
+    if "list-models" in methods:
+        monkeypatch.setattr(httpx, "get", mock_get)
+
+    if "tts" in methods:
+        monkeypatch.setattr(httpx, "stream", mock_stream)
+
+    return unpatch
+
+
+MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
+
+
+@pytest.fixture
+def setup_fishaudio_mock(request, monkeypatch):
+    methods = request.param if hasattr(request, "param") else []
+    if MOCK:
+        unpatch = mock_fishaudio(monkeypatch, methods=methods)
+
+    yield
+
+    if MOCK:
+        unpatch()

+ 0 - 0
api/tests/integration_tests/model_runtime/fishaudio/__init__.py


+ 33 - 0
api/tests/integration_tests/model_runtime/fishaudio/test_provider.py

@@ -0,0 +1,33 @@
+import os
+
+import httpx
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider
+from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
+
+
+@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True)
+def test_validate_provider_credentials(setup_fishaudio_mock):
+    print("-----", httpx.get)
+    provider = FishAudioProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(
+            credentials={
+                "api_key": "bad_api_key",
+                "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
+                "use_public_models": "false",
+                "latency": "normal",
+            }
+        )
+
+    provider.validate_provider_credentials(
+        credentials={
+            "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
+            "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
+            "use_public_models": "false",
+            "latency": "normal",
+        }
+    )

+ 32 - 0
api/tests/integration_tests/model_runtime/fishaudio/test_tts.py

@@ -0,0 +1,32 @@
+import os
+
+import pytest
+
+from core.model_runtime.model_providers.fishaudio.tts.tts import (
+    FishAudioText2SpeechModel,
+)
+from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
+
+
+@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True)
+def test_invoke_model(setup_fishaudio_mock):
+    model = FishAudioText2SpeechModel()
+
+    result = model.invoke(
+        model="tts-default",
+        tenant_id="test",
+        credentials={
+            "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
+            "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
+            "use_public_models": "false",
+            "latency": "normal",
+        },
+        content_text="Hello, world!",
+        voice="03397b4c4be74759b72533b663fbd001",
+    )
+
+    content = b""
+    for chunk in result:
+        content += chunk
+
+    assert content != b""

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików