Просмотр исходного кода

Add support for local ai speech to text (#3921)

Co-authored-by: Yeuoly <admin@srmxy.cn>
Tomy 11 месяцев назад
Родитель
Сommit
bb7c62777d

+ 4 - 0
api/core/model_runtime/model_providers/localai/localai.yaml

@@ -15,6 +15,7 @@ help:
 supported_model_types:
   - llm
   - text-embedding
+  - speech2text
 configurate_methods:
   - customizable-model
 model_credential_schema:
@@ -57,6 +58,9 @@ model_credential_schema:
         zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
         en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
     - variable: context_size
+      show_on:
+        - variable: __model_type
+          value: llm
       label:
         zh_Hans: 上下文大小
         en_US: Context size

+ 0 - 0
api/core/model_runtime/model_providers/localai/speech2text/__init__.py


+ 101 - 0
api/core/model_runtime/model_providers/localai/speech2text/speech2text.py

@@ -0,0 +1,101 @@
+from typing import IO, Optional
+
+from requests import Request, Session
+from yarl import URL
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+
+
+class LocalAISpeech2text(Speech2TextModel):
+    """
+    Model class for Local AI Text to speech model.
+    """
+
+    def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        
+        url = str(URL(credentials['server_url']) / "v1/audio/transcriptions")
+        data = {"model": model}
+        files = {"file": file}
+
+        session = Session()
+        request = Request("POST", url, data=data, files=files)
+        prepared_request = session.prepare_request(request)
+        response = session.send(prepared_request)
+
+        if 'error' in response.json():
+            raise InvokeServerUnavailableError("Empty response")
+
+        return response.json()["text"]
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            audio_file_path = self._get_demo_file_path()
+
+            with open(audio_file_path, 'rb') as audio_file:
+                self._invoke(model, credentials, audio_file)
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                InvokeBadRequestError
+            ],
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.SPEECH2TEXT,
+            model_properties={},
+            parameter_rules=[]
+        )
+
+        return entity

+ 54 - 0
api/tests/integration_tests/model_runtime/localai/test_speech2text.py

@@ -0,0 +1,54 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text
+
+
+def test_validate_credentials():
+    model = LocalAISpeech2text()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='whisper-1',
+            credentials={
+                'server_url': 'invalid_url'
+            }
+        )
+
+    model.validate_credentials(
+        model='whisper-1',
+        credentials={
+            'server_url': os.environ.get('LOCALAI_SERVER_URL')
+        }
+    )
+
+
+def test_invoke_model():
+    model = LocalAISpeech2text()
+
+    # Get the directory of the current file
+    current_dir = os.path.dirname(os.path.abspath(__file__))
+
+    # Get assets directory
+    assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
+
+    # Construct the path to the audio file
+    audio_file_path = os.path.join(assets_dir, 'audio.mp3')
+
+    # Open the file and get the file object
+    with open(audio_file_path, 'rb') as audio_file:
+        file = audio_file
+
+        result = model.invoke(
+            model='whisper-1',
+            credentials={
+                'server_url': os.environ.get('LOCALAI_SERVER_URL')
+            },
+            file=file,
+            user="abc-123"
+        )
+
+        assert isinstance(result, str)
+        assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'