Kaynağa Gözat

feat: add tencent asr (#6091)

Lance Mao 9 ay önce
ebeveyn
işleme
7c55c39085

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -23,6 +23,7 @@
 - tongyi
 - wenxin
 - moonshot
+- tencent
 - jina
 - chatglm
 - yi

+ 0 - 0
api/core/model_runtime/model_providers/tencent/__init__.py


Dosya farkı çok büyük olduğundan ihmal edildi
+ 11 - 0
api/core/model_runtime/model_providers/tencent/_assets/icon_l_en.svg


Dosya farkı çok büyük olduğundan ihmal edildi
+ 11 - 0
api/core/model_runtime/model_providers/tencent/_assets/icon_l_zh.svg


+ 11 - 0
api/core/model_runtime/model_providers/tencent/_assets/icon_s_en.svg

@@ -0,0 +1,11 @@
+<svg viewBox="0 83.15545000000002 85 76.44" data-name="图层 1" id="图层_1"
+  xmlns="http://www.w3.org/2000/svg" style="max-height: 500px" width="85"
+  height="76.44">
+  <defs>
+    <style>.cls-1{fill:#4999d4}</style>
+  </defs>
+  <title>tencent-cloud</title>
+  <path
+    d="M27.569 113.353a17.56 17.56 0 0 1 33.148-3.743.158.158 0 0 0 .194.105 21.267 21.267 0 0 1 7.008-.729c.235.018.327-.116.25-.33a24.828 24.828 0 0 0-47.933 4.444.082.082 0 0 0 .016 0 18.537 18.537 0 0 0-9.85 31.533 18.007 18.007 0 0 0 10.325 5h-.001a43.066 43.066 0 0 0 5.266.282c1.68.011 33.725.008 35.067.008 2.7 0 4.457-.002 6.345-.14a18.245 18.245 0 0 0 11.723-5.15 18.532 18.532 0 0 0-12.901-31.789 18.06 18.06 0 0 0-11.704 4.285c-1.467 1.196-3.006 2.626-4.944 4.508-.642.625-13.336 12.94-21.67 21.028-1.16-.005-2.828-.021-4.306-.07a11.704 11.704 0 0 1-8.125-3.148A11.275 11.275 0 0 1 23.33 120.1a11.706 11.706 0 0 1 7.646 3.062c1.44 1.192 4.633 4 6.035 5.263a.17.17 0 0 0 .24.002l4.945-4.825a.176.176 0 0 0-.004-.27c-2.378-2.15-5.749-5.158-7.778-6.669a18.874 18.874 0 0 0-6.844-3.31zm46.482 26.094a11.704 11.704 0 0 1-8.125 3.147 168.92 168.92 0 0 1-5.204.073h-22.38c8.142-7.91 15.245-14.808 16.051-15.59.738-.717 2.398-2.306 3.83-3.595 3.145-2.831 5.974-3.4 7.976-3.382a11.275 11.275 0 0 1 7.852 19.347z"
+    class="cls-1" />
+</svg>

+ 0 - 0
api/core/model_runtime/model_providers/tencent/speech2text/__init__.py


+ 156 - 0
api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py

@@ -0,0 +1,156 @@
+import base64
+import hashlib
+import hmac
+import time
+
+import requests
+
+
+class Credential:
+    def __init__(self, secret_id, secret_key):
+        self.secret_id = secret_id
+        self.secret_key = secret_key
+
+
+class FlashRecognitionRequest:
+    def __init__(self, voice_format="mp3", engine_type="16k_zh"):
+        self.engine_type = engine_type
+        self.speaker_diarization = 0
+        self.hotword_id = ""
+        self.customization_id = ""
+        self.filter_dirty = 0
+        self.filter_modal = 0
+        self.filter_punc = 0
+        self.convert_num_mode = 1
+        self.word_info = 0
+        self.voice_format = voice_format
+        self.first_channel_only = 1
+        self.reinforce_hotword = 0
+        self.sentence_max_length = 0
+
+    def set_first_channel_only(self, first_channel_only):
+        self.first_channel_only = first_channel_only
+
+    def set_speaker_diarization(self, speaker_diarization):
+        self.speaker_diarization = speaker_diarization
+
+    def set_filter_dirty(self, filter_dirty):
+        self.filter_dirty = filter_dirty
+
+    def set_filter_modal(self, filter_modal):
+        self.filter_modal = filter_modal
+
+    def set_filter_punc(self, filter_punc):
+        self.filter_punc = filter_punc
+
+    def set_convert_num_mode(self, convert_num_mode):
+        self.convert_num_mode = convert_num_mode
+
+    def set_word_info(self, word_info):
+        self.word_info = word_info
+
+    def set_hotword_id(self, hotword_id):
+        self.hotword_id = hotword_id
+
+    def set_customization_id(self, customization_id):
+        self.customization_id = customization_id
+
+    def set_voice_format(self, voice_format):
+        self.voice_format = voice_format
+
+    def set_sentence_max_length(self, sentence_max_length):
+        self.sentence_max_length = sentence_max_length
+
+    def set_reinforce_hotword(self, reinforce_hotword):
+        self.reinforce_hotword = reinforce_hotword
+
+
+class FlashRecognizer:
+    """
+    reponse:
+    request_id        string
+    status            Integer    
+    message           String    
+    audio_duration    Integer
+    flash_result      Result Array
+
+    Result:
+    text              String
+    channel_id        Integer
+    sentence_list     Sentence Array
+
+    Sentence:
+    text              String
+    start_time        Integer    
+    end_time          Integer    
+    speaker_id        Integer    
+    word_list         Word Array
+
+    Word:
+    word              String 
+    start_time        Integer 
+    end_time          Integer 
+    stable_flag:     Integer 
+    """
+
+    def __init__(self, appid, credential):
+        self.credential = credential
+        self.appid = appid
+
+    def _format_sign_string(self, param):
+        signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/"
+        for t in param:
+            if 'appid' in t:
+                signstr += str(t[1])
+                break
+        signstr += "?"
+        for x in param:
+            tmp = x
+            if 'appid' in x:
+                continue
+            for t in tmp:
+                signstr += str(t)
+                signstr += "="
+            signstr = signstr[:-1]
+            signstr += "&"
+        signstr = signstr[:-1]
+        return signstr
+
+    def _build_header(self):
+        header = {"Host": "asr.cloud.tencent.com"}
+        return header
+
+    def _sign(self, signstr, secret_key):
+        hmacstr = hmac.new(secret_key.encode('utf-8'),
+                           signstr.encode('utf-8'), hashlib.sha1).digest()
+        s = base64.b64encode(hmacstr)
+        s = s.decode('utf-8')
+        return s
+
+    def _build_req_with_signature(self, secret_key, params, header):
+        query = sorted(params.items(), key=lambda d: d[0])
+        signstr = self._format_sign_string(query)
+        signature = self._sign(signstr, secret_key)
+        header["Authorization"] = signature
+        requrl = "https://"
+        requrl += signstr[4::]
+        return requrl
+
+    def _create_query_arr(self, req):
+        return {
+            'appid': self.appid, 'secretid': self.credential.secret_id, 'timestamp': str(int(time.time())),
+             'engine_type': req.engine_type, 'voice_format': req.voice_format,
+             'speaker_diarization': req.speaker_diarization, 'hotword_id': req.hotword_id,
+             'customization_id': req.customization_id, 'filter_dirty': req.filter_dirty,
+             'filter_modal': req.filter_modal, 'filter_punc': req.filter_punc,
+             'convert_num_mode': req.convert_num_mode, 'word_info': req.word_info,
+             'first_channel_only': req.first_channel_only, 'reinforce_hotword': req.reinforce_hotword,
+             'sentence_max_length': req.sentence_max_length
+        }
+
+    def recognize(self, req, data):
+        header = self._build_header()
+        query_arr = self._create_query_arr(req)
+        req_url = self._build_req_with_signature(self.credential.secret_key, query_arr, header)
+        r = requests.post(req_url, headers=header, data=data)
+        return r.text

+ 92 - 0
api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py

@@ -0,0 +1,92 @@
+import json
+from typing import IO, Optional
+
+import requests
+
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeConnectionError,
+    InvokeError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.tencent.speech2text.flash_recognizer import (
+    Credential,
+    FlashRecognitionRequest,
+    FlashRecognizer,
+)
+
+
+class TencentSpeech2TextModel(Speech2TextModel):
+    def _invoke(self, model: str, credentials: dict,
+                file: IO[bytes], user: Optional[str] = None) \
+            -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        return self._speech2text_invoke(model, credentials, file)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            audio_file_path = self._get_demo_file_path()
+
+            with open(audio_file_path, 'rb') as audio_file:
+                self._speech2text_invoke(model, credentials, audio_file)
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :return: text for given audio file
+        """
+        app_id = credentials["app_id"]
+        secret_id = credentials["secret_id"]
+        secret_key = credentials["secret_key"]
+        voice_format = file.voice_format if hasattr(file, "voice_format") else "mp3"
+        tencent_voice_recognizer = FlashRecognizer(app_id, Credential(secret_id, secret_key))
+        resp = tencent_voice_recognizer.recognize(FlashRecognitionRequest(voice_format), file)
+        resp = json.loads(resp)
+        code = resp["code"]
+        message = resp["message"]
+        if code == 4002:
+            raise CredentialsValidateFailedError(str(message))
+        elif code != 0:
+            return f"Tencent ASR Recognition failed with code {code} and message {message}"
+        return "\n".join(item["text"] for item in resp["flash_result"])
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                requests.exceptions.ConnectionError
+            ],
+            InvokeAuthorizationError: [
+                CredentialsValidateFailedError
+            ]
+        }

+ 5 - 0
api/core/model_runtime/model_providers/tencent/speech2text/tencent.yaml

@@ -0,0 +1,5 @@
+model: tencent
+model_type: speech2text
+model_properties:
+  file_upload_limit: 25
+  supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm

+ 29 - 0
api/core/model_runtime/model_providers/tencent/tencent.py

@@ -0,0 +1,29 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class TencentProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.SPEECH2TEXT)
+            model_instance.validate_credentials(
+                model='tencent',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 49 - 0
api/core/model_runtime/model_providers/tencent/tencent.yaml

@@ -0,0 +1,49 @@
+provider: tencent
+label:
+  zh_Hans: 腾讯云
+  en_US: Tencent
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  zh_Hans: icon_l_zh.svg
+  en_US: icon_l_en.svg
+background: "#E5E7EB"
+help:
+  title:
+    en_US: Get your API key from Tencent AI
+    zh_Hans: 从腾讯云获取 API Key
+  url:
+    en_US: https://cloud.tencent.com/product/asr
+supported_model_types:
+  - speech2text
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: app_id
+      label:
+        zh_Hans: APPID
+        en_US: APPID
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的腾讯语音识别服务的 APPID
+        en_US: Enter the APPID of your Tencent Cloud ASR service
+    - variable: secret_id
+      label:
+        zh_Hans: SecretId
+        en_US: SecretId
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的腾讯语音识别服务的 SecretId
+        en_US: Enter the SecretId of your Tencent Cloud ASR service
+    - variable: secret_key
+      label:
+        zh_Hans: SecretKey
+        en_US: SecretKey
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的腾讯语音识别服务的 SecretKey
+        en_US: Enter the SecretKey of your Tencent Cloud ASR service

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor