Переглянути джерело

feat: add support of speech2text function for OpenAI-API-compatible and Siliconflow (#7197)

shAlfred 8 місяців тому
батько
коміт
a12ddc47e7

+ 17 - 0
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -7,6 +7,7 @@ description:
 supported_model_types:
   - llm
   - text-embedding
+  - speech2text
 configurate_methods:
   - customizable-model
 model_credential_schema:
@@ -61,6 +62,22 @@ model_credential_schema:
         zh_Hans: 模型上下文长度
         en_US: Model context size
       required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+      type: text-input
+      default: '4096'
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      show_on:
+        - variable: __model_type
+          value: text-embedding
       type: text-input
       default: '4096'
       placeholder:

+ 0 - 0
api/core/model_runtime/model_providers/openai_api_compatible/speech2text/__init__.py


+ 63 - 0
api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py

@@ -0,0 +1,63 @@
+from typing import IO, Optional
+from urllib.parse import urljoin
+
+import requests
+
+from core.model_runtime.errors.invoke import InvokeBadRequestError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
+
+
+class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
+    """
+    Model class for OpenAI Compatible Speech to text model.
+    """
+
+    def _invoke(
+            self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
+    ) -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        headers = {}
+
+        api_key = credentials.get("api_key")
+        if api_key:
+            headers["Authorization"] = f"Bearer {api_key}"
+
+        endpoint_url = credentials.get("endpoint_url")
+        if not endpoint_url.endswith("/"):
+            endpoint_url += "/"
+        endpoint_url = urljoin(endpoint_url, "audio/transcriptions")
+
+        payload = {"model": model}
+        files = [("file", file)]
+        response = requests.post(endpoint_url, headers=headers, data=payload, files=files)
+
+        if response.status_code != 200:
+            raise InvokeBadRequestError(response.text)
+        response_data = response.json()
+        return response_data["text"]
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            audio_file_path = self._get_demo_file_path()
+
+            with open(audio_file_path, "rb") as audio_file:
+                self._invoke(model, credentials, audio_file)
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))

+ 1 - 0
api/core/model_runtime/model_providers/siliconflow/siliconflow.py

@@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid
 
 logger = logging.getLogger(__name__)
 
+
 class SiliconflowProvider(ModelProvider):
 
     def validate_provider_credentials(self, credentials: dict) -> None:

+ 1 - 0
api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml

@@ -16,6 +16,7 @@ help:
 supported_model_types:
   - llm
   - text-embedding
+  - speech2text
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 0 - 0
api/core/model_runtime/model_providers/siliconflow/speech2text/__init__.py


+ 5 - 0
api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml

@@ -0,0 +1,5 @@
+model: iic/SenseVoiceSmall
+model_type: speech2text
+model_properties:
+  file_upload_limit: 1
+  supported_file_extensions: mp3,wav

+ 32 - 0
api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py

@@ -0,0 +1,32 @@
+from typing import IO, Optional
+
+from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
+
+
+class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel):
+    """
+    Model class for Siliconflow Speech to text model.
+    """
+
+    def _invoke(
+            self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
+    ) -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        self._add_custom_parameters(credentials)
+        return super()._invoke(model, credentials, file)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        self._add_custom_parameters(credentials)
+        return super().validate_credentials(model, credentials)
+
+    @classmethod
+    def _add_custom_parameters(cls, credentials: dict) -> None:
+        credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"

+ 59 - 0
api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py

@@ -0,0 +1,59 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import (
+    OAICompatSpeech2TextModel,
+)
+
+
+def test_validate_credentials():
+    model = OAICompatSpeech2TextModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="whisper-1",
+            credentials={
+                "api_key": "invalid_key",
+                "endpoint_url": "https://api.openai.com/v1/"
+            },
+        )
+
+    model.validate_credentials(
+        model="whisper-1",
+        credentials={
+            "api_key": os.environ.get("OPENAI_API_KEY"),
+            "endpoint_url": "https://api.openai.com/v1/"
+        },
+    )
+
+
+def test_invoke_model():
+    model = OAICompatSpeech2TextModel()
+
+    # Get the directory of the current file
+    current_dir = os.path.dirname(os.path.abspath(__file__))
+
+    # Get assets directory
+    assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
+
+    # Construct the path to the audio file
+    audio_file_path = os.path.join(assets_dir, "audio.mp3")
+
+    # Open the file and get the file object
+    with open(audio_file_path, "rb") as audio_file:
+        file = audio_file
+
+        result = model.invoke(
+            model="whisper-1",
+            credentials={
+                "api_key": os.environ.get("OPENAI_API_KEY"),
+                "endpoint_url": "https://api.openai.com/v1/"
+            },
+            file=file,
+            user="abc-123",
+        )
+
+        assert isinstance(result, str)
+        assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'

+ 53 - 0
api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py

@@ -0,0 +1,53 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel
+
+
+def test_validate_credentials():
+    model = SiliconflowSpeech2TextModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="iic/SenseVoiceSmall",
+            credentials={
+                "api_key": "invalid_key"
+            },
+        )
+
+    model.validate_credentials(
+        model="iic/SenseVoiceSmall",
+        credentials={
+            "api_key": os.environ.get("API_KEY")
+        },
+    )
+
+
+def test_invoke_model():
+    model = SiliconflowSpeech2TextModel()
+
+    # Get the directory of the current file
+    current_dir = os.path.dirname(os.path.abspath(__file__))
+
+    # Get assets directory
+    assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
+
+    # Construct the path to the audio file
+    audio_file_path = os.path.join(assets_dir, "audio.mp3")
+
+    # Open the file and get the file object
+    with open(audio_file_path, "rb") as audio_file:
+        file = audio_file
+
+        result = model.invoke(
+            model="iic/SenseVoiceSmall",
+            credentials={
+                "api_key": os.environ.get("API_KEY")
+            },
+            file=file
+        )
+
+        assert isinstance(result, str)
+        assert result == '1,2,3,4,5,6,7,8,9,10.'