|
@@ -4,14 +4,17 @@ import logging
|
|
|
import queue
|
|
|
import re
|
|
|
import threading
|
|
|
+from collections.abc import Iterable
|
|
|
|
|
|
from core.app.entities.queue_entities import (
|
|
|
+ MessageQueueMessage,
|
|
|
QueueAgentMessageEvent,
|
|
|
QueueLLMChunkEvent,
|
|
|
QueueNodeSucceededEvent,
|
|
|
QueueTextChunkEvent,
|
|
|
+ WorkflowQueueMessage,
|
|
|
)
|
|
|
-from core.model_manager import ModelManager
|
|
|
+from core.model_manager import ModelInstance, ModelManager
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
|
|
|
|
|
@@ -21,7 +24,7 @@ class AudioTrunk:
|
|
|
self.status = status
|
|
|
|
|
|
|
|
|
-def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
|
|
+def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
|
|
if not text_content or text_content.isspace():
|
|
|
return
|
|
|
return model_instance.invoke_tts(
|
|
@@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
|
|
)
|
|
|
|
|
|
|
|
|
-def _process_future(future_queue, audio_queue):
|
|
|
+def _process_future(
|
|
|
+ future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
|
|
+ audio_queue: queue.Queue[AudioTrunk],
|
|
|
+):
|
|
|
while True:
|
|
|
try:
|
|
|
future = future_queue.get()
|
|
|
if future is None:
|
|
|
break
|
|
|
- for audio in future.result():
|
|
|
+ invoke_result = future.result()
|
|
|
+ if not invoke_result:
|
|
|
+ continue
|
|
|
+ for audio in invoke_result:
|
|
|
audio_base64 = base64.b64encode(bytes(audio))
|
|
|
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
|
|
except Exception as e:
|
|
@@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
self.tenant_id = tenant_id
|
|
|
self.msg_text = ""
|
|
|
- self._audio_queue = queue.Queue()
|
|
|
- self._msg_queue = queue.Queue()
|
|
|
+ self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
|
|
+ self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
|
|
self.match = re.compile(r"[。.!?]")
|
|
|
self.model_manager = ModelManager()
|
|
|
self.model_instance = self.model_manager.get_default_model_instance(
|
|
@@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
|
|
|
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
|
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
|
|
|
|
|
- def publish(self, message):
|
|
|
- try:
|
|
|
- self._msg_queue.put(message)
|
|
|
- except Exception as e:
|
|
|
- self.logger.warning(e)
|
|
|
+ def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
|
|
+ self._msg_queue.put(message)
|
|
|
|
|
|
def _runtime(self):
|
|
|
- future_queue = queue.Queue()
|
|
|
+ future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
|
|
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
|
|
while True:
|
|
|
try:
|
|
@@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
|
|
|
break
|
|
|
future_queue.put(None)
|
|
|
|
|
|
- def check_and_get_audio(self) -> AudioTrunk | None:
|
|
|
+ def check_and_get_audio(self):
|
|
|
try:
|
|
|
if self._last_audio_event and self._last_audio_event.status == "finish":
|
|
|
if self.executor:
|